Compare commits

34 Commits

Author SHA1 Message Date
carry
7a4388c928 featmodel): 添加保存模式选择功能
在模型管理页面中新增保存模式选择功能,用户可以通过下拉菜单选择不同的保存模式(如默认、合并16位、合并4位等)。同时,将保存模型的逻辑抽离到独立的`save_model.py`文件中,以提高代码的可维护性和复用性。
2025-04-23 14:09:02 +08:00
carry
6338706967 feat: 修改应用启动方式
- 将 app.launch() 修改为 app.launch(server_name="0.0.0.0")
- 此修改使应用能够监听所有网络接口,提高可用性
2025-04-22 19:08:23 +08:00
carry
3718c75cee feat(frontend): 添加数据集生成页面的处理进度显示
- 在处理文档片段时添加进度条,提升用户体验
- 优化代码格式,调整缩进和空行
2025-04-22 00:14:16 +08:00
carry
905658073a docs(README): 更新项目文档
- 添加项目概述、核心功能、技术架构等详细信息
- 插入系统架构图和技术栈说明
- 细化功能模块描述,包括模型管理、推理、微调等
- 增加QLoRA原理和参数配置说明
- 补充快速开始指南和许可证信息
- 优化文档结构,增强可读性和完整性
2025-04-21 14:28:20 +08:00
carry
9806334517 fix(train_page): 捕获训练过程中的异常并终止 TensorBoard 进程
- 在训练过程中添加异常捕获,将异常信息转换为 gr.Error 抛出
- 确保在发生异常时也能终止 TensorBoard 子进程
2025-04-20 21:40:46 +08:00
carry
0a4efa5641 feat(dataset): 添加数据集生成功能
- 新增数据集生成页面和相关逻辑
- 实现数据集名称重复性检查
- 添加数据集对象创建和保存功能
- 优化文档处理和提示模板应用
- 增加错误处理和数据解析
2025-04-20 21:25:51 +08:00
carry
994d600221 refactor(frontend): 调整 TensorBoard iframe 高度
- 将 TensorBoard iframe 的高度从 500px 修改为 1000px
- 此修改旨在提供更宽敞的显示区域,改善用户体验
2025-04-20 21:25:37 +08:00
carry
d5774eee0c feat(db): 添加数据集导出功能
- 新增 save_dataset 函数,用于将 TinyDB 中的数据集保存为单独的 JSON 文件
- 更新 db/__init__.py,添加 get_dataset_tinydb 函数的引用
- 修改 db/dataset_store.py,实现 save_dataset 函数并添加相关逻辑
2025-04-20 19:44:11 +08:00
carry
87501c9353 fix(global_var): 移除全局变量设置函数set_datasets
- 删除了 global_var.py 文件中的 set_datasets 函数
- 该函数用于设置全局变量 _datasets,但似乎已不再使用
2025-04-20 19:14:00 +08:00
carry
5fc3b4950b refactor(schema): 修改 LLMResponse 中 API 响应内容的字段名称
- 将 LLMResponse 类中的 response_content 字段重命名为 content
- 更新字段类型从 dict 改为 str,以更准确地表示响应内容
- 在 reasoning.py 中相应地修改了调用 LLMResponse 时的参数
2025-04-20 18:40:51 +08:00
carry
c28e4819d9 refactor(frontend/tools): 重命名生成示例 JSON 数据结构的函数
- 将 generate_example_json 函数重命名为 generate_json_example
- 更新相关文件中的函数调用和引用
- 此更改旨在使函数名称更具描述性和一致性
2025-04-20 16:11:36 +08:00
carry
e7cf51d662 refactor(frontend): 重构数据集生成页面
- 调整页面布局,优化用户交互流程
- 新增数据集名称输入框
- 使用 LLMRequest 和 LLMResponse 模型处理请求和响应
- 添加 generate_example_json 函数用于格式化生成数据
- 改进数据集生成逻辑,支持多轮次生成
2025-04-20 16:10:08 +08:00
carry
4c9caff668 refactor(schema): 重构数据集和文档类的命名
- 将 dataset、dataset_item 和 doc 类的首字母大写,以符合 Python 类命名惯例
- 更新相关模块中的导入和引用,以适应新的类名
- 此更改不影响功能,仅提高了代码的一致性和可读性
2025-04-20 01:46:15 +08:00
carry
9236f49b36 feat(frontend): 添加文档切片和并发数功能
- 新增并发数输入框
- 实现文档切片处理
- 更新生成数据集的逻辑,支持并发处理
2025-04-20 01:40:48 +08:00
carry
868fcd45ba refactor(project): 重构项目文件组织结构
- 修改模型管理和训练页面的导入路径
- 更新 main.py 中的导入模块
- 调整 tools 包的内容,移除 model 模块
- 新建 train 包,包含 model 模块
- 优化 __init__.py 文件,简化导入语句
2025-04-19 21:49:19 +08:00
carry
5a21c8598a feat(tools): 支持 OpenAI API 的 JSON 格式返回结果
- 在 call_openai_api 函数中添加对 JSON 格式返回结果的支持
- 增加 llm_request.format 参数处理,将用户 prompt 与格式要求合并
- 添加 response_format 参数到 OpenAI API 请求
- 更新示例,使用 JSON 格式返回结果
2025-04-19 21:10:22 +08:00
carry
1e829c9268 feat(tools): 优化 JSON 示例生成函数
- 增加 include_optional 参数,决定是否包含可选字段
- 添加 list_length 参数,用于控制列表字段的示例长度
- 在列表示例中添加省略标记,更直观展示多元素列表
- 优化字典字段的示例生成逻辑
2025-04-19 21:07:00 +08:00
carry
9fc3ab904b feat(frontend): 实现了固定参数的注入 2025-04-19 17:48:45 +08:00
carry
d827f9758f fix(frontend): 修复dataframe_value返回值只有一列的bug 2025-04-19 17:30:10 +08:00
carry
ff1e9731bc fix(tools): 修复call_openai_api的导出 2025-04-19 17:13:19 +08:00
carry
90fde639ff feat(tools): 增加 OpenAI API 多轮调用功能
- 在 call_openai_api 函数中添加 rounds 参数,支持多次调用
- 累加每次调用的耗时和 token 使用情况
- 将多次调用的结果存储在 LLMRequest 对象的 response 列表中
- 更新函数返回类型,返回包含多次调用信息的 LLMRequest 对象
- 优化错误处理,记录每轮调用的错误信息
2025-04-19 17:02:00 +08:00
carry
5fc90903fb feat(tools): 添加 reasoning.py 工具模块
- 新增 reasoning.py 文件,实现与 OpenAI API 的交互
- 添加 call_openai_api 函数,用于发送请求并处理响应
- 支持可选的 LLMParameters 参数,以定制化请求
- 处理 API 响应中的 tokens 使用情况
- 提供错误处理和缓存 token 字段的处理
2025-04-19 16:53:48 +08:00
carry
81c2ad4a2d refactor(schema): 重构数据模型以提高可维护性和可扩展性
- 新增 LLMParameters 类以统一处理 LLM 参数
- 新增 TokensUsage 类以统一处理 token 使用信息
- 更新 LLMResponse 和 LLMRequest 类,使用新的 LLMParameters 和 TokensUsage 类
- 优化数据模型结构,提高代码的可读性和可维护性
2025-04-19 16:39:18 +08:00
carry
314434951d feat(frontend): 实现了文档、提示和 API 提供商的获取逻辑 2025-04-19 14:47:01 +08:00
carry
e16882953d fix(tools): 修复了optional字段无法被解析的问题 2025-04-18 22:00:51 +08:00
carry
86bcf90c66 feat(frontend): 添加数据集生成轮次控制功能
- 在数据集生成页面添加"生成轮次"输入框,支持设置生成轮数
- 更新生成逻辑,根据设置的轮次进行多次生成
2025-04-18 15:47:37 +08:00
carry
961a017f19 refactor(frontend): 调整数据集生成页面布局并优化代码结构
- 使用 gr.Column(scale=1) 和 gr.Column(scale=2) 调整列宽比例
- 移除多余的空行和缩进,提高代码可读性
- 优化变量声明和组件创建的顺序,使页面结构更清晰
2025-04-18 15:40:15 +08:00
carry
5a386d6401 feat(dataset_generate_page): 添加 API 选择功能
- 在数据集生成页面添加 API 选择下拉框
- 实现 API 选择变更时的处理逻辑
- 更新数据集生成函数,增加 API 选择参数
- 优化页面布局和代码结构
2025-04-18 15:23:33 +08:00
carry
feaea1fb64 refactor(db): 重命名数据库引擎加载函数
- 将 get_sqlite_engine 函数重命名为 load_sqlite_engine
- 更新了相关模块中的导入和调用
- 这个改动是为了更好地反映函数的实际功能,提高代码可读性
2025-04-18 15:16:29 +08:00
carry
7242a2ce03 feat(frontend): 添加生成数据集进度条功能并优化了界面布局 2025-04-18 15:07:46 +08:00
carry
db6e2271dc fix(frontend): 修复 prompt_dropdown 变化时,dataframe没有相应的变化
- 将 prompt_dropdown 变化时的输出从 prompt_state 修改为 [prompt_state, variables_dataframe]
- 这个改动可能会在 prompt 变化时同时更新变量数据框
2025-04-18 14:03:26 +08:00
carry
d764537143 feat(dataset_generate_page): 更新数据集生成页面功能
- 添加模板变量列表展示和编辑功能
- 实现模板选择后动态更新变量列表
- 增加生成数据集按钮和相关逻辑
- 优化页面布局和交互
2025-04-16 12:39:48 +08:00
carry
8c35a38c47 feat(frontend): 更新模板选择功能
- 在模板选择变更时,获取所选模板的详细信息
- 创建 PromptTemplate 对象并获取输入变量列表
- 此更新为后续的模板编辑功能做准备
2025-04-15 21:31:50 +08:00
carry
7ee751c88f fix(frontend): 移除文档生成页面的冗余事件绑定代码
- 删除了原有的简单事件绑定逻辑,这些逻辑仅将输入值赋给状态变量
- 为后续添加更复杂的文档选择更改事件处理函数做准备
2025-04-15 20:44:26 +08:00
20 changed files with 537 additions and 136 deletions

