feat(dataset): 添加数据集生成功能
- 新增数据集生成页面和相关逻辑 - 实现数据集名称重复性检查 - 添加数据集对象创建和保存功能 - 优化文档处理和提示模板应用 - 增加错误处理和数据解析
This commit is contained in:
parent
994d600221
commit
0a4efa5641
@ -1,13 +1,18 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import sys
|
import sys
|
||||||
|
import json
|
||||||
|
from tinydb import Query
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
from schema import Dataset, DatasetItem, Q_A
|
||||||
|
from db.dataset_store import get_all_dataset
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
|
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
|
||||||
|
from db import save_dataset
|
||||||
from tools import call_openai_api, process_markdown_file, generate_json_example
|
from tools import call_openai_api, process_markdown_file, generate_json_example
|
||||||
from global_var import get_docs, get_prompt_store, get_sql_engine
|
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
|
||||||
|
|
||||||
def dataset_generate_page():
|
def dataset_generate_page():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
@ -109,8 +114,13 @@ def dataset_generate_page():
|
|||||||
return selected_prompt, dataframe_value
|
return selected_prompt, dataframe_value
|
||||||
|
|
||||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
|
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
|
||||||
docs = [i for i in get_docs() if i.name == doc_state][0].markdown_files
|
dataset_db = get_datasets()
|
||||||
document_slice_list = [process_markdown_file(doc) for doc in docs]
|
if not dataset_db.search(Query().name == dataset_name):
|
||||||
|
raise gr.Error("数据集名称已存在")
|
||||||
|
|
||||||
|
doc = [i for i in get_docs() if i.name == doc_state][0]
|
||||||
|
doc_files = doc.markdown_files
|
||||||
|
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
|
||||||
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
||||||
prompt = PromptTemplate.from_template(prompt["content"])
|
prompt = PromptTemplate.from_template(prompt["content"])
|
||||||
with Session(get_sql_engine()) as session:
|
with Session(get_sql_engine()) as session:
|
||||||
@ -124,15 +134,41 @@ def dataset_generate_page():
|
|||||||
variables_dict[var_name] = var_value
|
variables_dict[var_name] = var_value
|
||||||
|
|
||||||
prompt = prompt.partial(**variables_dict)
|
prompt = prompt.partial(**variables_dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
dataset = Dataset(
|
||||||
|
name=dataset_name,
|
||||||
|
model_id=[api_provider.model_id],
|
||||||
|
source_doc=doc,
|
||||||
|
dataset_items=[]
|
||||||
|
)
|
||||||
|
|
||||||
for document_slice in document_slice_list:
|
for document_slice in document_slice_list:
|
||||||
request = LLMRequest(api_provider=api_provider,
|
request = LLMRequest(api_provider=api_provider,
|
||||||
prompt=prompt.format(document_slice=document_slice),
|
prompt=prompt.format(document_slice=document_slice),
|
||||||
format=generate_json_example(DatasetItem))
|
format=generate_json_example(DatasetItem))
|
||||||
call_openai_api(request, rounds)
|
call_openai_api(request, rounds)
|
||||||
|
|
||||||
|
|
||||||
return "all done"
|
for resp in request.response:
|
||||||
|
try:
|
||||||
|
content = json.loads(resp.content)
|
||||||
|
dataset_item = DatasetItem(
|
||||||
|
message=[Q_A(
|
||||||
|
question=content.get("question", ""),
|
||||||
|
answer=content.get("answer", "")
|
||||||
|
)]
|
||||||
|
)
|
||||||
|
dataset.dataset_items.append(dataset_item)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Failed to parse response: {e}")
|
||||||
|
|
||||||
|
# 保存数据集到TinyDB
|
||||||
|
dataset_db.insert(dataset.model_dump())
|
||||||
|
|
||||||
|
save_dataset(dataset_db,get_workdir(),dataset_name)
|
||||||
|
|
||||||
|
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
|
||||||
|
|
||||||
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
||||||
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user