diff --git a/frontend/dataset_generate_page.py b/frontend/dataset_generate_page.py index 7d21123..536b66f 100644 --- a/frontend/dataset_generate_page.py +++ b/frontend/dataset_generate_page.py @@ -1,13 +1,18 @@ import gradio as gr import sys +import json +from tinydb import Query from pathlib import Path from langchain.prompts import PromptTemplate from sqlmodel import Session, select +from schema import Dataset, DatasetItem, Q_A +from db.dataset_store import get_all_dataset sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem +from db import save_dataset from tools import call_openai_api, process_markdown_file, generate_json_example -from global_var import get_docs, get_prompt_store, get_sql_engine +from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir def dataset_generate_page(): with gr.Blocks() as demo: @@ -109,8 +114,13 @@ def dataset_generate_page(): return selected_prompt, dataframe_value def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()): - docs = [i for i in get_docs() if i.name == doc_state][0].markdown_files - document_slice_list = [process_markdown_file(doc) for doc in docs] + dataset_db = get_datasets() + if not dataset_db.search(Query().name == dataset_name): + raise gr.Error("数据集名称已存在") + + doc = [i for i in get_docs() if i.name == doc_state][0] + doc_files = doc.markdown_files + document_slice_list = [process_markdown_file(doc) for doc in doc_files] prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0] prompt = PromptTemplate.from_template(prompt["content"]) with Session(get_sql_engine()) as session: @@ -124,15 +134,41 @@ def dataset_generate_page(): variables_dict[var_name] = var_value prompt = prompt.partial(**variables_dict) + + + dataset = Dataset( + name=dataset_name, + model_id=[api_provider.model_id], + source_doc=doc, + dataset_items=[] + ) + for document_slice in document_slice_list: request = LLMRequest(api_provider=api_provider, prompt=prompt.format(document_slice=document_slice), format=generate_json_example(DatasetItem)) call_openai_api(request, rounds) - - return "all done" + for resp in request.response: + try: + content = json.loads(resp.content) + dataset_item = DatasetItem( + message=[Q_A( + question=content.get("question", ""), + answer=content.get("answer", "") + )] + ) + dataset.dataset_items.append(dataset_item) + except json.JSONDecodeError as e: + print(f"Failed to parse response: {e}") + + # 保存数据集到TinyDB + dataset_db.insert(dataset.model_dump()) + + save_dataset(dataset_db,get_workdir(),dataset_name) + + return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据" doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice) prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])