fix(frontend): 修复dataframe_value返回值只有一列的bug
This commit is contained in:
parent
ff1e9731bc
commit
d827f9758f
@ -6,6 +6,7 @@ from sqlmodel import Session, select
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from schema import APIProvider
|
from schema import APIProvider
|
||||||
|
from tools import call_openai_api
|
||||||
from global_var import get_docs, get_prompt_store, get_sql_engine
|
from global_var import get_docs, get_prompt_store, get_sql_engine
|
||||||
|
|
||||||
def dataset_generate_page():
|
def dataset_generate_page():
|
||||||
@ -34,6 +35,7 @@ def dataset_generate_page():
|
|||||||
prompt_content = prompt_data["content"]
|
prompt_content = prompt_data["content"]
|
||||||
prompt_template = PromptTemplate.from_template(prompt_content)
|
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||||
input_variables = prompt_template.input_variables
|
input_variables = prompt_template.input_variables
|
||||||
|
input_variables.remove("document_slice")
|
||||||
initial_dataframe_value = [[var, ""] for var in input_variables]
|
initial_dataframe_value = [[var, ""] for var in input_variables]
|
||||||
|
|
||||||
prompt_dropdown = gr.Dropdown(
|
prompt_dropdown = gr.Dropdown(
|
||||||
@ -95,21 +97,28 @@ def dataset_generate_page():
|
|||||||
input_variables = prompt_template.input_variables
|
input_variables = prompt_template.input_variables
|
||||||
input_variables.remove("document_slice")
|
input_variables.remove("document_slice")
|
||||||
dataframe_value = [] if input_variables is None else input_variables
|
dataframe_value = [] if input_variables is None else input_variables
|
||||||
|
dataframe_value = [[var, ""] for var in input_variables]
|
||||||
return selected_prompt, dataframe_value
|
return selected_prompt, dataframe_value
|
||||||
|
|
||||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
|
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
|
||||||
doc = [i for i in get_docs() if i.name == doc_state][0]
|
doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files
|
||||||
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
||||||
|
prompt = PromptTemplate.from_template(prompt["content"])
|
||||||
with Session(get_sql_engine()) as session:
|
with Session(get_sql_engine()) as session:
|
||||||
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||||
|
|
||||||
variables_dict = {}
|
variables_dict = {}
|
||||||
# 正确遍历DataFrame的行数据
|
|
||||||
for _, row in variables_dataframe.iterrows():
|
for _, row in variables_dataframe.iterrows():
|
||||||
var_name = row['变量名'].strip()
|
var_name = row['变量名'].strip()
|
||||||
var_value = row['变量值'].strip()
|
var_value = row['变量值'].strip()
|
||||||
if var_name:
|
if var_name:
|
||||||
variables_dict[var_name] = var_value
|
variables_dict[var_name] = var_value
|
||||||
|
|
||||||
|
|
||||||
|
print(doc)
|
||||||
|
print(prompt_content)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
total_steps = rounds
|
total_steps = rounds
|
||||||
|
Loading…
x
Reference in New Issue
Block a user