Compare commits

...

3 Commits

Author SHA1 Message Date
carry
475cd033d9 build: 添加 langchain 依赖
- 在 requirements.txt 中添加 langchain>=0.3 版本的依赖
- 保持其他依赖版本不变
2025-04-08 11:53:58 +08:00
carry
3970a67df3 refactor(dataset_generation): 增加 APIProvider 模型字段的最小长度验证
- 为 base_url 和 model_id 字段添加 min_length=1 的验证
- 更新字段描述,明确这些字段不能为空
2025-04-07 23:37:14 +08:00
carry
286db405ca feat(frontend): 优化设置页面并添加数据刷新功能
- 为 get_providers 函数添加异常处理,提高数据获取的稳定性
- 在设置页面添加刷新按钮,用户可手动触发数据刷新
- 优化页面布局,调整组件间距和对齐方式
2025-04-07 23:17:43 +08:00
3 changed files with 26 additions and 18 deletions

View File

@ -1,21 +1,23 @@
import gradio as gr import gradio as gr
from typing import List, Dict from typing import List
from sqlmodel import Session, select from sqlmodel import Session, select
from db import get_engine from db import get_engine
from schema import APIProvider from schema import APIProvider
import os import os
# 获取数据库引擎
engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir")) engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir"))
def setting_page(): def setting_page():
def get_providers() -> List[List[str]]: def get_providers() -> List[List[str]]:
with Session(engine) as session: try: # 添加异常处理
providers = session.exec(select(APIProvider)).all() with Session(engine) as session:
return [ providers = session.exec(select(APIProvider)).all()
[p.id, p.model_id, p.base_url, p.api_key or ""] return [
for p in providers [p.id, p.model_id, p.base_url, p.api_key or ""]
] for p in providers
]
except Exception as e:
raise gr.Error(f"获取数据失败: {str(e)}")
def add_provider(model_id, base_url, api_key): def add_provider(model_id, base_url, api_key):
try: try:
@ -30,34 +32,39 @@ def setting_page():
session.refresh(new_provider) session.refresh(new_provider)
return get_providers() return get_providers()
except Exception as e: except Exception as e:
# 抛出错误提示
raise gr.Error(f"添加失败: {str(e)}") raise gr.Error(f"添加失败: {str(e)}")
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("## API Provider 管理") gr.Markdown("## API Provider 管理")
with gr.Row(): with gr.Row():
with gr.Column(scale=1): with gr.Column(scale=1):
model_id_input = gr.Textbox(label="Model ID") model_id_input = gr.Textbox(label="Model ID")
base_url_input = gr.Textbox(label="Base URL") base_url_input = gr.Textbox(label="Base URL")
api_key_input = gr.Textbox(label="API Key") api_key_input = gr.Textbox(label="API Key")
add_button = gr.Button("添加新API") add_button = gr.Button("添加新API")
# API Provider列表
with gr.Column(scale=3): with gr.Column(scale=3):
provider_table = gr.DataFrame( provider_table = gr.DataFrame(
headers=["id" , "model id", "base URL", "API Key"], headers=["id", "model id", "base URL", "API Key"],
datatype=["number","str", "str", "str"], datatype=["number", "str", "str", "str"],
interactive=True, interactive=True,
value=get_providers(), value=get_providers(),
wrap=True, wrap=True,
col_count=(4, "auto") col_count=(4, "auto")
) )
with gr.Row(): with gr.Row():
edit_button = gr.Button("编辑选中行") edit_button = gr.Button("编辑选中行")
delete_button = gr.Button("删除选中行") delete_button = gr.Button("删除选中行")
refresh_button = gr.Button("刷新数据", variant="secondary")
# 绑定刷新按钮事件
refresh_button.click(
fn=get_providers,
outputs=[provider_table],
queue=False # 立即刷新不需要排队
)
add_button.click( add_button.click(
fn=add_provider, fn=add_provider,
@ -65,4 +72,4 @@ def setting_page():
outputs=[provider_table] outputs=[provider_table]
) )
return demo return demo

View File

@ -1,4 +1,5 @@
openai>=1.0.0 openai>=1.0.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0
pydantic>=2.0.0 pydantic>=2.0.0
gradio>=3.0.0 gradio>=3.0.0
langchain>=0.3

View File

@ -4,8 +4,8 @@ from sqlmodel import SQLModel, Relationship, Field
class APIProvider(SQLModel, table=True): class APIProvider(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
base_url: str = Field(..., description="API的基础URL") base_url: str = Field(...,min_length=1,description="API的基础URL,不能为空")
model_id: str = Field(..., description="API使用的模型ID") model_id: str = Field(...,min_length=1,description="API使用的模型ID,不能为空")
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥") api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
created_at: datetime = Field( created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), default_factory=lambda: datetime.now(timezone.utc),