From 5a386d64011dc8244e54233bcf4359b676356aef Mon Sep 17 00:00:00 2001 From: carry Date: Fri, 18 Apr 2025 15:23:33 +0800 Subject: [PATCH] =?UTF-8?q?feat(dataset=5Fgenerate=5Fpage):=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=20API=20=E9=80=89=E6=8B=A9=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在数据集生成页面添加 API 选择下拉框 - 实现 API 选择变更时的处理逻辑 - 更新数据集生成函数,增加 API 选择参数 - 优化页面布局和代码结构 --- frontend/dataset_generate_page.py | 38 ++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/frontend/dataset_generate_page.py b/frontend/dataset_generate_page.py index 517450c..ca86fc4 100644 --- a/frontend/dataset_generate_page.py +++ b/frontend/dataset_generate_page.py @@ -2,9 +2,11 @@ import gradio as gr import sys from pathlib import Path from langchain.prompts import PromptTemplate +from sqlmodel import Session, select sys.path.append(str(Path(__file__).resolve().parent.parent)) -from global_var import get_docs, get_prompt_store +from schema import APIProvider +from global_var import get_docs, get_prompt_store, get_sql_engine def dataset_generate_page(): with gr.Blocks() as demo: @@ -19,7 +21,7 @@ def dataset_generate_page(): label="选择文档", interactive=True ) - doc_state = gr.State(value=initial_doc) + doc_choice = gr.State(value=initial_doc) with gr.Column(): prompts = get_prompt_store().all() @@ -42,7 +44,22 @@ def dataset_generate_page(): label="选择模板", interactive=True ) - prompt_state = gr.State(value=initial_prompt) + prompt_choice = gr.State(value=initial_prompt) + + with gr.Column(): + # 从数据库获取API Provider列表 + with Session(get_sql_engine()) as session: + providers = session.exec(select(APIProvider)).all() + api_list = [f"{p.id} {p.model_id}" for p in providers] + initial_api = api_list[0] if api_list else None + + api_dropdown = gr.Dropdown( + choices=api_list, + value=initial_api, + label="选择API", + interactive=True + ) + api_choice = gr.State(value=initial_api) generate_button = gr.Button("生成数据集",variant="primary") @@ -58,8 +75,11 @@ def dataset_generate_page(): def on_doc_change(selected_doc): - # print(f"文档选择已更改为: {selected_doc}") return selected_doc + + def on_api_change(selected_api): + return selected_api + def on_prompt_change(selected_prompt): if not selected_prompt: return None, [] @@ -72,7 +92,8 @@ def dataset_generate_page(): dataframe_value = [] if input_variables is None else input_variables return selected_prompt, dataframe_value - def on_generate_click(doc_state, prompt_state, variables_dataframe, progress=gr.Progress()): + + def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, progress=gr.Progress()): variables_dict = {} # 正确遍历DataFrame的行数据 for _, row in variables_dataframe.iterrows(): @@ -96,12 +117,13 @@ def dataset_generate_page(): return "all done" - doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_state) - prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_state, variables_dataframe]) + doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice) + prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe]) + api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice) generate_button.click( on_generate_click, - inputs=[doc_state, prompt_state, variables_dataframe], + inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe], outputs=output_text )