Compare commits

...

9 Commits

Author SHA1 Message Date
carry
0a4efa5641 feat(dataset): 添加数据集生成功能
- 新增数据集生成页面和相关逻辑
- 实现数据集名称重复性检查
- 添加数据集对象创建和保存功能
- 优化文档处理和提示模板应用
- 增加错误处理和数据解析
2025-04-20 21:25:51 +08:00
carry
994d600221 refactor(frontend): 调整 TensorBoard iframe 高度
- 将 TensorBoard iframe 的高度从 500px 修改为 1000px
- 此修改旨在提供更宽敞的显示区域,改善用户体验
2025-04-20 21:25:37 +08:00
carry
d5774eee0c feat(db): 添加数据集导出功能
- 新增 save_dataset 函数,用于将 TinyDB 中的数据集保存为单独的 JSON 文件
- 更新 db/__init__.py,添加 get_dataset_tinydb 函数的引用
- 修改 db/dataset_store.py,实现 save_dataset 函数并添加相关逻辑
2025-04-20 19:44:11 +08:00
carry
87501c9353 fix(global_var): 移除全局变量设置函数set_datasets
- 删除了 global_var.py 文件中的 set_datasets 函数
- 该函数用于设置全局变量 _datasets,但似乎已不再使用
2025-04-20 19:14:00 +08:00
carry
5fc3b4950b refactor(schema): 修改 LLMResponse 中 API 响应内容的字段名称
- 将 LLMResponse 类中的 response_content 字段重命名为 content
- 更新字段类型从 dict 改为 str,以更准确地表示响应内容
- 在 reasoning.py 中相应地修改了调用 LLMResponse 时的参数
2025-04-20 18:40:51 +08:00
carry
c28e4819d9 refactor(frontend/tools): 重命名生成示例 JSON 数据结构的函数
- 将 generate_example_json 函数重命名为 generate_json_example
- 更新相关文件中的函数调用和引用
- 此更改旨在使函数名称更具描述性和一致性
2025-04-20 16:11:36 +08:00
carry
e7cf51d662 refactor(frontend): 重构数据集生成页面
- 调整页面布局,优化用户交互流程
- 新增数据集名称输入框
- 使用 LLMRequest 和 LLMResponse 模型处理请求和响应
- 添加 generate_example_json 函数用于格式化生成数据
- 改进数据集生成逻辑,支持多轮次生成
2025-04-20 16:10:08 +08:00
carry
4c9caff668 refactor(schema): 重构数据集和文档类的命名
- 将 dataset、dataset_item 和 doc 类的首字母大写,以符合 Python 类命名惯例
- 更新相关模块中的导入和引用,以适应新的类名
- 此更改不影响功能,仅提高了代码的一致性和可读性
2025-04-20 01:46:15 +08:00
carry
9236f49b36 feat(frontend): 添加文档切片和并发数功能
- 新增并发数输入框
- 实现文档切片处理
- 更新生成数据集的逻辑,支持并发处理
2025-04-20 01:40:48 +08:00
12 changed files with 138 additions and 72 deletions

View File

@ -1,11 +1,12 @@
from .init_db import load_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
from .dataset_store import get_all_dataset
from .dataset_store import get_all_dataset, save_dataset
__all__ = [
"load_sqlite_engine",
"initialize_sqlite_db",
"get_prompt_tinydb",
"initialize_prompt_store",
"get_all_dataset"
"get_all_dataset",
"save_dataset"
]

View File