View File

@@ -1,15 +1,91 @@
# 基于文档驱动的自适应编码大模型微调框架 # 基于文档驱动的自适应编码大模型微调框架
## 简介 ## 简介
本人的毕业设计
### 项目概述 ### 项目概述
本项目是一个基于文档驱动的自适应编码大模型微调框架,通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。 ### 核心功能
- 文档解析与语料生成
- 大语言模型高效微调
- 交互式训练与推理界面
- 训练过程可视化监控
### 项目技术 ## 技术架构
* 使用unsloth框架在GPU上实现大语言模型的qlora微调 ### 系统架构
* 使用langchain框架编写工作流实现批量生成微调语料 ```
* 使用tinydb和sqlite实现数据的持久化 [前端界面] -> [模型微调] -> [数据存储]
* 使用gradio框架实现前端展示 ↑ ↑ ↑
│ │ │
[Gradio] [unsloth/QLoRA] [SQLite/TinyDB]
```
**施工中......** ### 技术栈
- **前端界面**: Gradio构建的交互式Web界面
- **模型微调**: 基于unsloth框架的QLoRA高效微调
- **数据存储**: SQLite(结构化数据) + TinyDB(非结构化数据)
- **工作流引擎**: LangChain实现文档解析与语料生成
- **训练监控**: TensorBoard集成
## 功能模块
### 1. 模型管理
- 支持多种格式的大语言模型加载
- 模型信息查看与状态管理
- 模型卸载与内存释放
### 2. 模型推理
- 对话式交互界面
- 流式响应输出
- 历史对话管理
### 3. 模型微调
- 训练参数配置(学习率、batch size等)
- LoRA参数配置(秩、alpha等)
- 训练过程实时监控
- 训练中断与恢复
### 4. 数据集生成
- 文档解析与清洗
- 指令-响应对生成
- 数据集质量评估
- 数据集版本管理
## 微调技术
### QLoRA原理
QLoRA(Quantized Low-Rank Adaptation)是一种高效的大模型微调技术,核心特点包括:
1. **4-bit量化**: 将预训练模型量化为4-bit表示大幅减少显存占用
2. **低秩适配**: 通过低秩矩阵分解(LoRA)实现参数高效更新
3. **内存优化**: 使用梯度检查点等技术进一步降低显存需求
### 参数配置
- **学习率**: 建议2e-5到2e-4
- **LoRA秩**: 控制适配器复杂度(建议16-64)
- **LoRA Alpha**: 控制适配器更新幅度(通常设为秩的1-2倍)
### 训练监控
- **TensorBoard集成**: 实时查看损失曲线、学习率等指标
- **日志记录**: 训练过程详细日志保存
- **模型检查点**: 定期保存中间权重
## 快速开始
1. 安装依赖:
```bash
pip install -r requirements.txt
```
2. 启动应用:
```bash
python main.py
```
3. 访问Web界面:
```
http://localhost:7860
```
## 许可证
MIT License

