feat(dataset): 添加数据集生成功能

- 新增数据集生成页面和相关逻辑
- 实现数据集名称重复性检查
- 添加数据集对象创建和保存功能
- 优化文档处理和提示模板应用
- 增加错误处理和数据解析
This commit is contained in:
carry 2025-04-20 21:25:51 +08:00
parent 994d600221
commit 0a4efa5641

View File

@ -1,13 +1,18 @@
import gradio as gr import gradio as gr
import sys import sys
import json
from tinydb import Query
from pathlib import Path from pathlib import Path
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from sqlmodel import Session, select from sqlmodel import Session, select
from schema import Dataset, DatasetItem, Q_A
from db.dataset_store import get_all_dataset
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
from db import save_dataset
from tools import call_openai_api, process_markdown_file, generate_json_example from tools import call_openai_api, process_markdown_file, generate_json_example
from global_var import get_docs, get_prompt_store, get_sql_engine from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
def dataset_generate_page(): def dataset_generate_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@ -109,8 +114,13 @@ def dataset_generate_page():
return selected_prompt, dataframe_value return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()): def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
docs = [i for i in get_docs() if i.name == doc_state][0].markdown_files dataset_db = get_datasets()
document_slice_list = [process_markdown_file(doc) for doc in docs] if not dataset_db.search(Query().name == dataset_name):
raise gr.Error("数据集名称已存在")
doc = [i for i in get_docs() if i.name == doc_state][0]
doc_files = doc.markdown_files
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0] prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
prompt = PromptTemplate.from_template(prompt["content"]) prompt = PromptTemplate.from_template(prompt["content"])
with Session(get_sql_engine()) as session: with Session(get_sql_engine()) as session:
@ -124,15 +134,41 @@ def dataset_generate_page():
variables_dict[var_name] = var_value variables_dict[var_name] = var_value
prompt = prompt.partial(**variables_dict) prompt = prompt.partial(**variables_dict)
dataset = Dataset(
name=dataset_name,
model_id=[api_provider.model_id],
source_doc=doc,
dataset_items=[]
)
for document_slice in document_slice_list: for document_slice in document_slice_list:
request = LLMRequest(api_provider=api_provider, request = LLMRequest(api_provider=api_provider,
prompt=prompt.format(document_slice=document_slice), prompt=prompt.format(document_slice=document_slice),
format=generate_json_example(DatasetItem)) format=generate_json_example(DatasetItem))
call_openai_api(request, rounds) call_openai_api(request, rounds)
return "all done" for resp in request.response:
try:
content = json.loads(resp.content)
dataset_item = DatasetItem(
message=[Q_A(
question=content.get("question", ""),
answer=content.get("answer", "")
)]
)
dataset.dataset_items.append(dataset_item)
except json.JSONDecodeError as e:
print(f"Failed to parse response: {e}")
# 保存数据集到TinyDB
dataset_db.insert(dataset.model_dump())
save_dataset(dataset_db,get_workdir(),dataset_name)
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice) doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe]) prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])