
- 在 prompt_manage_page 和 setting_page 中更新了 select_record 函数 - 使用 DataFrame.iloc 方法获取选中行的数据,并转换为列表 - 添加了将第一列数据转换为整数的逻辑 - 更新了表格选择事件的参数,增加了输入和输出参数 - 将 gradio 版本升级到 5.25.0
131 lines
4.9 KiB
Python
131 lines
4.9 KiB
Python
import gradio as gr
|
|
from typing import List
|
|
from sqlmodel import Session, select
|
|
from schema import APIProvider
|
|
from global_var import get_sql_engine
|
|
|
|
def setting_page():
|
|
def get_providers() -> List[List[str]]:
|
|
selected_row = None
|
|
try: # 添加异常处理
|
|
with Session(get_sql_engine()) as session:
|
|
providers = session.exec(select(APIProvider)).all()
|
|
return [
|
|
[p.id, p.model_id, p.base_url, p.api_key or ""]
|
|
for p in providers
|
|
]
|
|
except Exception as e:
|
|
raise gr.Error(f"获取数据失败: {str(e)}")
|
|
|
|
def add_provider(model_id, base_url, api_key):
|
|
try:
|
|
with Session(get_sql_engine()) as session:
|
|
new_provider = APIProvider(
|
|
model_id=model_id if model_id else None,
|
|
base_url=base_url if base_url else None,
|
|
api_key=api_key if api_key else None
|
|
)
|
|
session.add(new_provider)
|
|
session.commit()
|
|
session.refresh(new_provider)
|
|
return get_providers(), "", "", "" # 返回清空后的输入框值
|
|
except Exception as e:
|
|
raise gr.Error(f"添加失败: {str(e)}")
|
|
|
|
def edit_provider():
|
|
global selected_row
|
|
if not selected_row:
|
|
raise gr.Error("请先选择要编辑的行")
|
|
try:
|
|
with Session(get_sql_engine()) as session:
|
|
provider = session.get(APIProvider, selected_row[0])
|
|
if not provider:
|
|
raise gr.Error("找不到选中的记录")
|
|
provider.model_id = selected_row[1] if selected_row[1] else None
|
|
provider.base_url = selected_row[2] if selected_row[2] else None
|
|
provider.api_key = selected_row[3] if selected_row[3] else None
|
|
session.add(provider)
|
|
session.commit()
|
|
session.refresh(provider)
|
|
return get_providers()
|
|
except Exception as e:
|
|
raise gr.Error(f"编辑失败: {str(e)}")
|
|
|
|
def delete_provider():
|
|
global selected_row
|
|
if not selected_row:
|
|
raise gr.Error("请先选择要删除的行")
|
|
try:
|
|
with Session(get_sql_engine()) as session:
|
|
provider = session.get(APIProvider, selected_row[0])
|
|
if not provider:
|
|
raise gr.Error("找不到选中的记录")
|
|
session.delete(provider)
|
|
session.commit()
|
|
return get_providers()
|
|
except Exception as e:
|
|
raise gr.Error(f"删除失败: {str(e)}")
|
|
|
|
selected_row = None # 保存当前选中行的全局变量
|
|
|
|
def select_record(dataFrame ,evt: gr.SelectData):
|
|
global selected_row
|
|
selected_row = dataFrame.iloc[evt.index[0]].tolist()
|
|
selected_row[0] = int(selected_row[0])
|
|
print(selected_row)
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("## API Provider 管理")
|
|
|
|
with gr.Row():
|
|
with gr.Column(scale=1):
|
|
model_id_input = gr.Textbox(label="Model ID")
|
|
base_url_input = gr.Textbox(label="Base URL")
|
|
api_key_input = gr.Textbox(label="API Key")
|
|
add_button = gr.Button("添加新API", variant="primary")
|
|
|
|
with gr.Column(scale=3):
|
|
provider_table = gr.DataFrame(
|
|
headers=["id", "model id", "base URL", "API Key"],
|
|
datatype=["number", "str", "str", "str"],
|
|
interactive=True,
|
|
value=get_providers(),
|
|
wrap=True,
|
|
col_count=(4, "auto")
|
|
)
|
|
|
|
with gr.Row():
|
|
refresh_button = gr.Button("刷新数据", variant="secondary")
|
|
edit_button = gr.Button("编辑选中行", variant="primary")
|
|
delete_button = gr.Button("删除选中行", variant="stop")
|
|
|
|
refresh_button.click(
|
|
fn=get_providers,
|
|
outputs=[provider_table],
|
|
queue=False # 立即刷新不需要排队
|
|
)
|
|
|
|
add_button.click(
|
|
fn=add_provider,
|
|
inputs=[model_id_input, base_url_input, api_key_input],
|
|
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
|
|
)
|
|
|
|
provider_table.select(fn=select_record,
|
|
inputs=[provider_table],
|
|
outputs=[],
|
|
show_progress="hidden")
|
|
|
|
edit_button.click(
|
|
fn=edit_provider,
|
|
inputs=[],
|
|
outputs=[provider_table]
|
|
)
|
|
|
|
delete_button.click(
|
|
fn=delete_provider,
|
|
inputs=[],
|
|
outputs=[provider_table]
|
|
)
|
|
|
|
return demo |