View File

@@ -1,11 +1,12 @@
from .init_db import get_sqlite_engine, initialize_sqlite_db from .init_db import load_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store from .prompt_store import get_prompt_tinydb, initialize_prompt_store
from .dataset_store import get_all_dataset from .dataset_store import get_all_dataset, save_dataset
__all__ = [ __all__ = [
"get_sqlite_engine", "load_sqlite_engine",
"initialize_sqlite_db", "initialize_sqlite_db",
"get_prompt_tinydb", "get_prompt_tinydb",
"initialize_prompt_store", "initialize_prompt_store",
"get_all_dataset" "get_all_dataset",
"save_dataset"
] ]

View File

@@ -8,7 +8,7 @@ from tinydb.storages import MemoryStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块 # 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A from schema.dataset import Dataset, DatasetItem, Q_A
def get_all_dataset(workdir: str) -> TinyDB: def get_all_dataset(workdir: str) -> TinyDB:
""" """
@@ -39,6 +39,30 @@ def get_all_dataset(workdir: str) -> TinyDB:
return db return db
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
"""
将TinyDB中的数据集保存为单独的json文件
Args:
db (TinyDB): 包含数据集对象的TinyDB实例
workdir (str): 工作目录路径
name (str, optional): 要保存的数据集名称None表示保存所有
"""
dataset_dir = os.path.join(workdir, "dataset")
os.makedirs(dataset_dir, exist_ok=True)
datasets = db.all() if name is None else db.search(Query().name == name)
for dataset in datasets:
try:
filename = f"{dataset.get(dataset['name'])}.json"
filepath = os.path.join(dataset_dir, filename)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":
# 定义工作目录路径 # 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
@@ -48,3 +72,10 @@ if __name__ == "__main__":
print(f"Found {len(datasets)} datasets:") print(f"Found {len(datasets)} datasets:")
for ds in datasets.all(): for ds in datasets.all():
print(f"- {ds['name']} (ID: {ds['id']})") print(f"- {ds['name']} (ID: {ds['id']})")
# 询问要保存的数据集名称
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
# 保存数据集到文件
save_dataset(datasets, workdir, name)
print(f"Datasets {'all' if name is None else name} saved to json files")

