gzhu-biyesheji/frontend/dataset_generate_page.py
carry 3718c75cee feat(frontend): 添加数据集生成页面的处理进度显示
- 在处理文档片段时添加进度条,提升用户体验
- 优化代码格式,调整缩进和空行
2025-04-22 00:14:16 +08:00

189 lines
8.4 KiB
Python

import gradio as gr
import sys
import json
from tinydb import Query
from pathlib import Path
from langchain.prompts import PromptTemplate
from sqlmodel import Session, select
from schema import Dataset, DatasetItem, Q_A
from db.dataset_store import get_all_dataset
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
from db import save_dataset
from tools import call_openai_api, process_markdown_file, generate_json_example
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
def dataset_generate_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集生成")
with gr.Row():
with gr.Column(scale=1):
docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None
prompts = get_prompt_store().all()
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_list[0] if prompt_list else None
# 初始化Dataframe的值
initial_dataframe_value = []
if initial_prompt:
selected_prompt_id = int(initial_prompt.split(" ")[0])
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
initial_dataframe_value = [[var, ""] for var in input_variables]
# 从数据库获取API Provider列表
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
api_list = [f"{p.id} {p.model_id}" for p in providers]
initial_api = api_list[0] if api_list else None
api_dropdown = gr.Dropdown(
choices=api_list,
value=initial_api,
label="选择API",
interactive=True
)
doc_dropdown = gr.Dropdown(
choices=docs_list,
value=initial_doc,
label="选择文档",
interactive=True
)
prompt_dropdown = gr.Dropdown(
choices=prompt_list,
value=initial_prompt,
label="选择模板",
interactive=True
)
rounds_input = gr.Number(
value=1,
label="生成轮次",
minimum=1,
maximum=100,
step=1,
interactive=True
)
concurrency_input = gr.Number(
value=1,
label="并发数",
minimum=1,
maximum=10,
step=1,
interactive=True,
visible=False
)
dataset_name_input = gr.Textbox(
label="数据集名称",
placeholder="输入数据集保存名称",
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
generate_button = gr.Button("生成数据集",variant="primary")
doc_choice = gr.State(value=initial_doc)
output_text = gr.Textbox(label="生成结果", interactive=False)
api_choice = gr.State(value=initial_api)
with gr.Column(scale=2):
variables_dataframe = gr.Dataframe(
headers=["变量名", "变量值"],
datatype=["str", "str"],
interactive=True,
label="变量列表",
value=initial_dataframe_value # 设置初始化数据
)
def on_doc_change(selected_doc):
return selected_doc
def on_api_change(selected_api):
return selected_api
def on_prompt_change(selected_prompt):
if not selected_prompt:
return None, []
selected_prompt_id = int(selected_prompt.split(" ")[0])
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
dataframe_value = [] if input_variables is None else input_variables
dataframe_value = [[var, ""] for var in input_variables]
return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
dataset_db = get_datasets()
if not dataset_db.search(Query().name == dataset_name):
raise gr.Error("数据集名称已存在")
doc = [i for i in get_docs() if i.name == doc_state][0]
doc_files = doc.markdown_files
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
prompt = PromptTemplate.from_template(prompt["content"])
with Session(get_sql_engine()) as session:
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
variables_dict = {}
for _, row in variables_dataframe.iterrows():
var_name = row['变量名'].strip()
var_value = row['变量值'].strip()
if var_name:
variables_dict[var_name] = var_value
prompt = prompt.partial(**variables_dict)
dataset = Dataset(
name=dataset_name,
model_id=[api_provider.model_id],
source_doc=doc,
dataset_items=[]
)
total_slices = len(document_slice_list)
for i, document_slice in enumerate(document_slice_list):
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
request = LLMRequest(api_provider=api_provider,
prompt=prompt.format(document_slice=document_slice),
format=generate_json_example(DatasetItem))
call_openai_api(request, rounds)
for resp in request.response:
try:
content = json.loads(resp.content)
dataset_item = DatasetItem(
message=[Q_A(
question=content.get("question", ""),
answer=content.get("answer", "")
)]
)
dataset.dataset_items.append(dataset_item)
except json.JSONDecodeError as e:
print(f"Failed to parse response: {e}")
# 保存数据集到TinyDB
dataset_db.insert(dataset.model_dump())
save_dataset(dataset_db,get_workdir(),dataset_name)
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
generate_button.click(
on_generate_click,
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
outputs=output_text
)
return demo
if __name__ == "__main__":
from global_var import init_global_var
init_global_var("workdir")
demo = dataset_generate_page()
demo.launch()