Compare commits
3 Commits
d40f5b1f24
...
475cd033d9
Author | SHA1 | Date | |
---|---|---|---|
![]() |
475cd033d9 | ||
![]() |
3970a67df3 | ||
![]() |
286db405ca |
@ -1,21 +1,23 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import List, Dict
|
from typing import List
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from db import get_engine
|
from db import get_engine
|
||||||
from schema import APIProvider
|
from schema import APIProvider
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# 获取数据库引擎
|
|
||||||
engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir"))
|
engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir"))
|
||||||
|
|
||||||
def setting_page():
|
def setting_page():
|
||||||
def get_providers() -> List[List[str]]:
|
def get_providers() -> List[List[str]]:
|
||||||
with Session(engine) as session:
|
try: # 添加异常处理
|
||||||
providers = session.exec(select(APIProvider)).all()
|
with Session(engine) as session:
|
||||||
return [
|
providers = session.exec(select(APIProvider)).all()
|
||||||
[p.id, p.model_id, p.base_url, p.api_key or ""]
|
return [
|
||||||
for p in providers
|
[p.id, p.model_id, p.base_url, p.api_key or ""]
|
||||||
]
|
for p in providers
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
raise gr.Error(f"获取数据失败: {str(e)}")
|
||||||
|
|
||||||
def add_provider(model_id, base_url, api_key):
|
def add_provider(model_id, base_url, api_key):
|
||||||
try:
|
try:
|
||||||
@ -30,34 +32,39 @@ def setting_page():
|
|||||||
session.refresh(new_provider)
|
session.refresh(new_provider)
|
||||||
return get_providers()
|
return get_providers()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 抛出错误提示
|
|
||||||
raise gr.Error(f"添加失败: {str(e)}")
|
raise gr.Error(f"添加失败: {str(e)}")
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("## API Provider 管理")
|
gr.Markdown("## API Provider 管理")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
|
|
||||||
model_id_input = gr.Textbox(label="Model ID")
|
model_id_input = gr.Textbox(label="Model ID")
|
||||||
base_url_input = gr.Textbox(label="Base URL")
|
base_url_input = gr.Textbox(label="Base URL")
|
||||||
api_key_input = gr.Textbox(label="API Key")
|
api_key_input = gr.Textbox(label="API Key")
|
||||||
add_button = gr.Button("添加新API")
|
add_button = gr.Button("添加新API")
|
||||||
|
|
||||||
# API Provider列表
|
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
provider_table = gr.DataFrame(
|
provider_table = gr.DataFrame(
|
||||||
headers=["id" , "model id", "base URL", "API Key"],
|
headers=["id", "model id", "base URL", "API Key"],
|
||||||
datatype=["number","str", "str", "str"],
|
datatype=["number", "str", "str", "str"],
|
||||||
interactive=True,
|
interactive=True,
|
||||||
value=get_providers(),
|
value=get_providers(),
|
||||||
wrap=True,
|
wrap=True,
|
||||||
col_count=(4, "auto")
|
col_count=(4, "auto")
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
edit_button = gr.Button("编辑选中行")
|
edit_button = gr.Button("编辑选中行")
|
||||||
delete_button = gr.Button("删除选中行")
|
delete_button = gr.Button("删除选中行")
|
||||||
|
refresh_button = gr.Button("刷新数据", variant="secondary")
|
||||||
|
|
||||||
|
# 绑定刷新按钮事件
|
||||||
|
refresh_button.click(
|
||||||
|
fn=get_providers,
|
||||||
|
outputs=[provider_table],
|
||||||
|
queue=False # 立即刷新不需要排队
|
||||||
|
)
|
||||||
|
|
||||||
add_button.click(
|
add_button.click(
|
||||||
fn=add_provider,
|
fn=add_provider,
|
||||||
@ -65,4 +72,4 @@ def setting_page():
|
|||||||
outputs=[provider_table]
|
outputs=[provider_table]
|
||||||
)
|
)
|
||||||
|
|
||||||
return demo
|
return demo
|
@ -1,4 +1,5 @@
|
|||||||
openai>=1.0.0
|
openai>=1.0.0
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
pydantic>=2.0.0
|
pydantic>=2.0.0
|
||||||
gradio>=3.0.0
|
gradio>=3.0.0
|
||||||
|
langchain>=0.3
|
@ -4,8 +4,8 @@ from sqlmodel import SQLModel, Relationship, Field
|
|||||||
|
|
||||||
class APIProvider(SQLModel, table=True):
|
class APIProvider(SQLModel, table=True):
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
base_url: str = Field(..., description="API的基础URL")
|
base_url: str = Field(...,min_length=1,description="API的基础URL,不能为空")
|
||||||
model_id: str = Field(..., description="API使用的模型ID")
|
model_id: str = Field(...,min_length=1,description="API使用的模型ID,不能为空")
|
||||||
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
|
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
|
||||||
created_at: datetime = Field(
|
created_at: datetime = Field(
|
||||||
default_factory=lambda: datetime.now(timezone.utc),
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user