Compare commits
4 Commits
d764537143
...
5a386d6401
Author | SHA1 | Date | |
---|---|---|---|
![]() |
5a386d6401 | ||
![]() |
feaea1fb64 | ||
![]() |
7242a2ce03 | ||
![]() |
db6e2271dc |
@ -1,9 +1,9 @@
|
|||||||
from .init_db import get_sqlite_engine, initialize_sqlite_db
|
from .init_db import load_sqlite_engine, initialize_sqlite_db
|
||||||
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
||||||
from .dataset_store import get_all_dataset
|
from .dataset_store import get_all_dataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_sqlite_engine",
|
"load_sqlite_engine",
|
||||||
"initialize_sqlite_db",
|
"initialize_sqlite_db",
|
||||||
"get_prompt_tinydb",
|
"get_prompt_tinydb",
|
||||||
"initialize_prompt_store",
|
"initialize_prompt_store",
|
||||||
|
@ -14,7 +14,7 @@ from schema.dataset_generation import APIProvider
|
|||||||
# 全局变量,用于存储数据库引擎实例
|
# 全局变量,用于存储数据库引擎实例
|
||||||
_engine: Optional[Engine] = None
|
_engine: Optional[Engine] = None
|
||||||
|
|
||||||
def get_sqlite_engine(workdir: str) -> Engine:
|
def load_sqlite_engine(workdir: str) -> Engine:
|
||||||
"""
|
"""
|
||||||
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||||||
|
|
||||||
@ -74,6 +74,6 @@ if __name__ == "__main__":
|
|||||||
# 定义工作目录路径
|
# 定义工作目录路径
|
||||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||||
# 获取数据库引擎
|
# 获取数据库引擎
|
||||||
engine = get_sqlite_engine(workdir)
|
engine = load_sqlite_engine(workdir)
|
||||||
# 初始化数据库
|
# 初始化数据库
|
||||||
initialize_sqlite_db(engine)
|
initialize_sqlite_db(engine)
|
@ -2,9 +2,11 @@ import gradio as gr
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from global_var import get_docs, get_prompt_store
|
from schema import APIProvider
|
||||||
|
from global_var import get_docs, get_prompt_store, get_sql_engine
|
||||||
|
|
||||||
def dataset_generate_page():
|
def dataset_generate_page():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
@ -19,7 +21,7 @@ def dataset_generate_page():
|
|||||||
label="选择文档",
|
label="选择文档",
|
||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
doc_state = gr.State(value=initial_doc)
|
doc_choice = gr.State(value=initial_doc)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
prompts = get_prompt_store().all()
|
prompts = get_prompt_store().all()
|
||||||
@ -42,9 +44,26 @@ def dataset_generate_page():
|
|||||||
label="选择模板",
|
label="选择模板",
|
||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
prompt_state = gr.State(value=initial_prompt)
|
prompt_choice = gr.State(value=initial_prompt)
|
||||||
|
|
||||||
generate_button = gr.Button("生成数据集")
|
with gr.Column():
|
||||||
|
# 从数据库获取API Provider列表
|
||||||
|
with Session(get_sql_engine()) as session:
|
||||||
|
providers = session.exec(select(APIProvider)).all()
|
||||||
|
api_list = [f"{p.id} {p.model_id}" for p in providers]
|
||||||
|
initial_api = api_list[0] if api_list else None
|
||||||
|
|
||||||
|
api_dropdown = gr.Dropdown(
|
||||||
|
choices=api_list,
|
||||||
|
value=initial_api,
|
||||||
|
label="选择API",
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
api_choice = gr.State(value=initial_api)
|
||||||
|
|
||||||
|
generate_button = gr.Button("生成数据集",variant="primary")
|
||||||
|
|
||||||
|
output_text = gr.Textbox(label="生成结果", interactive=False)
|
||||||
|
|
||||||
variables_dataframe = gr.Dataframe(
|
variables_dataframe = gr.Dataframe(
|
||||||
headers=["变量名", "变量值"],
|
headers=["变量名", "变量值"],
|
||||||
@ -54,11 +73,13 @@ def dataset_generate_page():
|
|||||||
value=initial_dataframe_value # 设置初始化数据
|
value=initial_dataframe_value # 设置初始化数据
|
||||||
)
|
)
|
||||||
|
|
||||||
output_text = gr.Textbox(label="生成结果", interactive=False)
|
|
||||||
|
|
||||||
def on_doc_change(selected_doc):
|
def on_doc_change(selected_doc):
|
||||||
# print(f"文档选择已更改为: {selected_doc}")
|
|
||||||
return selected_doc
|
return selected_doc
|
||||||
|
|
||||||
|
def on_api_change(selected_api):
|
||||||
|
return selected_api
|
||||||
|
|
||||||
def on_prompt_change(selected_prompt):
|
def on_prompt_change(selected_prompt):
|
||||||
if not selected_prompt:
|
if not selected_prompt:
|
||||||
return None, []
|
return None, []
|
||||||
@ -67,10 +88,12 @@ def dataset_generate_page():
|
|||||||
prompt_content = prompt_data["content"]
|
prompt_content = prompt_data["content"]
|
||||||
prompt_template = PromptTemplate.from_template(prompt_content)
|
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||||
input_variables = prompt_template.input_variables
|
input_variables = prompt_template.input_variables
|
||||||
dataframe_value = [[var, ""] for var in input_variables]
|
input_variables.remove("document_slice")
|
||||||
|
dataframe_value = [] if input_variables is None else input_variables
|
||||||
return selected_prompt, dataframe_value
|
return selected_prompt, dataframe_value
|
||||||
|
|
||||||
def on_generate_click(doc_state, prompt_state, variables_dataframe):
|
|
||||||
|
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, progress=gr.Progress()):
|
||||||
variables_dict = {}
|
variables_dict = {}
|
||||||
# 正确遍历DataFrame的行数据
|
# 正确遍历DataFrame的行数据
|
||||||
for _, row in variables_dataframe.iterrows():
|
for _, row in variables_dataframe.iterrows():
|
||||||
@ -79,13 +102,28 @@ def dataset_generate_page():
|
|||||||
if var_name:
|
if var_name:
|
||||||
variables_dict[var_name] = var_value
|
variables_dict[var_name] = var_value
|
||||||
|
|
||||||
|
import time
|
||||||
|
total_steps = 10
|
||||||
|
for i in range(total_steps):
|
||||||
|
# 模拟每个步骤的工作负载
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
# 更新进度条
|
||||||
|
# 第一个参数是当前的进度比例 (0.0 到 1.0)
|
||||||
|
# desc 参数可以动态更新进度条旁边的描述文字
|
||||||
|
current_progress = (i + 1) / total_steps
|
||||||
|
progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}")
|
||||||
|
|
||||||
|
return "all done"
|
||||||
|
|
||||||
|
|
||||||
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_state)
|
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
||||||
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=prompt_state)
|
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
||||||
|
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
|
||||||
|
|
||||||
generate_button.click(
|
generate_button.click(
|
||||||
on_generate_click,
|
on_generate_click,
|
||||||
inputs=[doc_state, prompt_state, variables_dataframe],
|
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe],
|
||||||
outputs=output_text
|
outputs=output_text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
from db import load_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
||||||
from tools import scan_docs_directory
|
from tools import scan_docs_directory
|
||||||
|
|
||||||
_prompt_store = None
|
_prompt_store = None
|
||||||
@ -10,7 +10,7 @@ _workdir = None
|
|||||||
def init_global_var(workdir="workdir"):
|
def init_global_var(workdir="workdir"):
|
||||||
global _prompt_store, _sql_engine, _datasets, _workdir
|
global _prompt_store, _sql_engine, _datasets, _workdir
|
||||||
_prompt_store = get_prompt_tinydb(workdir)
|
_prompt_store = get_prompt_tinydb(workdir)
|
||||||
_sql_engine = get_sqlite_engine(workdir)
|
_sql_engine = load_sqlite_engine(workdir)
|
||||||
_datasets = get_all_dataset(workdir)
|
_datasets = get_all_dataset(workdir)
|
||||||
_workdir = workdir
|
_workdir = workdir
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user