@ -8,7 +8,7 @@ from tinydb.storages import MemoryStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A
from schema.dataset import Dataset, DatasetItem, Q_A
def get_all_dataset(workdir: str) -> TinyDB:
"""
@ -39,6 +39,30 @@ def get_all_dataset(workdir: str) -> TinyDB:
return db
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
"""
将TinyDB中的数据集保存为单独的json文件
Args:
db (TinyDB): 包含数据集对象的TinyDB实例
workdir (str): 工作目录路径
name (str, optional): 要保存的数据集名称None表示保存所有
"""
dataset_dir = os.path.join(workdir, "dataset")
os.makedirs(dataset_dir, exist_ok=True)
datasets = db.all() if name is None else db.search(Query().name == name)
for dataset in datasets:
try:
filename = f"{dataset.get(dataset['name'])}.json"
filepath = os.path.join(dataset_dir, filename)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
@ -47,4 +71,11 @@ if __name__ == "__main__":
# 打印结果
print(f"Found {len(datasets)} datasets:")
for ds in datasets.all():
print(f"- {ds['name']} (ID: {ds['id']})")
print(f"- {ds['name']} (ID: {ds['id']})")
# 询问要保存的数据集名称
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
# 保存数据集到文件
save_dataset(datasets, workdir, name)
print(f"Datasets {'all' if name is None else name} saved to json files")

View File

@ -1,13 +1,18 @@
import gradio as gr
import sys
import json
from tinydb import Query
from pathlib import Path
from langchain.prompts import PromptTemplate
from sqlmodel import Session, select
from schema import Dataset, DatasetItem, Q_A
from db.dataset_store import get_all_dataset
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider
from tools import call_openai_api
from global_var import get_docs, get_prompt_store, get_sql_engine
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
from db import save_dataset
from tools import call_openai_api, process_markdown_file, generate_json_example
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
def dataset_generate_page():
with gr.Blocks() as demo:
@ -16,13 +21,6 @@ def dataset_generate_page():
with gr.Column(scale=1):
docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown(
choices=docs_list,
value=initial_doc,
label="选择文档",
interactive=True
)
doc_choice = gr.State(value=initial_doc)
prompts = get_prompt_store().all()
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_list[0] if prompt_list else None
@ -37,14 +35,6 @@ def dataset_generate_page():
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
initial_dataframe_value = [[var, ""] for var in input_variables]
prompt_dropdown = gr.Dropdown(
choices=prompt_list,
value=initial_prompt,
label="选择模板",
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
# 从数据库获取API Provider列表
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
@ -57,8 +47,18 @@ def dataset_generate_page():
label="选择API",
interactive=True
)
api_choice = gr.State(value=initial_api)
doc_dropdown = gr.Dropdown(
choices=docs_list,
value=initial_doc,
label="选择文档",
interactive=True
)
prompt_dropdown = gr.Dropdown(
choices=prompt_list,
value=initial_prompt,
label="选择模板",
interactive=True
)
rounds_input = gr.Number(
value=1,
label="生成轮次",
@ -67,10 +67,25 @@ def dataset_generate_page():
step=1,
interactive=True
)
concurrency_input = gr.Number(
value=1,
label="并发数",
minimum=1,
maximum=10,
step=1,
interactive=True,
visible=False
)
dataset_name_input = gr.Textbox(
label="数据集名称",
placeholder="输入数据集保存名称",
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
generate_button = gr.Button("生成数据集",variant="primary")
doc_choice = gr.State(value=initial_doc)
output_text = gr.Textbox(label="生成结果", interactive=False)
api_choice = gr.State(value=initial_api)
with gr.Column(scale=2):
variables_dataframe = gr.Dataframe(
headers=["变量名", "变量值"],
@ -79,8 +94,6 @@ def dataset_generate_page():
label="变量列表",
value=initial_dataframe_value # 设置初始化数据
)
def on_doc_change(selected_doc):
return selected_doc
@ -100,8 +113,14 @@ def dataset_generate_page():
dataframe_value = [[var, ""] for var in input_variables]
return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
dataset_db = get_datasets()
if not dataset_db.search(Query().name == dataset_name):
raise gr.Error("数据集名称已存在")
doc = [i for i in get_docs() if i.name == doc_state][0]
doc_files = doc.markdown_files
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
prompt = PromptTemplate.from_template(prompt["content"])
with Session(get_sql_engine()) as session:
@ -114,23 +133,42 @@ def dataset_generate_page():
if var_name:
variables_dict[var_name] = var_value
# 注入除document_slice以外的所有参数
prompt = prompt.partial(**variables_dict)
print(doc)
print(prompt.format(document_slice="test"))
print(variables_dict)
import time
total_steps = rounds
for i in range(total_steps):
# 模拟每个步骤的工作负载
time.sleep(0.5)
dataset = Dataset(
name=dataset_name,
model_id=[api_provider.model_id],
source_doc=doc,
dataset_items=[]
)
for document_slice in document_slice_list:
request = LLMRequest(api_provider=api_provider,
prompt=prompt.format(document_slice=document_slice),
format=generate_json_example(DatasetItem))
call_openai_api(request, rounds)
current_progress = (i + 1) / total_steps
progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}")
for resp in request.response:
try:
content = json.loads(resp.content)
dataset_item = DatasetItem(
message=[Q_A(
question=content.get("question", ""),
answer=content.get("answer", "")
)]
)
dataset.dataset_items.append(dataset_item)
except json.JSONDecodeError as e:
print(f"Failed to parse response: {e}")
return "all done"
# 保存数据集到TinyDB
dataset_db.insert(dataset.model_dump())
save_dataset(dataset_db,get_workdir(),dataset_name)
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
@ -138,7 +176,7 @@ def dataset_generate_page():
generate_button.click(
on_generate_click,
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input],
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
outputs=output_text
)

View File

@ -74,7 +74,7 @@ def train_page():
# 动态生成 TensorBoard iframe
tensorboard_url = f"http://localhost:{tensorboard_port}"
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="500px"></iframe>'
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>'
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
try:

View File

@ -37,10 +37,6 @@ def get_docs():
def get_datasets():
return _datasets
def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
def get_model():
return _model

View File

@ -2,7 +2,7 @@ from typing import Optional
from pydantic import BaseModel, Field
from datetime import datetime, timezone
class doc(BaseModel):
class Doc(BaseModel):
id: Optional[int] = Field(default=None, description="文档ID")
name: str = Field(default="", description="文档名称")
path: str = Field(default="", description="文档路径")
@ -13,18 +13,18 @@ class Q_A(BaseModel):
question: str = Field(default="", min_length=1,description="问题")
answer: str = Field(default="", min_length=1, description="答案")
class dataset_item(BaseModel):
class DatasetItem(BaseModel):
id: Optional[int] = Field(default=None, description="数据集项ID")
message: list[Q_A] = Field(description="数据集项内容")
class dataset(BaseModel):
class Dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID")
name: str = Field(default="", description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档")
source_doc: Optional[Doc] = Field(default=None, description="数据集来源文档")
description: Optional[str] = Field(default="", description="数据集描述")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
dataset_items: list[dataset_item] = Field(default_factory=list, description="数据集项列表")
dataset_items: list[DatasetItem] = Field(default_factory=list, description="数据集项列表")

View File

@ -33,7 +33,7 @@ class LLMResponse(SQLModel):
)
response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容")
content: str = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")

View File

@ -1,5 +1,5 @@
from .parse_markdown import *
from .document import *
from .json_example import generate_example_json
from .json_example import generate_json_example
from .port import *
from .reasoning import call_openai_api

View File

@ -1,19 +1,19 @@
from typing import List
from schema.dataset import dataset, dataset_item, Q_A
from schema.dataset import Dataset, DatasetItem, Q_A
import json
def convert_json_to_dataset(json_data: List[dict]) -> dataset:
def convert_json_to_dataset(json_data: List[dict]) -> Dataset:
# 将JSON数据转换为dataset格式
dataset_items = []
item_id = 1 # 自增ID计数器
for item in json_data:
qa = Q_A(question=item["question"], answer=item["answer"])
dataset_item_obj = dataset_item(id=item_id, message=[qa])
dataset_item_obj = DatasetItem(id=item_id, message=[qa])
dataset_items.append(dataset_item_obj)
item_id += 1 # ID自增
# 创建dataset对象
result_dataset = dataset(
result_dataset = Dataset(
name="Converted Dataset",
model_id=None,
description="Dataset converted from JSON",

View File

@ -4,7 +4,7 @@ from pathlib import Path
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import doc
from schema import Doc
def scan_docs_directory(workdir: str):
docs_dir = os.path.join(workdir, "docs")
@ -21,7 +21,7 @@ def scan_docs_directory(workdir: str):
for file in files:
if file.endswith(".md"):
markdown_files.append(os.path.join(root, file))
to_return.append(doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
to_return.append(Doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
return to_return

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union, get_args, get_origin
import json
from datetime import datetime, date
def generate_example_json(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
def generate_json_example(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
"""
根据 Pydantic V2 模型生成示例 JSON 数据结构
"""
@ -37,14 +37,14 @@ def generate_example_json(model: type[BaseModel], include_optional: bool = False
elif field_type is date:
return date.today().isoformat()
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
return json.loads(generate_example_json(field_type, include_optional))
return json.loads(generate_json_example(field_type, include_optional))
else:
# 处理直接类型注解(非泛型)
if field_type is type(None):
return None
try:
if issubclass(field_type, BaseModel):
return json.loads(generate_example_json(field_type, include_optional))
return json.loads(generate_json_example(field_type, include_optional))
except TypeError:
pass
return "unknown"
@ -61,7 +61,7 @@ if __name__ == "__main__":
from pathlib import Path
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import dataset
from schema import Dataset
print("示例 JSON:")
print(generate_example_json(dataset))
print(generate_json_example(Dataset))

View File

@ -68,7 +68,7 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
llm_request.response.append(LLMResponse(
response_id=response.id,
tokens_usage=tokens_usage,
response_content={"content": response.choices[0].message.content},
content = response.choices[0].message.content,
total_duration=duration,
llm_parameters=llm_parameters
))
@ -79,7 +79,7 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
llm_request.response.append(LLMResponse(
response_id=f"error-round-{i+1}",
response_content={"error": str(e)},
content={"error": str(e)},
total_duration=duration
))
if llm_request.error is None:
@ -93,10 +93,10 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
return llm_request
if __name__ == "__main__":
from json_example import generate_example_json
from json_example import generate_json_example
from sqlmodel import Session, select
from global_var import get_sql_engine, init_global_var
from schema import dataset_item
from schema import DatasetItem
init_global_var("workdir")
api_state = "1 deepseek-chat"
@ -105,7 +105,7 @@ if __name__ == "__main__":
llm_request = LLMRequest(
prompt="测试,随便说点什么",
api_provider=api_provider,
format=generate_example_json(dataset_item)
format=generate_json_example(DatasetItem)
)
# # 单次调用示例
@ -120,4 +120,4 @@ if __name__ == "__main__":
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
for i, resp in enumerate(result.response, 1):
print(f"响应{i}: {resp.response_content}")
print(f"响应{i}: {resp.content}")