Compare commits

..

No commits in common. "0a4efa5641008c7f5b332b17f622e6e162178af3" and "868fcd45ba00676178cb0859abc74fccbf899e93" have entirely different histories.

12 changed files with 72 additions and 138 deletions

View File

@ -1,12 +1,11 @@
from .init_db import load_sqlite_engine, initialize_sqlite_db from .init_db import load_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store from .prompt_store import get_prompt_tinydb, initialize_prompt_store
from .dataset_store import get_all_dataset, save_dataset from .dataset_store import get_all_dataset
__all__ = [ __all__ = [
"load_sqlite_engine", "load_sqlite_engine",
"initialize_sqlite_db", "initialize_sqlite_db",
"get_prompt_tinydb", "get_prompt_tinydb",
"initialize_prompt_store", "initialize_prompt_store",
"get_all_dataset", "get_all_dataset"
"save_dataset"
] ]

View File

@ -8,7 +8,7 @@ from tinydb.storages import MemoryStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块 # 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import Dataset, DatasetItem, Q_A from schema.dataset import dataset, dataset_item, Q_A
def get_all_dataset(workdir: str) -> TinyDB: def get_all_dataset(workdir: str) -> TinyDB:
""" """
@ -39,30 +39,6 @@ def get_all_dataset(workdir: str) -> TinyDB:
return db return db
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
"""
将TinyDB中的数据集保存为单独的json文件
Args:
db (TinyDB): 包含数据集对象的TinyDB实例
workdir (str): 工作目录路径
name (str, optional): 要保存的数据集名称None表示保存所有
"""
dataset_dir = os.path.join(workdir, "dataset")
os.makedirs(dataset_dir, exist_ok=True)
datasets = db.all() if name is None else db.search(Query().name == name)
for dataset in datasets:
try:
filename = f"{dataset.get(dataset['name'])}.json"
filepath = os.path.join(dataset_dir, filename)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":
# 定义工作目录路径 # 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
@ -72,10 +48,3 @@ if __name__ == "__main__":
print(f"Found {len(datasets)} datasets:") print(f"Found {len(datasets)} datasets:")
for ds in datasets.all(): for ds in datasets.all():
print(f"- {ds['name']} (ID: {ds['id']})") print(f"- {ds['name']} (ID: {ds['id']})")
# 询问要保存的数据集名称
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
# 保存数据集到文件
save_dataset(datasets, workdir, name)
print(f"Datasets {'all' if name is None else name} saved to json files")

View File

