Compare commits
2 Commits
5a386d6401
...
86bcf90c66
Author | SHA1 | Date | |
---|---|---|---|
![]() |
86bcf90c66 | ||
![]() |
961a017f19 |
@ -12,7 +12,7 @@ def dataset_generate_page():
|
|||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("## 数据集生成")
|
gr.Markdown("## 数据集生成")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column(scale=1):
|
||||||
docs_list = [str(doc.name) for doc in get_docs()]
|
docs_list = [str(doc.name) for doc in get_docs()]
|
||||||
initial_doc = docs_list[0] if docs_list else None
|
initial_doc = docs_list[0] if docs_list else None
|
||||||
doc_dropdown = gr.Dropdown(
|
doc_dropdown = gr.Dropdown(
|
||||||
@ -22,8 +22,6 @@ def dataset_generate_page():
|
|||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
doc_choice = gr.State(value=initial_doc)
|
doc_choice = gr.State(value=initial_doc)
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
prompts = get_prompt_store().all()
|
prompts = get_prompt_store().all()
|
||||||
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
|
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
|
||||||
initial_prompt = prompt_list[0] if prompt_list else None
|
initial_prompt = prompt_list[0] if prompt_list else None
|
||||||
@ -45,8 +43,6 @@ def dataset_generate_page():
|
|||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
prompt_choice = gr.State(value=initial_prompt)
|
prompt_choice = gr.State(value=initial_prompt)
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
# 从数据库获取API Provider列表
|
# 从数据库获取API Provider列表
|
||||||
with Session(get_sql_engine()) as session:
|
with Session(get_sql_engine()) as session:
|
||||||
providers = session.exec(select(APIProvider)).all()
|
providers = session.exec(select(APIProvider)).all()
|
||||||
@ -61,17 +57,26 @@ def dataset_generate_page():
|
|||||||
)
|
)
|
||||||
api_choice = gr.State(value=initial_api)
|
api_choice = gr.State(value=initial_api)
|
||||||
|
|
||||||
generate_button = gr.Button("生成数据集",variant="primary")
|
rounds_input = gr.Number(
|
||||||
|
value=1,
|
||||||
|
label="生成轮次",
|
||||||
|
minimum=1,
|
||||||
|
maximum=100,
|
||||||
|
step=1,
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
generate_button = gr.Button("生成数据集",variant="primary")
|
||||||
|
|
||||||
output_text = gr.Textbox(label="生成结果", interactive=False)
|
output_text = gr.Textbox(label="生成结果", interactive=False)
|
||||||
|
|
||||||
variables_dataframe = gr.Dataframe(
|
with gr.Column(scale=2):
|
||||||
headers=["变量名", "变量值"],
|
variables_dataframe = gr.Dataframe(
|
||||||
datatype=["str", "str"],
|
headers=["变量名", "变量值"],
|
||||||
interactive=True,
|
datatype=["str", "str"],
|
||||||
label="变量列表",
|
interactive=True,
|
||||||
value=initial_dataframe_value # 设置初始化数据
|
label="变量列表",
|
||||||
)
|
value=initial_dataframe_value # 设置初始化数据
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def on_doc_change(selected_doc):
|
def on_doc_change(selected_doc):
|
||||||
@ -92,8 +97,7 @@ def dataset_generate_page():
|
|||||||
dataframe_value = [] if input_variables is None else input_variables
|
dataframe_value = [] if input_variables is None else input_variables
|
||||||
return selected_prompt, dataframe_value
|
return selected_prompt, dataframe_value
|
||||||
|
|
||||||
|
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
|
||||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, progress=gr.Progress()):
|
|
||||||
variables_dict = {}
|
variables_dict = {}
|
||||||
# 正确遍历DataFrame的行数据
|
# 正确遍历DataFrame的行数据
|
||||||
for _, row in variables_dataframe.iterrows():
|
for _, row in variables_dataframe.iterrows():
|
||||||
@ -103,7 +107,7 @@ def dataset_generate_page():
|
|||||||
variables_dict[var_name] = var_value
|
variables_dict[var_name] = var_value
|
||||||
|
|
||||||
import time
|
import time
|
||||||
total_steps = 10
|
total_steps = rounds
|
||||||
for i in range(total_steps):
|
for i in range(total_steps):
|
||||||
# 模拟每个步骤的工作负载
|
# 模拟每个步骤的工作负载
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
@ -116,14 +120,13 @@ def dataset_generate_page():
|
|||||||
|
|
||||||
return "all done"
|
return "all done"
|
||||||
|
|
||||||
|
|
||||||
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
||||||
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
||||||
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
|
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
|
||||||
|
|
||||||
generate_button.click(
|
generate_button.click(
|
||||||
on_generate_click,
|
on_generate_click,
|
||||||
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe],
|
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input],
|
||||||
outputs=output_text
|
outputs=output_text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user