From 9236f49b36224323bff36534c60b1cae41c96c40 Mon Sep 17 00:00:00 2001 From: carry Date: Sun, 20 Apr 2025 01:40:48 +0800 Subject: [PATCH] =?UTF-8?q?feat(frontend):=20=E6=B7=BB=E5=8A=A0=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E5=88=87=E7=89=87=E5=92=8C=E5=B9=B6=E5=8F=91=E6=95=B0?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增并发数输入框 - 实现文档切片处理 - 更新生成数据集的逻辑,支持并发处理 --- frontend/dataset_generate_page.py | 33 ++++++++++++++++--------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/frontend/dataset_generate_page.py b/frontend/dataset_generate_page.py index bd75a58..e1df948 100644 --- a/frontend/dataset_generate_page.py +++ b/frontend/dataset_generate_page.py @@ -6,7 +6,7 @@ from sqlmodel import Session, select sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema import APIProvider -from tools import call_openai_api +from tools import call_openai_api,process_markdown_file from global_var import get_docs, get_prompt_store, get_sql_engine def dataset_generate_page(): @@ -67,6 +67,15 @@ def dataset_generate_page(): step=1, interactive=True ) + concurrency_input = gr.Number( + value=1, + label="并发数", + minimum=1, + maximum=10, + step=1, + interactive=True, + visible=False + ) generate_button = gr.Button("生成数据集",variant="primary") output_text = gr.Textbox(label="生成结果", interactive=False) @@ -100,8 +109,9 @@ def dataset_generate_page(): dataframe_value = [[var, ""] for var in input_variables] return selected_prompt, dataframe_value - def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()): - doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files + def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, progress=gr.Progress()): + docs = [i for i in get_docs() if i.name == doc_state][0].markdown_files + document_slice_list = [process_markdown_file(doc) for doc in docs] prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0] prompt = PromptTemplate.from_template(prompt["content"]) with Session(get_sql_engine()) as session: @@ -117,18 +127,9 @@ def dataset_generate_page(): # 注入除document_slice以外的所有参数 prompt = prompt.partial(**variables_dict) - print(doc) - print(prompt.format(document_slice="test")) - print(variables_dict) - - import time - total_steps = rounds - for i in range(total_steps): - # 模拟每个步骤的工作负载 - time.sleep(0.5) - - current_progress = (i + 1) / total_steps - progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}") + for document_slice in document_slice_list: + print("~"*20) + print(prompt.format(document_slice=document_slice)) return "all done" @@ -138,7 +139,7 @@ def dataset_generate_page(): generate_button.click( on_generate_click, - inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input], + inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input], outputs=output_text )