@ -1,18 +1,13 @@
import gradio as gr import gradio as gr
import sys import sys
import json
from tinydb import Query
from pathlib import Path from pathlib import Path
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from sqlmodel import Session, select from sqlmodel import Session, select
from schema import Dataset, DatasetItem, Q_A
from db.dataset_store import get_all_dataset
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem from schema import APIProvider
from db import save_dataset from tools import call_openai_api
from tools import call_openai_api, process_markdown_file, generate_json_example from global_var import get_docs, get_prompt_store, get_sql_engine
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
def dataset_generate_page(): def dataset_generate_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@ -21,6 +16,13 @@ def dataset_generate_page():
with gr.Column(scale=1): with gr.Column(scale=1):
docs_list = [str(doc.name) for doc in get_docs()] docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown(
choices=docs_list,
value=initial_doc,
label="选择文档",
interactive=True
)
doc_choice = gr.State(value=initial_doc)
prompts = get_prompt_store().all() prompts = get_prompt_store().all()
prompt_list = [f"{p['id']} {p['name']}" for p in prompts] prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_list[0] if prompt_list else None initial_prompt = prompt_list[0] if prompt_list else None
@ -35,6 +37,14 @@ def dataset_generate_page():
input_variables = prompt_template.input_variables input_variables = prompt_template.input_variables
input_variables.remove("document_slice") input_variables.remove("document_slice")
initial_dataframe_value = [[var, ""] for var in input_variables] initial_dataframe_value = [[var, ""] for var in input_variables]
prompt_dropdown = gr.Dropdown(
choices=prompt_list,
value=initial_prompt,
label="选择模板",
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
# 从数据库获取API Provider列表 # 从数据库获取API Provider列表
with Session(get_sql_engine()) as session: with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all() providers = session.exec(select(APIProvider)).all()
@ -47,18 +57,8 @@ def dataset_generate_page():
label="选择API", label="选择API",
interactive=True interactive=True
) )
doc_dropdown = gr.Dropdown( api_choice = gr.State(value=initial_api)
choices=docs_list,
value=initial_doc,
label="选择文档",
interactive=True
)
prompt_dropdown = gr.Dropdown(
choices=prompt_list,
value=initial_prompt,
label="选择模板",
interactive=True
)
rounds_input = gr.Number( rounds_input = gr.Number(
value=1, value=1,
label="生成轮次", label="生成轮次",
@ -67,25 +67,10 @@ def dataset_generate_page():
step=1, step=1,
interactive=True interactive=True
) )
concurrency_input = gr.Number(
value=1,
label="并发数",
minimum=1,
maximum=10,
step=1,
interactive=True,
visible=False
)
dataset_name_input = gr.Textbox(
label="数据集名称",
placeholder="输入数据集保存名称",
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
generate_button = gr.Button("生成数据集",variant="primary") generate_button = gr.Button("生成数据集",variant="primary")
doc_choice = gr.State(value=initial_doc)
output_text = gr.Textbox(label="生成结果", interactive=False) output_text = gr.Textbox(label="生成结果", interactive=False)
api_choice = gr.State(value=initial_api)
with gr.Column(scale=2): with gr.Column(scale=2):
variables_dataframe = gr.Dataframe( variables_dataframe = gr.Dataframe(
headers=["变量名", "变量值"], headers=["变量名", "变量值"],
@ -94,6 +79,8 @@ def dataset_generate_page():
label="变量列表", label="变量列表",
value=initial_dataframe_value # 设置初始化数据 value=initial_dataframe_value # 设置初始化数据
) )
def on_doc_change(selected_doc): def on_doc_change(selected_doc):
return selected_doc return selected_doc
@ -113,14 +100,8 @@ def dataset_generate_page():
dataframe_value = [[var, ""] for var in input_variables] dataframe_value = [[var, ""] for var in input_variables]
return selected_prompt, dataframe_value return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()): def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
dataset_db = get_datasets() doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files
if not dataset_db.search(Query().name == dataset_name):
raise gr.Error("数据集名称已存在")
doc = [i for i in get_docs() if i.name == doc_state][0]
doc_files = doc.markdown_files
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0] prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
prompt = PromptTemplate.from_template(prompt["content"]) prompt = PromptTemplate.from_template(prompt["content"])
with Session(get_sql_engine()) as session: with Session(get_sql_engine()) as session:
@ -133,42 +114,23 @@ def dataset_generate_page():
if var_name: if var_name:
variables_dict[var_name] = var_value variables_dict[var_name] = var_value
# 注入除document_slice以外的所有参数
prompt = prompt.partial(**variables_dict) prompt = prompt.partial(**variables_dict)
print(doc)
print(prompt.format(document_slice="test"))
print(variables_dict)
import time
total_steps = rounds
for i in range(total_steps):
# 模拟每个步骤的工作负载
time.sleep(0.5)
dataset = Dataset( current_progress = (i + 1) / total_steps
name=dataset_name, progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}")
model_id=[api_provider.model_id],
source_doc=doc,
dataset_items=[]
)
for document_slice in document_slice_list: return "all done"
request = LLMRequest(api_provider=api_provider,
prompt=prompt.format(document_slice=document_slice),
format=generate_json_example(DatasetItem))
call_openai_api(request, rounds)
for resp in request.response:
try:
content = json.loads(resp.content)
dataset_item = DatasetItem(
message=[Q_A(
question=content.get("question", ""),
answer=content.get("answer", "")
)]
)
dataset.dataset_items.append(dataset_item)
except json.JSONDecodeError as e:
print(f"Failed to parse response: {e}")
# 保存数据集到TinyDB
dataset_db.insert(dataset.model_dump())
save_dataset(dataset_db,get_workdir(),dataset_name)
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice) doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe]) prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
@ -176,7 +138,7 @@ def dataset_generate_page():
generate_button.click( generate_button.click(
on_generate_click, on_generate_click,
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input], inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input],
outputs=output_text outputs=output_text
) )

