diff --git a/frontend/dataset_generate_page.py b/frontend/dataset_generate_page.py index e1df948..1765a59 100644 --- a/frontend/dataset_generate_page.py +++ b/frontend/dataset_generate_page.py @@ -5,8 +5,8 @@ from langchain.prompts import PromptTemplate from sqlmodel import Session, select sys.path.append(str(Path(__file__).resolve().parent.parent)) -from schema import APIProvider -from tools import call_openai_api,process_markdown_file +from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem +from tools import call_openai_api, process_markdown_file, generate_example_json from global_var import get_docs, get_prompt_store, get_sql_engine def dataset_generate_page(): @@ -16,13 +16,6 @@ def dataset_generate_page(): with gr.Column(scale=1): docs_list = [str(doc.name) for doc in get_docs()] initial_doc = docs_list[0] if docs_list else None - doc_dropdown = gr.Dropdown( - choices=docs_list, - value=initial_doc, - label="选择文档", - interactive=True - ) - doc_choice = gr.State(value=initial_doc) prompts = get_prompt_store().all() prompt_list = [f"{p['id']} {p['name']}" for p in prompts] initial_prompt = prompt_list[0] if prompt_list else None @@ -37,14 +30,6 @@ def dataset_generate_page(): input_variables = prompt_template.input_variables input_variables.remove("document_slice") initial_dataframe_value = [[var, ""] for var in input_variables] - - prompt_dropdown = gr.Dropdown( - choices=prompt_list, - value=initial_prompt, - label="选择模板", - interactive=True - ) - prompt_choice = gr.State(value=initial_prompt) # 从数据库获取API Provider列表 with Session(get_sql_engine()) as session: providers = session.exec(select(APIProvider)).all() @@ -57,8 +42,18 @@ def dataset_generate_page(): label="选择API", interactive=True ) - api_choice = gr.State(value=initial_api) - + doc_dropdown = gr.Dropdown( + choices=docs_list, + value=initial_doc, + label="选择文档", + interactive=True + ) + prompt_dropdown = gr.Dropdown( + choices=prompt_list, + value=initial_prompt, + label="选择模板", + interactive=True + ) rounds_input = gr.Number( value=1, label="生成轮次", @@ -76,10 +71,16 @@ def dataset_generate_page(): interactive=True, visible=False ) + dataset_name_input = gr.Textbox( + label="数据集名称", + placeholder="输入数据集保存名称", + interactive=True + ) + prompt_choice = gr.State(value=initial_prompt) generate_button = gr.Button("生成数据集",variant="primary") - + doc_choice = gr.State(value=initial_doc) output_text = gr.Textbox(label="生成结果", interactive=False) - + api_choice = gr.State(value=initial_api) with gr.Column(scale=2): variables_dataframe = gr.Dataframe( headers=["变量名", "变量值"], @@ -88,8 +89,6 @@ def dataset_generate_page(): label="变量列表", value=initial_dataframe_value # 设置初始化数据 ) - - def on_doc_change(selected_doc): return selected_doc @@ -109,7 +108,7 @@ def dataset_generate_page(): dataframe_value = [[var, ""] for var in input_variables] return selected_prompt, dataframe_value - def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, progress=gr.Progress()): + def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()): docs = [i for i in get_docs() if i.name == doc_state][0].markdown_files document_slice_list = [process_markdown_file(doc) for doc in docs] prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0] @@ -124,12 +123,14 @@ def dataset_generate_page(): if var_name: variables_dict[var_name] = var_value - # 注入除document_slice以外的所有参数 prompt = prompt.partial(**variables_dict) for document_slice in document_slice_list: - print("~"*20) - print(prompt.format(document_slice=document_slice)) + request = LLMRequest(api_provider=api_provider, + prompt=prompt.format(document_slice=document_slice), + format=generate_example_json(DatasetItem)) + call_openai_api(request, rounds) + return "all done" @@ -139,7 +140,7 @@ def dataset_generate_page(): generate_button.click( on_generate_click, - inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input], + inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input], outputs=output_text )