Compare commits

..

4 Commits

Author SHA1 Message Date
carry
5a386d6401 feat(dataset_generate_page): 添加 API 选择功能
- 在数据集生成页面添加 API 选择下拉框
- 实现 API 选择变更时的处理逻辑
- 更新数据集生成函数,增加 API 选择参数
- 优化页面布局和代码结构
2025-04-18 15:23:33 +08:00
carry
feaea1fb64 refactor(db): 重命名数据库引擎加载函数
- 将 get_sqlite_engine 函数重命名为 load_sqlite_engine
- 更新了相关模块中的导入和调用
- 这个改动是为了更好地反映函数的实际功能,提高代码可读性
2025-04-18 15:16:29 +08:00
carry
7242a2ce03 feat(frontend): 添加生成数据集进度条功能并优化了界面布局 2025-04-18 15:07:46 +08:00
carry
db6e2271dc fix(frontend): 修复 prompt_dropdown 变化时,dataframe没有相应的变化
- 将 prompt_dropdown 变化时的输出从 prompt_state 修改为 [prompt_state, variables_dataframe]
- 这个改动可能会在 prompt 变化时同时更新变量数据框
2025-04-18 14:03:26 +08:00
4 changed files with 55 additions and 17 deletions

View File

@ -1,9 +1,9 @@
from .init_db import get_sqlite_engine, initialize_sqlite_db
from .init_db import load_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
from .dataset_store import get_all_dataset
__all__ = [
"get_sqlite_engine",
"load_sqlite_engine",
"initialize_sqlite_db",
"get_prompt_tinydb",
"initialize_prompt_store",

View File

@ -14,7 +14,7 @@ from schema.dataset_generation import APIProvider
# 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None
def get_sqlite_engine(workdir: str) -> Engine:
def load_sqlite_engine(workdir: str) -> Engine:
"""
获取数据库引擎实例如果引擎尚未创建则创建一个新的引擎并返回
@ -74,6 +74,6 @@ if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎
engine = get_sqlite_engine(workdir)
engine = load_sqlite_engine(workdir)
# 初始化数据库
initialize_sqlite_db(engine)

View File

@ -2,9 +2,11 @@ import gradio as gr
import sys
from pathlib import Path
from langchain.prompts import PromptTemplate
from sqlmodel import Session, select
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_docs, get_prompt_store
from schema import APIProvider
from global_var import get_docs, get_prompt_store, get_sql_engine
def dataset_generate_page():
with gr.Blocks() as demo:
@ -19,7 +21,7 @@ def dataset_generate_page():
label="选择文档",
interactive=True
)
doc_state = gr.State(value=initial_doc)
doc_choice = gr.State(value=initial_doc)
with gr.Column():
prompts = get_prompt_store().all()
@ -42,9 +44,26 @@ def dataset_generate_page():
label="选择模板",
interactive=True
)
prompt_state = gr.State(value=initial_prompt)
prompt_choice = gr.State(value=initial_prompt)
generate_button = gr.Button("生成数据集")
with gr.Column():
# 从数据库获取API Provider列表
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
api_list = [f"{p.id} {p.model_id}" for p in providers]
initial_api = api_list[0] if api_list else None
api_dropdown = gr.Dropdown(
choices=api_list,
value=initial_api,
label="选择API",
interactive=True
)
api_choice = gr.State(value=initial_api)
generate_button = gr.Button("生成数据集",variant="primary")
output_text = gr.Textbox(label="生成结果", interactive=False)
variables_dataframe = gr.Dataframe(
headers=["变量名", "变量值"],
@ -54,11 +73,13 @@ def dataset_generate_page():
value=initial_dataframe_value # 设置初始化数据
)
output_text = gr.Textbox(label="生成结果", interactive=False)
def on_doc_change(selected_doc):
# print(f"文档选择已更改为: {selected_doc}")
return selected_doc
def on_api_change(selected_api):
return selected_api
def on_prompt_change(selected_prompt):
if not selected_prompt:
return None, []
@ -67,10 +88,12 @@ def dataset_generate_page():
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
dataframe_value = [[var, ""] for var in input_variables]
input_variables.remove("document_slice")
dataframe_value = [] if input_variables is None else input_variables
return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, variables_dataframe):
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, progress=gr.Progress()):
variables_dict = {}
# 正确遍历DataFrame的行数据
for _, row in variables_dataframe.iterrows():
@ -79,13 +102,28 @@ def dataset_generate_page():
if var_name:
variables_dict[var_name] = var_value
import time
total_steps = 10
for i in range(total_steps):
# 模拟每个步骤的工作负载
time.sleep(0.5)
# 更新进度条
# 第一个参数是当前的进度比例 (0.0 到 1.0)
# desc 参数可以动态更新进度条旁边的描述文字
current_progress = (i + 1) / total_steps
progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}")
return "all done"
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_state)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=prompt_state)
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
generate_button.click(
on_generate_click,
inputs=[doc_state, prompt_state, variables_dataframe],
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe],
outputs=output_text
)

View File

@ -1,4 +1,4 @@
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
from db import load_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory
_prompt_store = None
@ -10,7 +10,7 @@ _workdir = None
def init_global_var(workdir="workdir"):
global _prompt_store, _sql_engine, _datasets, _workdir
_prompt_store = get_prompt_tinydb(workdir)
_sql_engine = get_sqlite_engine(workdir)
_sql_engine = load_sqlite_engine(workdir)
_datasets = get_all_dataset(workdir)
_workdir = workdir