Compare commits

..

No commits in common. "475cd033d9480dba04a5abbf5edda08f80565790" and "d40f5b1f24e184c363e0ebaf7d7365b3e6683cf7" have entirely different histories.

3 changed files with 18 additions and 26 deletions

View File

@ -1,23 +1,21 @@
import gradio as gr import gradio as gr
from typing import List from typing import List, Dict
from sqlmodel import Session, select from sqlmodel import Session, select
from db import get_engine from db import get_engine
from schema import APIProvider from schema import APIProvider
import os import os
# 获取数据库引擎
engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir")) engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir"))
def setting_page(): def setting_page():
def get_providers() -> List[List[str]]: def get_providers() -> List[List[str]]:
try: # 添加异常处理 with Session(engine) as session:
with Session(engine) as session: providers = session.exec(select(APIProvider)).all()
providers = session.exec(select(APIProvider)).all() return [
return [ [p.id, p.model_id, p.base_url, p.api_key or ""]
[p.id, p.model_id, p.base_url, p.api_key or ""] for p in providers
for p in providers ]
]
except Exception as e:
raise gr.Error(f"获取数据失败: {str(e)}")
def add_provider(model_id, base_url, api_key): def add_provider(model_id, base_url, api_key):
try: try:
@ -32,39 +30,34 @@ def setting_page():
session.refresh(new_provider) session.refresh(new_provider)
return get_providers() return get_providers()
except Exception as e: except Exception as e:
# 抛出错误提示
raise gr.Error(f"添加失败: {str(e)}") raise gr.Error(f"添加失败: {str(e)}")
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("## API Provider 管理") gr.Markdown("## API Provider 管理")
with gr.Row(): with gr.Row():
with gr.Column(scale=1): with gr.Column(scale=1):
model_id_input = gr.Textbox(label="Model ID") model_id_input = gr.Textbox(label="Model ID")
base_url_input = gr.Textbox(label="Base URL") base_url_input = gr.Textbox(label="Base URL")
api_key_input = gr.Textbox(label="API Key") api_key_input = gr.Textbox(label="API Key")
add_button = gr.Button("添加新API") add_button = gr.Button("添加新API")
# API Provider列表
with gr.Column(scale=3): with gr.Column(scale=3):
provider_table = gr.DataFrame( provider_table = gr.DataFrame(
headers=["id", "model id", "base URL", "API Key"], headers=["id" , "model id", "base URL", "API Key"],
datatype=["number", "str", "str", "str"], datatype=["number","str", "str", "str"],
interactive=True, interactive=True,
value=get_providers(), value=get_providers(),
wrap=True, wrap=True,
col_count=(4, "auto") col_count=(4, "auto")
) )
with gr.Row(): with gr.Row():
edit_button = gr.Button("编辑选中行") edit_button = gr.Button("编辑选中行")
delete_button = gr.Button("删除选中行") delete_button = gr.Button("删除选中行")
refresh_button = gr.Button("刷新数据", variant="secondary")
# 绑定刷新按钮事件
refresh_button.click(
fn=get_providers,
outputs=[provider_table],
queue=False # 立即刷新不需要排队
)
add_button.click( add_button.click(
fn=add_provider, fn=add_provider,
@ -72,4 +65,4 @@ def setting_page():
outputs=[provider_table] outputs=[provider_table]
) )
return demo return demo

View File

@ -1,5 +1,4 @@
openai>=1.0.0 openai>=1.0.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0
pydantic>=2.0.0 pydantic>=2.0.0
gradio>=3.0.0 gradio>=3.0.0
langchain>=0.3

View File

@ -4,8 +4,8 @@ from sqlmodel import SQLModel, Relationship, Field
class APIProvider(SQLModel, table=True): class APIProvider(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
base_url: str = Field(...,min_length=1,description="API的基础URL,不能为空") base_url: str = Field(..., description="API的基础URL")
model_id: str = Field(...,min_length=1,description="API使用的模型ID,不能为空") model_id: str = Field(..., description="API使用的模型ID")
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥") api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
created_at: datetime = Field( created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), default_factory=lambda: datetime.now(timezone.utc),