189 lines
8.4 KiB
Python
189 lines
8.4 KiB
Python
import gradio as gr
|
|
import sys
|
|
import json
|
|
from tinydb import Query
|
|
from pathlib import Path
|
|
from langchain.prompts import PromptTemplate
|
|
from sqlmodel import Session, select
|
|
from schema import Dataset, DatasetItem, Q_A
|
|
from db.dataset_store import get_all_dataset
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
|
|
from db import save_dataset
|
|
from tools import call_openai_api, process_markdown_file, generate_json_example
|
|
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
|
|
|
|
def dataset_generate_page():
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("## 数据集生成")
|
|
with gr.Row():
|
|
with gr.Column(scale=1):
|
|
docs_list = [str(doc.name) for doc in get_docs()]
|
|
initial_doc = docs_list[0] if docs_list else None
|
|
prompts = get_prompt_store().all()
|
|
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
|
|
initial_prompt = prompt_list[0] if prompt_list else None
|
|
|
|
# 初始化Dataframe的值
|
|
initial_dataframe_value = []
|
|
if initial_prompt:
|
|
selected_prompt_id = int(initial_prompt.split(" ")[0])
|
|
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
|
|
prompt_content = prompt_data["content"]
|
|
prompt_template = PromptTemplate.from_template(prompt_content)
|
|
input_variables = prompt_template.input_variables
|
|
input_variables.remove("document_slice")
|
|
initial_dataframe_value = [[var, ""] for var in input_variables]
|
|
# 从数据库获取API Provider列表
|
|
with Session(get_sql_engine()) as session:
|
|
providers = session.exec(select(APIProvider)).all()
|
|
api_list = [f"{p.id} {p.model_id}" for p in providers]
|
|
initial_api = api_list[0] if api_list else None
|
|
|
|
api_dropdown = gr.Dropdown(
|
|
choices=api_list,
|
|
value=initial_api,
|
|
label="选择API",
|
|
interactive=True
|
|
)
|
|
doc_dropdown = gr.Dropdown(
|
|
choices=docs_list,
|
|
value=initial_doc,
|
|
label="选择文档",
|
|
interactive=True
|
|
)
|
|
prompt_dropdown = gr.Dropdown(
|
|
choices=prompt_list,
|
|
value=initial_prompt,
|
|
label="选择模板",
|
|
interactive=True
|
|
)
|
|
rounds_input = gr.Number(
|
|
value=1,
|
|
label="生成轮次",
|
|
minimum=1,
|
|
maximum=100,
|
|
step=1,
|
|
interactive=True
|
|
)
|
|
concurrency_input = gr.Number(
|
|
value=1,
|
|
label="并发数",
|
|
minimum=1,
|
|
maximum=10,
|
|
step=1,
|
|
interactive=True,
|
|
visible=False
|
|
)
|
|
dataset_name_input = gr.Textbox(
|
|
label="数据集名称",
|
|
placeholder="输入数据集保存名称",
|
|
interactive=True
|
|
)
|
|
prompt_choice = gr.State(value=initial_prompt)
|
|
generate_button = gr.Button("生成数据集",variant="primary")
|
|
doc_choice = gr.State(value=initial_doc)
|
|
output_text = gr.Textbox(label="生成结果", interactive=False)
|
|
api_choice = gr.State(value=initial_api)
|
|
with gr.Column(scale=2):
|
|
variables_dataframe = gr.Dataframe(
|
|
headers=["变量名", "变量值"],
|
|
datatype=["str", "str"],
|
|
interactive=True,
|
|
label="变量列表",
|
|
value=initial_dataframe_value # 设置初始化数据
|
|
)
|
|
def on_doc_change(selected_doc):
|
|
return selected_doc
|
|
|
|
def on_api_change(selected_api):
|
|
return selected_api
|
|
|
|
def on_prompt_change(selected_prompt):
|
|
if not selected_prompt:
|
|
return None, []
|
|
selected_prompt_id = int(selected_prompt.split(" ")[0])
|
|
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
|
|
prompt_content = prompt_data["content"]
|
|
prompt_template = PromptTemplate.from_template(prompt_content)
|
|
input_variables = prompt_template.input_variables
|
|
input_variables.remove("document_slice")
|
|
dataframe_value = [] if input_variables is None else input_variables
|
|
dataframe_value = [[var, ""] for var in input_variables]
|
|
return selected_prompt, dataframe_value
|
|
|
|
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
|
|
dataset_db = get_datasets()
|
|
if not dataset_db.search(Query().name == dataset_name):
|
|
raise gr.Error("数据集名称已存在")
|
|
|
|
doc = [i for i in get_docs() if i.name == doc_state][0]
|
|
doc_files = doc.markdown_files
|
|
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
|
|
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
|
prompt = PromptTemplate.from_template(prompt["content"])
|
|
with Session(get_sql_engine()) as session:
|
|
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
|
|
|
variables_dict = {}
|
|
for _, row in variables_dataframe.iterrows():
|
|
var_name = row['变量名'].strip()
|
|
var_value = row['变量值'].strip()
|
|
if var_name:
|
|
variables_dict[var_name] = var_value
|
|
|
|
prompt = prompt.partial(**variables_dict)
|
|
|
|
dataset = Dataset(
|
|
name=dataset_name,
|
|
model_id=[api_provider.model_id],
|
|
source_doc=doc,
|
|
dataset_items=[]
|
|
)
|
|
|
|
total_slices = len(document_slice_list)
|
|
for i, document_slice in enumerate(document_slice_list):
|
|
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
|
|
request = LLMRequest(api_provider=api_provider,
|
|
prompt=prompt.format(document_slice=document_slice),
|
|
format=generate_json_example(DatasetItem))
|
|
call_openai_api(request, rounds)
|
|
|
|
for resp in request.response:
|
|
try:
|
|
content = json.loads(resp.content)
|
|
dataset_item = DatasetItem(
|
|
message=[Q_A(
|
|
question=content.get("question", ""),
|
|
answer=content.get("answer", "")
|
|
)]
|
|
)
|
|
dataset.dataset_items.append(dataset_item)
|
|
except json.JSONDecodeError as e:
|
|
print(f"Failed to parse response: {e}")
|
|
|
|
# 保存数据集到TinyDB
|
|
dataset_db.insert(dataset.model_dump())
|
|
|
|
save_dataset(dataset_db,get_workdir(),dataset_name)
|
|
|
|
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
|
|
|
|
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
|
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
|
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
|
|
|
|
generate_button.click(
|
|
on_generate_click,
|
|
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
|
|
outputs=output_text
|
|
)
|
|
|
|
return demo
|
|
|
|
if __name__ == "__main__":
|
|
from global_var import init_global_var
|
|
init_global_var("workdir")
|
|
demo = dataset_generate_page()
|
|
demo.launch() |