View File

@ -74,7 +74,7 @@ def train_page():
# 动态生成 TensorBoard iframe # 动态生成 TensorBoard iframe
tensorboard_url = f"http://localhost:{tensorboard_port}" tensorboard_url = f"http://localhost:{tensorboard_port}"
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>' tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="500px"></iframe>'
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
try: try:

View File

@ -37,6 +37,10 @@ def get_docs():
def get_datasets(): def get_datasets():
return _datasets return _datasets
def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
def get_model(): def get_model():
return _model return _model

View File

@ -2,7 +2,7 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from datetime import datetime, timezone from datetime import datetime, timezone
class Doc(BaseModel): class doc(BaseModel):
id: Optional[int] = Field(default=None, description="文档ID") id: Optional[int] = Field(default=None, description="文档ID")
name: str = Field(default="", description="文档名称") name: str = Field(default="", description="文档名称")
path: str = Field(default="", description="文档路径") path: str = Field(default="", description="文档路径")
@ -13,18 +13,18 @@ class Q_A(BaseModel):
question: str = Field(default="", min_length=1,description="问题") question: str = Field(default="", min_length=1,description="问题")
answer: str = Field(default="", min_length=1, description="答案") answer: str = Field(default="", min_length=1, description="答案")
class DatasetItem(BaseModel): class dataset_item(BaseModel):
id: Optional[int] = Field(default=None, description="数据集项ID") id: Optional[int] = Field(default=None, description="数据集项ID")
message: list[Q_A] = Field(description="数据集项内容") message: list[Q_A] = Field(description="数据集项内容")
class Dataset(BaseModel): class dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID") id: Optional[int] = Field(default=None, description="数据集ID")
name: str = Field(default="", description="数据集名称") name: str = Field(default="", description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID") model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
source_doc: Optional[Doc] = Field(default=None, description="数据集来源文档") source_doc: Optional[doc] = Field(default=None, description="数据集来源文档")
description: Optional[str] = Field(default="", description="数据集描述") description: Optional[str] = Field(default="", description="数据集描述")
created_at: datetime = Field( created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间" description="记录创建时间"
) )
dataset_items: list[DatasetItem] = Field(default_factory=list, description="数据集项列表") dataset_items: list[dataset_item] = Field(default_factory=list, description="数据集项列表")

View File

@ -33,7 +33,7 @@ class LLMResponse(SQLModel):
) )
response_id: str = Field(..., description="响应的唯一ID") response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息") tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
content: str = Field(default_factory=dict, description="API响应的内容") response_content: dict = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数") llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")

View File

@ -1,5 +1,5 @@
from .parse_markdown import * from .parse_markdown import *
from .document import * from .document import *
from .json_example import generate_json_example from .json_example import generate_example_json
from .port import * from .port import *
from .reasoning import call_openai_api from .reasoning import call_openai_api

View File

