refactor(frontend): 重构数据集生成页面
- 调整页面布局,优化用户交互流程 - 新增数据集名称输入框 - 使用 LLMRequest 和 LLMResponse 模型处理请求和响应 - 添加 generate_example_json 函数用于格式化生成数据 - 改进数据集生成逻辑,支持多轮次生成
This commit is contained in:
parent
4c9caff668
commit
e7cf51d662
@ -5,8 +5,8 @@ from langchain.prompts import PromptTemplate
|
|||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from schema import APIProvider
|
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
|
||||||
from tools import call_openai_api,process_markdown_file
|
from tools import call_openai_api, process_markdown_file, generate_example_json
|
||||||
from global_var import get_docs, get_prompt_store, get_sql_engine
|
from global_var import get_docs, get_prompt_store, get_sql_engine
|
||||||
|
|
||||||
def dataset_generate_page():
|
def dataset_generate_page():
|
||||||
@ -16,13 +16,6 @@ def dataset_generate_page():
|
|||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
docs_list = [str(doc.name) for doc in get_docs()]
|
docs_list = [str(doc.name) for doc in get_docs()]
|
||||||
initial_doc = docs_list[0] if docs_list else None
|
initial_doc = docs_list[0] if docs_list else None
|
||||||
doc_dropdown = gr.Dropdown(
|
|
||||||
choices=docs_list,
|
|
||||||
value=initial_doc,
|
|
||||||
label="选择文档",
|
|
||||||
interactive=True
|
|
||||||
)
|
|
||||||
doc_choice = gr.State(value=initial_doc)
|
|
||||||
prompts = get_prompt_store().all()
|
prompts = get_prompt_store().all()
|
||||||
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
|
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
|
||||||
initial_prompt = prompt_list[0] if prompt_list else None
|
initial_prompt = prompt_list[0] if prompt_list else None
|
||||||
@ -37,14 +30,6 @@ def dataset_generate_page():
|
|||||||
input_variables = prompt_template.input_variables
|
input_variables = prompt_template.input_variables
|
||||||
input_variables.remove("document_slice")
|
input_variables.remove("document_slice")
|
||||||
initial_dataframe_value = [[var, ""] for var in input_variables]
|
initial_dataframe_value = [[var, ""] for var in input_variables]
|
||||||
|
|
||||||
prompt_dropdown = gr.Dropdown(
|
|
||||||
choices=prompt_list,
|
|
||||||
value=initial_prompt,
|
|
||||||
label="选择模板",
|
|
||||||
interactive=True
|
|
||||||
)
|
|
||||||
prompt_choice = gr.State(value=initial_prompt)
|
|
||||||
# 从数据库获取API Provider列表
|
# 从数据库获取API Provider列表
|
||||||
with Session(get_sql_engine()) as session:
|
with Session(get_sql_engine()) as session:
|
||||||
providers = session.exec(select(APIProvider)).all()
|
providers = session.exec(select(APIProvider)).all()
|
||||||
@ -57,8 +42,18 @@ def dataset_generate_page():
|
|||||||
label="选择API",
|
label="选择API",
|
||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
api_choice = gr.State(value=initial_api)
|
doc_dropdown = gr.Dropdown(
|
||||||
|
choices=docs_list,
|
||||||
|
value=initial_doc,
|
||||||
|
label="选择文档",
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
prompt_dropdown = gr.Dropdown(
|
||||||
|
choices=prompt_list,
|
||||||
|
value=initial_prompt,
|
||||||
|
label="选择模板",
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
rounds_input = gr.Number(
|
rounds_input = gr.Number(
|
||||||
value=1,
|
value=1,
|
||||||
label="生成轮次",
|
label="生成轮次",
|
||||||
@ -76,10 +71,16 @@ def dataset_generate_page():
|
|||||||
interactive=True,
|
interactive=True,
|
||||||
visible=False
|
visible=False
|
||||||
)
|
)
|
||||||
|
dataset_name_input = gr.Textbox(
|
||||||
|
label="数据集名称",
|
||||||
|
placeholder="输入数据集保存名称",
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
prompt_choice = gr.State(value=initial_prompt)
|
||||||
generate_button = gr.Button("生成数据集",variant="primary")
|
generate_button = gr.Button("生成数据集",variant="primary")
|
||||||
|
doc_choice = gr.State(value=initial_doc)
|
||||||
output_text = gr.Textbox(label="生成结果", interactive=False)
|
output_text = gr.Textbox(label="生成结果", interactive=False)
|
||||||
|
api_choice = gr.State(value=initial_api)
|
||||||
with gr.Column(scale=2):
|
with gr.Column(scale=2):
|
||||||
variables_dataframe = gr.Dataframe(
|
variables_dataframe = gr.Dataframe(
|
||||||
headers=["变量名", "变量值"],
|
headers=["变量名", "变量值"],
|
||||||
@ -88,8 +89,6 @@ def dataset_generate_page():
|
|||||||
label="变量列表",
|
label="变量列表",
|
||||||
value=initial_dataframe_value # 设置初始化数据
|
value=initial_dataframe_value # 设置初始化数据
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def on_doc_change(selected_doc):
|
def on_doc_change(selected_doc):
|
||||||
return selected_doc
|
return selected_doc
|
||||||
|
|
||||||
@ -109,7 +108,7 @@ def dataset_generate_page():
|
|||||||
dataframe_value = [[var, ""] for var in input_variables]
|
dataframe_value = [[var, ""] for var in input_variables]
|
||||||
return selected_prompt, dataframe_value
|
return selected_prompt, dataframe_value
|
||||||
|
|
||||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, progress=gr.Progress()):
|
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
|
||||||
docs = [i for i in get_docs() if i.name == doc_state][0].markdown_files
|
docs = [i for i in get_docs() if i.name == doc_state][0].markdown_files
|
||||||
document_slice_list = [process_markdown_file(doc) for doc in docs]
|
document_slice_list = [process_markdown_file(doc) for doc in docs]
|
||||||
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
||||||
@ -124,12 +123,14 @@ def dataset_generate_page():
|
|||||||
if var_name:
|
if var_name:
|
||||||
variables_dict[var_name] = var_value
|
variables_dict[var_name] = var_value
|
||||||
|
|
||||||
# 注入除document_slice以外的所有参数
|
|
||||||
prompt = prompt.partial(**variables_dict)
|
prompt = prompt.partial(**variables_dict)
|
||||||
|
|
||||||
for document_slice in document_slice_list:
|
for document_slice in document_slice_list:
|
||||||
print("~"*20)
|
request = LLMRequest(api_provider=api_provider,
|
||||||
print(prompt.format(document_slice=document_slice))
|
prompt=prompt.format(document_slice=document_slice),
|
||||||
|
format=generate_example_json(DatasetItem))
|
||||||
|
call_openai_api(request, rounds)
|
||||||
|
|
||||||
|
|
||||||
return "all done"
|
return "all done"
|
||||||
|
|
||||||
@ -139,7 +140,7 @@ def dataset_generate_page():
|
|||||||
|
|
||||||
generate_button.click(
|
generate_button.click(
|
||||||
on_generate_click,
|
on_generate_click,
|
||||||
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input],
|
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
|
||||||
outputs=output_text
|
outputs=output_text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user