import gradio as gr import sys from pathlib import Path from langchain.prompts import PromptTemplate from sqlmodel import Session, select sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema import APIProvider from tools import call_openai_api from global_var import get_docs, get_prompt_store, get_sql_engine def dataset_generate_page(): with gr.Blocks() as demo: gr.Markdown("## 数据集生成") with gr.Row(): with gr.Column(scale=1): docs_list = [str(doc.name) for doc in get_docs()] initial_doc = docs_list[0] if docs_list else None doc_dropdown = gr.Dropdown( choices=docs_list, value=initial_doc, label="选择文档", interactive=True ) doc_choice = gr.State(value=initial_doc) prompts = get_prompt_store().all() prompt_list = [f"{p['id']} {p['name']}" for p in prompts] initial_prompt = prompt_list[0] if prompt_list else None # 初始化Dataframe的值 initial_dataframe_value = [] if initial_prompt: selected_prompt_id = int(initial_prompt.split(" ")[0]) prompt_data = get_prompt_store().get(doc_id=selected_prompt_id) prompt_content = prompt_data["content"] prompt_template = PromptTemplate.from_template(prompt_content) input_variables = prompt_template.input_variables input_variables.remove("document_slice") initial_dataframe_value = [[var, ""] for var in input_variables] prompt_dropdown = gr.Dropdown( choices=prompt_list, value=initial_prompt, label="选择模板", interactive=True ) prompt_choice = gr.State(value=initial_prompt) # 从数据库获取API Provider列表 with Session(get_sql_engine()) as session: providers = session.exec(select(APIProvider)).all() api_list = [f"{p.id} {p.model_id}" for p in providers] initial_api = api_list[0] if api_list else None api_dropdown = gr.Dropdown( choices=api_list, value=initial_api, label="选择API", interactive=True ) api_choice = gr.State(value=initial_api) rounds_input = gr.Number( value=1, label="生成轮次", minimum=1, maximum=100, step=1, interactive=True ) generate_button = gr.Button("生成数据集",variant="primary") output_text = gr.Textbox(label="生成结果", interactive=False) with gr.Column(scale=2): variables_dataframe = gr.Dataframe( headers=["变量名", "变量值"], datatype=["str", "str"], interactive=True, label="变量列表", value=initial_dataframe_value # 设置初始化数据 ) def on_doc_change(selected_doc): return selected_doc def on_api_change(selected_api): return selected_api def on_prompt_change(selected_prompt): if not selected_prompt: return None, [] selected_prompt_id = int(selected_prompt.split(" ")[0]) prompt_data = get_prompt_store().get(doc_id=selected_prompt_id) prompt_content = prompt_data["content"] prompt_template = PromptTemplate.from_template(prompt_content) input_variables = prompt_template.input_variables input_variables.remove("document_slice") dataframe_value = [] if input_variables is None else input_variables dataframe_value = [[var, ""] for var in input_variables] return selected_prompt, dataframe_value def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()): doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0] prompt = PromptTemplate.from_template(prompt["content"]) with Session(get_sql_engine()) as session: api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first() variables_dict = {} for _, row in variables_dataframe.iterrows(): var_name = row['变量名'].strip() var_value = row['变量值'].strip() if var_name: variables_dict[var_name] = var_value # 注入除document_slice以外的所有参数 prompt = prompt.partial(**variables_dict) print(doc) print(prompt.format(document_slice="test")) print(variables_dict) import time total_steps = rounds for i in range(total_steps): # 模拟每个步骤的工作负载 time.sleep(0.5) current_progress = (i + 1) / total_steps progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}") return "all done" doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice) prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe]) api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice) generate_button.click( on_generate_click, inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input], outputs=output_text ) return demo if __name__ == "__main__": from global_var import init_global_var init_global_var("workdir") demo = dataset_generate_page() demo.launch()