From d827f9758f27cacc7d008b31445d5fe92718dde0 Mon Sep 17 00:00:00 2001 From: carry Date: Sat, 19 Apr 2025 17:30:10 +0800 Subject: [PATCH] =?UTF-8?q?fix(frontend):=20=E4=BF=AE=E5=A4=8Ddataframe=5F?= =?UTF-8?q?value=E8=BF=94=E5=9B=9E=E5=80=BC=E5=8F=AA=E6=9C=89=E4=B8=80?= =?UTF-8?q?=E5=88=97=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/dataset_generate_page.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/frontend/dataset_generate_page.py b/frontend/dataset_generate_page.py index b281042..77c280d 100644 --- a/frontend/dataset_generate_page.py +++ b/frontend/dataset_generate_page.py @@ -6,6 +6,7 @@ from sqlmodel import Session, select sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema import APIProvider +from tools import call_openai_api from global_var import get_docs, get_prompt_store, get_sql_engine def dataset_generate_page(): @@ -34,6 +35,7 @@ def dataset_generate_page(): prompt_content = prompt_data["content"] prompt_template = PromptTemplate.from_template(prompt_content) input_variables = prompt_template.input_variables + input_variables.remove("document_slice") initial_dataframe_value = [[var, ""] for var in input_variables] prompt_dropdown = gr.Dropdown( @@ -95,21 +97,28 @@ def dataset_generate_page(): input_variables = prompt_template.input_variables input_variables.remove("document_slice") dataframe_value = [] if input_variables is None else input_variables + dataframe_value = [[var, ""] for var in input_variables] return selected_prompt, dataframe_value def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()): - doc = [i for i in get_docs() if i.name == doc_state][0] + doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0] + prompt = PromptTemplate.from_template(prompt["content"]) with Session(get_sql_engine()) as session: api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first() - + variables_dict = {} - # 正确遍历DataFrame的行数据 for _, row in variables_dataframe.iterrows(): var_name = row['变量名'].strip() var_value = row['变量值'].strip() if var_name: variables_dict[var_name] = var_value + + + print(doc) + print(prompt_content) + + import time total_steps = rounds