View File

@@ -14,7 +14,7 @@ from schema.dataset_generation import APIProvider
# 全局变量,用于存储数据库引擎实例 # 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None _engine: Optional[Engine] = None
def get_sqlite_engine(workdir: str) -> Engine: def load_sqlite_engine(workdir: str) -> Engine:
""" """
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。 获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
@@ -74,6 +74,6 @@ if __name__ == "__main__":
# 定义工作目录路径 # 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎 # 获取数据库引擎
engine = get_sqlite_engine(workdir) engine = load_sqlite_engine(workdir)
# 初始化数据库 # 初始化数据库
initialize_sqlite_db(engine) initialize_sqlite_db(engine)

View File

@@ -1,58 +1,184 @@
import gradio as gr import gradio as gr
import sys import sys
import json
from tinydb import Query
from pathlib import Path from pathlib import Path
from langchain.prompts import PromptTemplate
from sqlmodel import Session, select
from schema import Dataset, DatasetItem, Q_A
from db.dataset_store import get_all_dataset
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_docs, get_prompt_store from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
from db import save_dataset
from tools import call_openai_api, process_markdown_file, generate_json_example
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
def dataset_generate_page(): def dataset_generate_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("## 数据集生成") gr.Markdown("## 数据集生成")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column(scale=1):
# 获取文档列表并设置初始值
docs_list = [str(doc.name) for doc in get_docs()] docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown(
choices=docs_list,
value=initial_doc, # 设置初始选中项
label="选择文档",
allow_custom_value=True,
interactive=True
)
doc_state = gr.State(value=initial_doc) # 用文档初始值初始化状态
with gr.Column():
# 获取模板列表并设置初始值
prompts = get_prompt_store().all() prompts = get_prompt_store().all()
prompt_list = [f"{p['id']} {p['name']}" for p in prompts] prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_list[0] if prompt_list else None initial_prompt = prompt_list[0] if prompt_list else None
prompt_dropdown = gr.Dropdown( # 初始化Dataframe的值
choices=prompt_list, initial_dataframe_value = []
value=initial_prompt, # 设置初始选中项 if initial_prompt:
label="选择模板", selected_prompt_id = int(initial_prompt.split(" ")[0])
allow_custom_value=True, prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
initial_dataframe_value = [[var, ""] for var in input_variables]
# 从数据库获取API Provider列表
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
api_list = [f"{p.id} {p.model_id}" for p in providers]
initial_api = api_list[0] if api_list else None
api_dropdown = gr.Dropdown(
choices=api_list,
value=initial_api,
label="选择API",
interactive=True interactive=True
) )
prompt_state = gr.State(value=initial_prompt) # 用模板初始值初始化状态 doc_dropdown = gr.Dropdown(
choices=docs_list,
# 绑定事件(保留原有逻辑,确保交互时更新) value=initial_doc,
doc_dropdown.change(lambda x: x, inputs=doc_dropdown, outputs=doc_state) label="选择文档",
prompt_dropdown.change(lambda x: x, inputs=prompt_dropdown, outputs=prompt_state) interactive=True
)
# 新增事件绑定 prompt_dropdown = gr.Dropdown(
choices=prompt_list,
value=initial_prompt,
label="选择模板",
interactive=True
)
rounds_input = gr.Number(
value=1,
label="生成轮次",
minimum=1,
maximum=100,
step=1,
interactive=True
)
concurrency_input = gr.Number(
value=1,
label="并发数",
minimum=1,
maximum=10,
step=1,
interactive=True,
visible=False
)
dataset_name_input = gr.Textbox(
label="数据集名称",
placeholder="输入数据集保存名称",
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
generate_button = gr.Button("生成数据集",variant="primary")
doc_choice = gr.State(value=initial_doc)
output_text = gr.Textbox(label="生成结果", interactive=False)
api_choice = gr.State(value=initial_api)
with gr.Column(scale=2):
variables_dataframe = gr.Dataframe(
headers=["变量名", "变量值"],
datatype=["str", "str"],
interactive=True,
label="变量列表",
value=initial_dataframe_value # 设置初始化数据
)
def on_doc_change(selected_doc): def on_doc_change(selected_doc):
print(f"文档选择已更改为: {selected_doc}")
return selected_doc return selected_doc
def on_prompt_change(selected_prompt): def on_api_change(selected_api):
print(f"模板选择已更改为: {selected_prompt}") return selected_api
return selected_prompt
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_state) def on_prompt_change(selected_prompt):
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=prompt_state) if not selected_prompt:
return None, []
selected_prompt_id = int(selected_prompt.split(" ")[0])
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
dataframe_value = [] if input_variables is None else input_variables
dataframe_value = [[var, ""] for var in input_variables]
return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
dataset_db = get_datasets()
if not dataset_db.search(Query().name == dataset_name):
raise gr.Error("数据集名称已存在")
doc = [i for i in get_docs() if i.name == doc_state][0]
doc_files = doc.markdown_files
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
prompt = PromptTemplate.from_template(prompt["content"])
with Session(get_sql_engine()) as session:
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
variables_dict = {}
for _, row in variables_dataframe.iterrows():
var_name = row['变量名'].strip()
var_value = row['变量值'].strip()
if var_name:
variables_dict[var_name] = var_value
prompt = prompt.partial(**variables_dict)
dataset = Dataset(
name=dataset_name,
model_id=[api_provider.model_id],
source_doc=doc,
dataset_items=[]
)
total_slices = len(document_slice_list)
for i, document_slice in enumerate(document_slice_list):
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
request = LLMRequest(api_provider=api_provider,
prompt=prompt.format(document_slice=document_slice),
format=generate_json_example(DatasetItem))
call_openai_api(request, rounds)
for resp in request.response:
try:
content = json.loads(resp.content)
dataset_item = DatasetItem(
message=[Q_A(
question=content.get("question", ""),
answer=content.get("answer", "")
)]
)
dataset.dataset_items.append(dataset_item)
except json.JSONDecodeError as e:
print(f"Failed to parse response: {e}")
# 保存数据集到TinyDB
dataset_db.insert(dataset.model_dump())
save_dataset(dataset_db,get_workdir(),dataset_name)
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
generate_button.click(
on_generate_click,
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
outputs=output_text
)
return demo return demo

