diff --git a/frontend/dataset_generate_page.py b/frontend/dataset_generate_page.py index b281042..77c280d 100644 --- a/frontend/dataset_generate_page.py +++ b/frontend/dataset_generate_page.py @@ -6,6 +6,7 @@ from sqlmodel import Session, select sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema import APIProvider +from tools import call_openai_api from global_var import get_docs, get_prompt_store, get_sql_engine def dataset_generate_page(): @@ -34,6 +35,7 @@ def dataset_generate_page(): prompt_content = prompt_data["content"] prompt_template = PromptTemplate.from_template(prompt_content) input_variables = prompt_template.input_variables + input_variables.remove("document_slice") initial_dataframe_value = [[var, ""] for var in input_variables] prompt_dropdown = gr.Dropdown( @@ -95,21 +97,28 @@ def dataset_generate_page(): input_variables = prompt_template.input_variables input_variables.remove("document_slice") dataframe_value = [] if input_variables is None else input_variables + dataframe_value = [[var, ""] for var in input_variables] return selected_prompt, dataframe_value def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()): - doc = [i for i in get_docs() if i.name == doc_state][0] + doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0] + prompt = PromptTemplate.from_template(prompt["content"]) with Session(get_sql_engine()) as session: api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first() - + variables_dict = {} - # 正确遍历DataFrame的行数据 for _, row in variables_dataframe.iterrows(): var_name = row['变量名'].strip() var_value = row['变量值'].strip() if var_name: variables_dict[var_name] = var_value + + + print(doc) + print(prompt_content) + + import time total_steps = rounds