Compare commits

..

No commits in common. "86bcf90c66f99cd47afa62cb691f736e8ac6854a" and "5a386d64011dc8244e54233bcf4359b676356aef" have entirely different histories.

View File

@ -12,7 +12,7 @@ def dataset_generate_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集生成")
with gr.Row():
with gr.Column(scale=1):
with gr.Column():
docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown(
@ -22,6 +22,8 @@ def dataset_generate_page():
interactive=True
)
doc_choice = gr.State(value=initial_doc)
with gr.Column():
prompts = get_prompt_store().all()
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_list[0] if prompt_list else None
@ -43,6 +45,8 @@ def dataset_generate_page():
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
with gr.Column():
# 从数据库获取API Provider列表
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
@ -57,19 +61,10 @@ def dataset_generate_page():
)
api_choice = gr.State(value=initial_api)
rounds_input = gr.Number(
value=1,
label="生成轮次",
minimum=1,
maximum=100,
step=1,
interactive=True
)
generate_button = gr.Button("生成数据集",variant="primary")
output_text = gr.Textbox(label="生成结果", interactive=False)
with gr.Column(scale=2):
variables_dataframe = gr.Dataframe(
headers=["变量名", "变量值"],
datatype=["str", "str"],
@ -97,7 +92,8 @@ def dataset_generate_page():
dataframe_value = [] if input_variables is None else input_variables
return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, progress=gr.Progress()):
variables_dict = {}
# 正确遍历DataFrame的行数据
for _, row in variables_dataframe.iterrows():
@ -107,7 +103,7 @@ def dataset_generate_page():
variables_dict[var_name] = var_value
import time
total_steps = rounds
total_steps = 10
for i in range(total_steps):
# 模拟每个步骤的工作负载
time.sleep(0.5)
@ -120,13 +116,14 @@ def dataset_generate_page():
return "all done"
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
generate_button.click(
on_generate_click,
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input],
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe],
outputs=output_text
)