import gradio as gr import sys import json from tinydb import Query from pathlib import Path from langchain.prompts import PromptTemplate from sqlmodel import Session, select from schema import Dataset, DatasetItem, Q_A from db.dataset_store import get_all_dataset sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem from db import save_dataset from tools import call_openai_api, process_markdown_file, generate_json_example from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir def dataset_generate_page(): with gr.Blocks() as demo: gr.Markdown("## 数据集生成") with gr.Row(): with gr.Column(scale=1): docs_list = [str(doc.name) for doc in get_docs()] initial_doc = docs_list[0] if docs_list else None prompts = get_prompt_store().all() prompt_list = [f"{p['id']} {p['name']}" for p in prompts] initial_prompt = prompt_list[0] if prompt_list else None # 初始化Dataframe的值 initial_dataframe_value = [] if initial_prompt: selected_prompt_id = int(initial_prompt.split(" ")[0]) prompt_data = get_prompt_store().get(doc_id=selected_prompt_id) prompt_content = prompt_data["content"] prompt_template = PromptTemplate.from_template(prompt_content) input_variables = prompt_template.input_variables input_variables.remove("document_slice") initial_dataframe_value = [[var, ""] for var in input_variables] # 从数据库获取API Provider列表 with Session(get_sql_engine()) as session: providers = session.exec(select(APIProvider)).all() api_list = [f"{p.id} {p.model_id}" for p in providers] initial_api = api_list[0] if api_list else None api_dropdown = gr.Dropdown( choices=api_list, value=initial_api, label="选择API", interactive=True ) doc_dropdown = gr.Dropdown( choices=docs_list, value=initial_doc, label="选择文档", interactive=True ) prompt_dropdown = gr.Dropdown( choices=prompt_list, value=initial_prompt, label="选择模板", interactive=True ) rounds_input = gr.Number( value=1, label="生成轮次", minimum=1, maximum=100, step=1, interactive=True ) concurrency_input = gr.Number( value=1, label="并发数", minimum=1, maximum=10, step=1, interactive=True, visible=False ) dataset_name_input = gr.Textbox( label="数据集名称", placeholder="输入数据集保存名称", interactive=True ) prompt_choice = gr.State(value=initial_prompt) generate_button = gr.Button("生成数据集",variant="primary") doc_choice = gr.State(value=initial_doc) output_text = gr.Textbox(label="生成结果", interactive=False) api_choice = gr.State(value=initial_api) with gr.Column(scale=2): variables_dataframe = gr.Dataframe( headers=["变量名", "变量值"], datatype=["str", "str"], interactive=True, label="变量列表", value=initial_dataframe_value # 设置初始化数据 ) def on_doc_change(selected_doc): return selected_doc def on_api_change(selected_api): return selected_api def on_prompt_change(selected_prompt): if not selected_prompt: return None, [] selected_prompt_id = int(selected_prompt.split(" ")[0]) prompt_data = get_prompt_store().get(doc_id=selected_prompt_id) prompt_content = prompt_data["content"] prompt_template = PromptTemplate.from_template(prompt_content) input_variables = prompt_template.input_variables input_variables.remove("document_slice") dataframe_value = [] if input_variables is None else input_variables dataframe_value = [[var, ""] for var in input_variables] return selected_prompt, dataframe_value def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()): dataset_db = get_datasets() if not dataset_db.search(Query().name == dataset_name): raise gr.Error("数据集名称已存在") doc = [i for i in get_docs() if i.name == doc_state][0] doc_files = doc.markdown_files document_slice_list = [process_markdown_file(doc) for doc in doc_files] prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0] prompt = PromptTemplate.from_template(prompt["content"]) with Session(get_sql_engine()) as session: api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first() variables_dict = {} for _, row in variables_dataframe.iterrows(): var_name = row['变量名'].strip() var_value = row['变量值'].strip() if var_name: variables_dict[var_name] = var_value prompt = prompt.partial(**variables_dict) dataset = Dataset( name=dataset_name, model_id=[api_provider.model_id], source_doc=doc, dataset_items=[] ) total_slices = len(document_slice_list) for i, document_slice in enumerate(document_slice_list): progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}") request = LLMRequest(api_provider=api_provider, prompt=prompt.format(document_slice=document_slice), format=generate_json_example(DatasetItem)) call_openai_api(request, rounds) for resp in request.response: try: content = json.loads(resp.content) dataset_item = DatasetItem( message=[Q_A( question=content.get("question", ""), answer=content.get("answer", "") )] ) dataset.dataset_items.append(dataset_item) except json.JSONDecodeError as e: print(f"Failed to parse response: {e}") # 保存数据集到TinyDB dataset_db.insert(dataset.model_dump()) save_dataset(dataset_db,get_workdir(),dataset_name) return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据" doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice) prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe]) api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice) generate_button.click( on_generate_click, inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input], outputs=output_text ) return demo if __name__ == "__main__": from global_var import init_global_var init_global_var("workdir") demo = dataset_generate_page() demo.launch()