View File

@@ -7,7 +7,7 @@ import torch
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, set_model, set_tokenizer from global_var import get_model, get_tokenizer, set_model, set_tokenizer
from tools.model import get_model_name from train import get_model_name
def model_manage_page(): def model_manage_page():
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹 workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
@@ -30,6 +30,11 @@ def model_manage_page():
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称") save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
save_method_dropdown = gr.Dropdown(
choices=["default", "merged_16bit", "merged_4bit", "lora", "gguf", "gguf_q4_k_m", "gguf_f16"],
label="保存模式",
value="default"
)
with gr.Column(scale=1): with gr.Column(scale=1):
save_button = gr.Button("保存模型", variant="secondary") save_button = gr.Button("保存模型", variant="secondary")
@@ -73,21 +78,12 @@ def model_manage_page():
unload_button.click(fn=unload_model, inputs=None, outputs=state_output) unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
def save_model(save_model_name): from train.save_model import save_model_to_dir
try:
global model, tokenizer
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name) def save_model(save_model_name, save_method):
os.makedirs(save_path, exist_ok=True) return save_model_to_dir(save_model_name, models_dir, get_model(), get_tokenizer(), save_method)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
return f"模型已保存到 {save_path}"
except Exception as e:
return f"保存模型时出错: {str(e)}"
save_button.click(fn=save_model, inputs=save_model_name_input, outputs=state_output) save_button.click(fn=save_model, inputs=[save_model_name_input, save_method_dropdown], outputs=state_output)
def refresh_model_list(): def refresh_model_list():
try: try:

View File

@@ -8,7 +8,8 @@ from transformers import TrainerCallback
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, get_datasets, get_workdir from global_var import get_model, get_tokenizer, get_datasets, get_workdir
from tools import train_model, find_available_port from tools import find_available_port
from train import train_model
def train_page(): def train_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@@ -73,7 +74,7 @@ def train_page():
# 动态生成 TensorBoard iframe # 动态生成 TensorBoard iframe
tensorboard_url = f"http://localhost:{tensorboard_port}" tensorboard_url = f"http://localhost:{tensorboard_port}"
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="500px"></iframe>' tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>'
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
try: try:
@@ -81,6 +82,8 @@ def train_page():
dataset, new_training_dir, dataset, new_training_dir,
learning_rate, per_device_train_batch_size, epoch, learning_rate, per_device_train_batch_size, epoch,
save_steps, lora_rank) save_steps, lora_rank)
except Exception as e:
raise gr.Error(str(e))
finally: finally:
# 确保训练结束后终止 TensorBoard 子进程 # 确保训练结束后终止 TensorBoard 子进程
tensorboard_process.terminate() tensorboard_process.terminate()

View File

@@ -1,4 +1,4 @@
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset from db import load_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory from tools import scan_docs_directory
_prompt_store = None _prompt_store = None
@@ -10,7 +10,7 @@ _workdir = None
def init_global_var(workdir="workdir"): def init_global_var(workdir="workdir"):
global _prompt_store, _sql_engine, _datasets, _workdir global _prompt_store, _sql_engine, _datasets, _workdir
_prompt_store = get_prompt_tinydb(workdir) _prompt_store = get_prompt_tinydb(workdir)
_sql_engine = get_sqlite_engine(workdir) _sql_engine = load_sqlite_engine(workdir)
_datasets = get_all_dataset(workdir) _datasets = get_all_dataset(workdir)
_workdir = workdir _workdir = workdir
@@ -37,10 +37,6 @@ def get_docs():
def get_datasets(): def get_datasets():
return _datasets return _datasets
def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
def get_model(): def get_model():
return _model return _model