@ -1,19 +1,19 @@
from typing import List from typing import List
from schema.dataset import Dataset, DatasetItem, Q_A from schema.dataset import dataset, dataset_item, Q_A
import json import json
def convert_json_to_dataset(json_data: List[dict]) -> Dataset: def convert_json_to_dataset(json_data: List[dict]) -> dataset:
# 将JSON数据转换为dataset格式 # 将JSON数据转换为dataset格式
dataset_items = [] dataset_items = []
item_id = 1 # 自增ID计数器 item_id = 1 # 自增ID计数器
for item in json_data: for item in json_data:
qa = Q_A(question=item["question"], answer=item["answer"]) qa = Q_A(question=item["question"], answer=item["answer"])
dataset_item_obj = DatasetItem(id=item_id, message=[qa]) dataset_item_obj = dataset_item(id=item_id, message=[qa])
dataset_items.append(dataset_item_obj) dataset_items.append(dataset_item_obj)
item_id += 1 # ID自增 item_id += 1 # ID自增
# 创建dataset对象 # 创建dataset对象
result_dataset = Dataset( result_dataset = dataset(
name="Converted Dataset", name="Converted Dataset",
model_id=None, model_id=None,
description="Dataset converted from JSON", description="Dataset converted from JSON",

View File

@ -4,7 +4,7 @@ from pathlib import Path
# 添加项目根目录到sys.path # 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import Doc from schema import doc
def scan_docs_directory(workdir: str): def scan_docs_directory(workdir: str):
docs_dir = os.path.join(workdir, "docs") docs_dir = os.path.join(workdir, "docs")
@ -21,7 +21,7 @@ def scan_docs_directory(workdir: str):
for file in files: for file in files:
if file.endswith(".md"): if file.endswith(".md"):
markdown_files.append(os.path.join(root, file)) markdown_files.append(os.path.join(root, file))
to_return.append(Doc(name=doc_name, path=doc_path, markdown_files=markdown_files)) to_return.append(doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
return to_return return to_return

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union, get_args, get_origin
import json import json
from datetime import datetime, date from datetime import datetime, date
def generate_json_example(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str: def generate_example_json(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
""" """
根据 Pydantic V2 模型生成示例 JSON 数据结构 根据 Pydantic V2 模型生成示例 JSON 数据结构
""" """
@ -37,14 +37,14 @@ def generate_json_example(model: type[BaseModel], include_optional: bool = False
elif field_type is date: elif field_type is date:
return date.today().isoformat() return date.today().isoformat()
elif isinstance(field_type, type) and issubclass(field_type, BaseModel): elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
return json.loads(generate_json_example(field_type, include_optional)) return json.loads(generate_example_json(field_type, include_optional))
else: else:
# 处理直接类型注解(非泛型) # 处理直接类型注解(非泛型)
if field_type is type(None): if field_type is type(None):
return None return None
try: try:
if issubclass(field_type, BaseModel): if issubclass(field_type, BaseModel):
return json.loads(generate_json_example(field_type, include_optional)) return json.loads(generate_example_json(field_type, include_optional))
except TypeError: except TypeError:
pass pass
return "unknown" return "unknown"
@ -61,7 +61,7 @@ if __name__ == "__main__":
from pathlib import Path from pathlib import Path
# 添加项目根目录到sys.path # 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import Dataset from schema import dataset
print("示例 JSON:") print("示例 JSON:")
print(generate_json_example(Dataset)) print(generate_example_json(dataset))

View File

@ -68,7 +68,7 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
llm_request.response.append(LLMResponse( llm_request.response.append(LLMResponse(
response_id=response.id, response_id=response.id,
tokens_usage=tokens_usage, tokens_usage=tokens_usage,
content = response.choices[0].message.content, response_content={"content": response.choices[0].message.content},
total_duration=duration, total_duration=duration,
llm_parameters=llm_parameters llm_parameters=llm_parameters
)) ))
@ -79,7 +79,7 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
llm_request.response.append(LLMResponse( llm_request.response.append(LLMResponse(
response_id=f"error-round-{i+1}", response_id=f"error-round-{i+1}",
content={"error": str(e)}, response_content={"error": str(e)},
total_duration=duration total_duration=duration
)) ))
if llm_request.error is None: if llm_request.error is None:
@ -93,10 +93,10 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
return llm_request return llm_request
if __name__ == "__main__": if __name__ == "__main__":
from json_example import generate_json_example from json_example import generate_example_json
from sqlmodel import Session, select from sqlmodel import Session, select
from global_var import get_sql_engine, init_global_var from global_var import get_sql_engine, init_global_var
from schema import DatasetItem from schema import dataset_item
init_global_var("workdir") init_global_var("workdir")
api_state = "1 deepseek-chat" api_state = "1 deepseek-chat"
@ -105,7 +105,7 @@ if __name__ == "__main__":
llm_request = LLMRequest( llm_request = LLMRequest(
prompt="测试,随便说点什么", prompt="测试,随便说点什么",
api_provider=api_provider, api_provider=api_provider,
format=generate_json_example(DatasetItem) format=generate_example_json(dataset_item)
) )
# # 单次调用示例 # # 单次调用示例
@ -120,4 +120,4 @@ if __name__ == "__main__":
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s") print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}") print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
for i, resp in enumerate(result.response, 1): for i, resp in enumerate(result.response, 1):
print(f"响应{i}: {resp.content}") print(f"响应{i}: {resp.response_content}")