Compare commits
4 Commits
d764537143
...
5a386d6401
Author | SHA1 | Date | |
---|---|---|---|
![]() |
5a386d6401 | ||
![]() |
feaea1fb64 | ||
![]() |
7242a2ce03 | ||
![]() |
db6e2271dc |
@ -1,9 +1,9 @@
|
||||
from .init_db import get_sqlite_engine, initialize_sqlite_db
|
||||
from .init_db import load_sqlite_engine, initialize_sqlite_db
|
||||
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
||||
from .dataset_store import get_all_dataset
|
||||
|
||||
__all__ = [
|
||||
"get_sqlite_engine",
|
||||
"load_sqlite_engine",
|
||||
"initialize_sqlite_db",
|
||||
"get_prompt_tinydb",
|
||||
"initialize_prompt_store",
|
||||
|
@ -14,7 +14,7 @@ from schema.dataset_generation import APIProvider
|
||||
# 全局变量,用于存储数据库引擎实例
|
||||
_engine: Optional[Engine] = None
|
||||
|
||||
def get_sqlite_engine(workdir: str) -> Engine:
|
||||
def load_sqlite_engine(workdir: str) -> Engine:
|
||||
"""
|
||||
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||||
|
||||
@ -74,6 +74,6 @@ if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取数据库引擎
|
||||
engine = get_sqlite_engine(workdir)
|
||||
engine = load_sqlite_engine(workdir)
|
||||
# 初始化数据库
|
||||
initialize_sqlite_db(engine)
|
@ -2,9 +2,11 @@ import gradio as gr
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sqlmodel import Session, select
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_docs, get_prompt_store
|
||||
from schema import APIProvider
|
||||
from global_var import get_docs, get_prompt_store, get_sql_engine
|
||||
|
||||
def dataset_generate_page():
|
||||
with gr.Blocks() as demo:
|
||||
@ -19,7 +21,7 @@ def dataset_generate_page():
|
||||
label="选择文档",
|
||||
interactive=True
|
||||
)
|
||||
doc_state = gr.State(value=initial_doc)
|
||||
doc_choice = gr.State(value=initial_doc)
|
||||
|
||||
with gr.Column():
|
||||
prompts = get_prompt_store().all()
|
||||
@ -42,9 +44,26 @@ def dataset_generate_page():
|
||||
label="选择模板",
|
||||
interactive=True
|
||||
)
|
||||
prompt_state = gr.State(value=initial_prompt)
|
||||
prompt_choice = gr.State(value=initial_prompt)
|
||||
|
||||
generate_button = gr.Button("生成数据集")
|
||||
with gr.Column():
|
||||
# 从数据库获取API Provider列表
|
||||
with Session(get_sql_engine()) as session:
|
||||
providers = session.exec(select(APIProvider)).all()
|
||||
api_list = [f"{p.id} {p.model_id}" for p in providers]
|
||||
initial_api = api_list[0] if api_list else None
|
||||
|
||||
api_dropdown = gr.Dropdown(
|
||||
choices=api_list,
|
||||
value=initial_api,
|
||||
label="选择API",
|
||||
interactive=True
|
||||
)
|
||||
api_choice = gr.State(value=initial_api)
|
||||
|
||||
generate_button = gr.Button("生成数据集",variant="primary")
|
||||
|
||||
output_text = gr.Textbox(label="生成结果", interactive=False)
|
||||
|
||||
variables_dataframe = gr.Dataframe(
|
||||
headers=["变量名", "变量值"],
|
||||
@ -54,11 +73,13 @@ def dataset_generate_page():
|
||||
value=initial_dataframe_value # 设置初始化数据
|
||||
)
|
||||
|
||||
output_text = gr.Textbox(label="生成结果", interactive=False)
|
||||
|
||||
def on_doc_change(selected_doc):
|
||||
# print(f"文档选择已更改为: {selected_doc}")
|
||||
return selected_doc
|
||||
|
||||
def on_api_change(selected_api):
|
||||
return selected_api
|
||||
|
||||
def on_prompt_change(selected_prompt):
|
||||
if not selected_prompt:
|
||||
return None, []
|
||||
@ -67,10 +88,12 @@ def dataset_generate_page():
|
||||
prompt_content = prompt_data["content"]
|
||||
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||
input_variables = prompt_template.input_variables
|
||||
dataframe_value = [[var, ""] for var in input_variables]
|
||||
input_variables.remove("document_slice")
|
||||
dataframe_value = [] if input_variables is None else input_variables
|
||||
return selected_prompt, dataframe_value
|
||||
|
||||
def on_generate_click(doc_state, prompt_state, variables_dataframe):
|
||||
|
||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, progress=gr.Progress()):
|
||||
variables_dict = {}
|
||||
# 正确遍历DataFrame的行数据
|
||||
for _, row in variables_dataframe.iterrows():
|
||||
@ -79,13 +102,28 @@ def dataset_generate_page():
|
||||
if var_name:
|
||||
variables_dict[var_name] = var_value
|
||||
|
||||
import time
|
||||
total_steps = 10
|
||||
for i in range(total_steps):
|
||||
# 模拟每个步骤的工作负载
|
||||
time.sleep(0.5)
|
||||
|
||||
# 更新进度条
|
||||
# 第一个参数是当前的进度比例 (0.0 到 1.0)
|
||||
# desc 参数可以动态更新进度条旁边的描述文字
|
||||
current_progress = (i + 1) / total_steps
|
||||
progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}")
|
||||
|
||||
return "all done"
|
||||
|
||||
|
||||
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_state)
|
||||
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=prompt_state)
|
||||
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
||||
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
||||
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
|
||||
|
||||
generate_button.click(
|
||||
on_generate_click,
|
||||
inputs=[doc_state, prompt_state, variables_dataframe],
|
||||
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe],
|
||||
outputs=output_text
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
||||
from db import load_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
||||
from tools import scan_docs_directory
|
||||
|
||||
_prompt_store = None
|
||||
@ -10,7 +10,7 @@ _workdir = None
|
||||
def init_global_var(workdir="workdir"):
|
||||
global _prompt_store, _sql_engine, _datasets, _workdir
|
||||
_prompt_store = get_prompt_tinydb(workdir)
|
||||
_sql_engine = get_sqlite_engine(workdir)
|
||||
_sql_engine = load_sqlite_engine(workdir)
|
||||
_datasets = get_all_dataset(workdir)
|
||||
_workdir = workdir
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user