View File

@@ -1,5 +1,5 @@
import gradio as gr import gradio as gr
import unsloth import train
from frontend import * from frontend import *
from db import initialize_sqlite_db, initialize_prompt_store from db import initialize_sqlite_db, initialize_prompt_store
from global_var import init_global_var, get_sql_engine, get_prompt_store from global_var import init_global_var, get_sql_engine, get_prompt_store
@@ -26,4 +26,4 @@ if __name__ == "__main__":
with gr.TabItem("设置"): with gr.TabItem("设置"):
setting_page() setting_page()
app.launch() app.launch(server_name="0.0.0.0")

View File

@@ -1,4 +1,4 @@
from .dataset import * from .dataset import *
from .dataset_generation import APIProvider, LLMResponse, LLMRequest from .dataset_generation import *
from .md_doc import MarkdownNode from .md_doc import MarkdownNode
from .prompt import promptTempleta from .prompt import promptTempleta

View File

@@ -2,7 +2,7 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from datetime import datetime, timezone from datetime import datetime, timezone
class doc(BaseModel): class Doc(BaseModel):
id: Optional[int] = Field(default=None, description="文档ID") id: Optional[int] = Field(default=None, description="文档ID")
name: str = Field(default="", description="文档名称") name: str = Field(default="", description="文档名称")
path: str = Field(default="", description="文档路径") path: str = Field(default="", description="文档路径")
@@ -13,18 +13,18 @@ class Q_A(BaseModel):
question: str = Field(default="", min_length=1,description="问题") question: str = Field(default="", min_length=1,description="问题")
answer: str = Field(default="", min_length=1, description="答案") answer: str = Field(default="", min_length=1, description="答案")
class dataset_item(BaseModel): class DatasetItem(BaseModel):
id: Optional[int] = Field(default=None, description="数据集项ID") id: Optional[int] = Field(default=None, description="数据集项ID")
message: list[Q_A] = Field(description="数据集项内容") message: list[Q_A] = Field(description="数据集项内容")
class dataset(BaseModel): class Dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID") id: Optional[int] = Field(default=None, description="数据集ID")
name: str = Field(default="", description="数据集名称") name: str = Field(default="", description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID") model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档") source_doc: Optional[Doc] = Field(default=None, description="数据集来源文档")
description: Optional[str] = Field(default="", description="数据集描述") description: Optional[str] = Field(default="", description="数据集描述")
created_at: datetime = Field( created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间" description="记录创建时间"
) )
dataset_items: list[dataset_item] = Field(default_factory=list, description="数据集项列表") dataset_items: list[DatasetItem] = Field(default_factory=list, description="数据集项列表")

View File

@@ -12,40 +12,36 @@ class APIProvider(SQLModel, table=True):
description="记录创建时间" description="记录创建时间"
) )
class LLMParameters(SQLModel):
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
seed: Optional[int] = None
class TokensUsage(SQLModel):
prompt_tokens: int = Field(default=0, description="提示词使用的token数量")
completion_tokens: int = Field(default=0, description="完成部分使用的token数量")
prompt_cache_hit_tokens: Optional[int] = Field(default=None, description="缓存命中token数量")
prompt_cache_miss_tokens: Optional[int] = Field(default=None, description="缓存未命中token数量")
class LLMResponse(SQLModel): class LLMResponse(SQLModel):
timestamp: datetime = Field( timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), default_factory=lambda: datetime.now(timezone.utc),
description="响应的时间戳" description="响应的时间戳"
) )
response_id: str = Field(..., description="响应的唯一ID") response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: dict = Field(default_factory=lambda: { tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
"prompt_tokens": 0, content: str = Field(default_factory=dict, description="API响应的内容")
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: dict = Field(default_factory=lambda: { llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")
"temperature": None,
"max_tokens": None,
"top_p": None,
"frequency_penalty": None,
"presence_penalty": None,
"seed": None
}, description="API的生成参数")
class LLMRequest(SQLModel): class LLMRequest(SQLModel):
prompt: str = Field(..., description="发送给API的提示词") prompt: str = Field(..., description="发送给API的提示词")
provider_id: int = Field(foreign_key="apiprovider.id") api_provider: APIProvider = Field(..., description="API提供者的信息")
provider: APIProvider = Relationship()
format: Optional[str] = Field(default=None, description="API响应的格式") format: Optional[str] = Field(default=None, description="API响应的格式")
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表") response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息") error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
total_tokens_usage: dict = Field(default_factory=lambda: { total_tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")

View File

@@ -1,5 +1,5 @@
from .parse_markdown import parse_markdown from .parse_markdown import *
from .document import * from .document import *
from .json_example import generate_example_json from .json_example import generate_json_example
from .model import *
from .port import * from .port import *
from .reasoning import call_openai_api

View File

@@ -1,19 +1,19 @@
from typing import List from typing import List
from schema.dataset import dataset, dataset_item, Q_A from schema.dataset import Dataset, DatasetItem, Q_A
import json import json
def convert_json_to_dataset(json_data: List[dict]) -> dataset: def convert_json_to_dataset(json_data: List[dict]) -> Dataset:
# 将JSON数据转换为dataset格式 # 将JSON数据转换为dataset格式
dataset_items = [] dataset_items = []
item_id = 1 # 自增ID计数器 item_id = 1 # 自增ID计数器
for item in json_data: for item in json_data:
qa = Q_A(question=item["question"], answer=item["answer"]) qa = Q_A(question=item["question"], answer=item["answer"])
dataset_item_obj = dataset_item(id=item_id, message=[qa]) dataset_item_obj = DatasetItem(id=item_id, message=[qa])
dataset_items.append(dataset_item_obj) dataset_items.append(dataset_item_obj)
item_id += 1 # ID自增 item_id += 1 # ID自增
# 创建dataset对象 # 创建dataset对象
result_dataset = dataset( result_dataset = Dataset(
name="Converted Dataset", name="Converted Dataset",
model_id=None, model_id=None,
description="Dataset converted from JSON", description="Dataset converted from JSON",

View File

@@ -4,7 +4,7 @@ from pathlib import Path
# 添加项目根目录到sys.path # 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import doc from schema import Doc
def scan_docs_directory(workdir: str): def scan_docs_directory(workdir: str):
docs_dir = os.path.join(workdir, "docs") docs_dir = os.path.join(workdir, "docs")
@@ -21,7 +21,7 @@ def scan_docs_directory(workdir: str):
for file in files: for file in files:
if file.endswith(".md"): if file.endswith(".md"):
markdown_files.append(os.path.join(root, file)) markdown_files.append(os.path.join(root, file))
to_return.append(doc(name=doc_name, path=doc_path, markdown_files=markdown_files)) to_return.append(Doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
return to_return return to_return

View File

@@ -1,32 +1,29 @@
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
from typing import Any, Dict, List, Optional, get_args, get_origin from typing import Any, Dict, List, Optional, Union, get_args, get_origin
import json import json
from datetime import datetime, date from datetime import datetime, date
def generate_example_json(model: type[BaseModel]) -> str: def generate_json_example(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
""" """
根据 Pydantic V2 模型生成示例 JSON 数据结构。 根据 Pydantic V2 模型生成示例 JSON 数据结构。
""" """
def _generate_example(field_type: Any) -> Any: def _generate_example(field_type: Any) -> Any:
origin = get_origin(field_type) origin = get_origin(field_type)
args = get_args(field_type) args = get_args(field_type)
if origin is list or origin is List: if origin is list or origin is List:
if args: # 生成多个元素,这里生成 3 个
return [_generate_example(args[0])] result = [_generate_example(args[0]) for _ in range(list_length)] if args else []
else: result.append("......")
return [] return result
elif origin is dict or origin is Dict: elif origin is dict or origin is Dict:
if len(args) == 2 and args[0] is str: if len(args) == 2:
return {"key": _generate_example(args[1])} return {"key": _generate_example(args[1])}
else:
return {} return {}
elif origin is Optional or origin is type(None): elif origin is Union:
if args: # 处理 Optional 类型Union[T, None]
return _generate_example(args[0]) non_none_args = [arg for arg in args if arg is not type(None)]
else: return _generate_example(non_none_args[0]) if non_none_args else None
return None
elif field_type is str: elif field_type is str:
return "string" return "string"
elif field_type is int: elif field_type is int:
@@ -39,13 +36,22 @@ def generate_example_json(model: type[BaseModel]) -> str:
return datetime.now().isoformat() return datetime.now().isoformat()
elif field_type is date: elif field_type is date:
return date.today().isoformat() return date.today().isoformat()
elif issubclass(field_type, BaseModel): elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
return generate_example_json(field_type) return json.loads(generate_json_example(field_type, include_optional))
else: else:
return "unknown" # 对于未知类型返回 "unknown" # 处理直接类型注解(非泛型)
if field_type is type(None):
return None
try:
if issubclass(field_type, BaseModel):
return json.loads(generate_json_example(field_type, include_optional))
except TypeError:
pass
return "unknown"
example_data = {} example_data = {}
for field_name, field in model.model_fields.items(): for field_name, field in model.model_fields.items():
if include_optional or not isinstance(field.default, type(None)):
example_data[field_name] = _generate_example(field.annotation) example_data[field_name] = _generate_example(field.annotation)
return json.dumps(example_data, indent=2, default=str) return json.dumps(example_data, indent=2, default=str)
@@ -55,9 +61,7 @@ if __name__ == "__main__":
from pathlib import Path from pathlib import Path
# 添加项目根目录到sys.path # 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import Q_A from schema import Dataset
class Q_A_list(BaseModel):
Q_As: List[Q_A]
print("示例 JSON:") print("示例 JSON:")
print(generate_example_json(Q_A_list)) print(generate_json_example(Dataset))

123
tools/reasoning.py Normal file
View File

@@ -0,0 +1,123 @@
import sys
import asyncio
import openai
from pathlib import Path
from datetime import datetime, timezone
from typing import Optional
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters
async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest:
start_time = datetime.now(timezone.utc)
client = openai.AsyncOpenAI(
api_key=llm_request.api_provider.api_key,
base_url=llm_request.api_provider.base_url
)
total_duration = 0.0
total_tokens = TokensUsage()
prompt = llm_request.prompt
round_start = datetime.now(timezone.utc)
if llm_request.format:
prompt += "\n请以JSON格式返回结果" + llm_request.format
for i in range(rounds):
round_start = datetime.now(timezone.utc)
try:
messages = [{"role": "user", "content": prompt}]
create_args = {
"model": llm_request.api_provider.model_id,
"messages": messages,
"temperature": llm_parameters.temperature if llm_parameters else None,
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
"top_p": llm_parameters.top_p if llm_parameters else None,
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
"seed": llm_parameters.seed if llm_parameters else None
} # 处理format参数
if llm_request.format:
create_args["response_format"] = {"type": "json_object"}
response = await client.chat.completions.create(**create_args)
round_end = datetime.now(timezone.utc)
duration = (round_end - round_start).total_seconds()
total_duration += duration
# 处理可能不存在的缓存token字段
usage = response.usage
cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None)
cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None)
tokens_usage = TokensUsage(
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
prompt_cache_hit_tokens=cache_hit,
prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens
)
# 累加总token使用量
total_tokens.prompt_tokens += tokens_usage.prompt_tokens
total_tokens.completion_tokens += tokens_usage.completion_tokens
if tokens_usage.prompt_cache_hit_tokens:
total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens
if tokens_usage.prompt_cache_miss_tokens:
total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens
llm_request.response.append(LLMResponse(
response_id=response.id,
tokens_usage=tokens_usage,
content = response.choices[0].message.content,
total_duration=duration,
llm_parameters=llm_parameters
))
except Exception as e:
round_end = datetime.now(timezone.utc)
duration = (round_end - round_start).total_seconds()
total_duration += duration
llm_request.response.append(LLMResponse(
response_id=f"error-round-{i+1}",
content={"error": str(e)},
total_duration=duration
))
if llm_request.error is None:
llm_request.error = []
llm_request.error.append(str(e))
# 更新总耗时和总token使用量
llm_request.total_duration = total_duration
llm_request.total_tokens_usage = total_tokens
return llm_request
if __name__ == "__main__":
from json_example import generate_json_example
from sqlmodel import Session, select
from global_var import get_sql_engine, init_global_var
from schema import DatasetItem
init_global_var("workdir")
api_state = "1 deepseek-chat"
with Session(get_sql_engine()) as session:
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
llm_request = LLMRequest(
prompt="测试,随便说点什么",
api_provider=api_provider,
format=generate_json_example(DatasetItem)
)
# # 单次调用示例
# result = asyncio.run(call_openai_api(llm_request))
# print(f"\n单次调用结果 - 响应数量: {len(result.response)}")
# for i, resp in enumerate(result.response, 1):
# print(f"响应{i}: {resp.response_content}")
# 多次调用示例
params = LLMParameters(temperature=0.7, max_tokens=100)
result = asyncio.run(call_openai_api(llm_request, 3,params))
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
for i, resp in enumerate(result.response, 1):
print(f"响应{i}: {resp.content}")

1
train/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .model import *

48
train/save_model.py Normal file
View File

@@ -0,0 +1,48 @@
import os
from global_var import get_model, get_tokenizer
def save_model_to_dir(save_model_name, models_dir, model, tokenizer, save_method="default"):
"""
保存模型到指定目录
:param save_model_name: 要保存的模型名称
:param models_dir: 模型保存的基础目录
:param model: 要保存的模型
:param tokenizer: 要保存的tokenizer
:param save_method: 保存模式选项
- "default": 默认保存方式
- "merged_16bit": 合并为16位
- "merged_4bit": 合并为4位
- "lora": 仅LoRA适配器
- "gguf": 保存为GGUF格式
- "gguf_q4_k_m": 保存为q4_k_m GGUF格式
- "gguf_f16": 保存为16位GGUF格式
:return: 保存结果消息或错误信息
"""
try:
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name)
os.makedirs(save_path, exist_ok=True)
if save_method == "default":
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
elif save_method == "merged_16bit":
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_16bit")
elif save_method == "merged_4bit":
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_4bit_forced")
elif save_method == "lora":
model.save_pretrained_merged(save_path, tokenizer, save_method="lora")
elif save_method == "gguf":
model.save_pretrained_gguf(save_path, tokenizer)
elif save_method == "gguf_q4_k_m":
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="q4_k_m")
elif save_method == "gguf_f16":
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="f16")
else:
return f"不支持的保存模式: {save_method}"
return f"模型已保存到 {save_path} (模式: {save_method})"
except Exception as e:
return f"保存模型时出错: {str(e)}"