Compare commits

54 Commits

Author SHA1 Message Date
carry
7a4388c928 featmodel): 添加保存模式选择功能
在模型管理页面中新增保存模式选择功能,用户可以通过下拉菜单选择不同的保存模式(如默认、合并16位、合并4位等)。同时,将保存模型的逻辑抽离到独立的`save_model.py`文件中,以提高代码的可维护性和复用性。
2025-04-23 14:09:02 +08:00
carry
6338706967 feat: 修改应用启动方式
- 将 app.launch() 修改为 app.launch(server_name="0.0.0.0")
- 此修改使应用能够监听所有网络接口,提高可用性
2025-04-22 19:08:23 +08:00
carry
3718c75cee feat(frontend): 添加数据集生成页面的处理进度显示
- 在处理文档片段时添加进度条,提升用户体验
- 优化代码格式,调整缩进和空行
2025-04-22 00:14:16 +08:00
carry
905658073a docs(README): 更新项目文档
- 添加项目概述、核心功能、技术架构等详细信息
- 插入系统架构图和技术栈说明
- 细化功能模块描述,包括模型管理、推理、微调等
- 增加QLoRA原理和参数配置说明
- 补充快速开始指南和许可证信息
- 优化文档结构,增强可读性和完整性
2025-04-21 14:28:20 +08:00
carry
9806334517 fix(train_page): 捕获训练过程中的异常并终止 TensorBoard 进程
- 在训练过程中添加异常捕获,将异常信息转换为 gr.Error 抛出
- 确保在发生异常时也能终止 TensorBoard 子进程
2025-04-20 21:40:46 +08:00
carry
0a4efa5641 feat(dataset): 添加数据集生成功能
- 新增数据集生成页面和相关逻辑
- 实现数据集名称重复性检查
- 添加数据集对象创建和保存功能
- 优化文档处理和提示模板应用
- 增加错误处理和数据解析
2025-04-20 21:25:51 +08:00
carry
994d600221 refactor(frontend): 调整 TensorBoard iframe 高度
- 将 TensorBoard iframe 的高度从 500px 修改为 1000px
- 此修改旨在提供更宽敞的显示区域,改善用户体验
2025-04-20 21:25:37 +08:00
carry
d5774eee0c feat(db): 添加数据集导出功能
- 新增 save_dataset 函数,用于将 TinyDB 中的数据集保存为单独的 JSON 文件
- 更新 db/__init__.py,添加 get_dataset_tinydb 函数的引用
- 修改 db/dataset_store.py,实现 save_dataset 函数并添加相关逻辑
2025-04-20 19:44:11 +08:00
carry
87501c9353 fix(global_var): 移除全局变量设置函数set_datasets
- 删除了 global_var.py 文件中的 set_datasets 函数
- 该函数用于设置全局变量 _datasets,但似乎已不再使用
2025-04-20 19:14:00 +08:00
carry
5fc3b4950b refactor(schema): 修改 LLMResponse 中 API 响应内容的字段名称
- 将 LLMResponse 类中的 response_content 字段重命名为 content
- 更新字段类型从 dict 改为 str,以更准确地表示响应内容
- 在 reasoning.py 中相应地修改了调用 LLMResponse 时的参数
2025-04-20 18:40:51 +08:00
carry
c28e4819d9 refactor(frontend/tools): 重命名生成示例 JSON 数据结构的函数
- 将 generate_example_json 函数重命名为 generate_json_example
- 更新相关文件中的函数调用和引用
- 此更改旨在使函数名称更具描述性和一致性
2025-04-20 16:11:36 +08:00
carry
e7cf51d662 refactor(frontend): 重构数据集生成页面
- 调整页面布局,优化用户交互流程
- 新增数据集名称输入框
- 使用 LLMRequest 和 LLMResponse 模型处理请求和响应
- 添加 generate_example_json 函数用于格式化生成数据
- 改进数据集生成逻辑,支持多轮次生成
2025-04-20 16:10:08 +08:00
carry
4c9caff668 refactor(schema): 重构数据集和文档类的命名
- 将 dataset、dataset_item 和 doc 类的首字母大写,以符合 Python 类命名惯例
- 更新相关模块中的导入和引用,以适应新的类名
- 此更改不影响功能,仅提高了代码的一致性和可读性
2025-04-20 01:46:15 +08:00
carry
9236f49b36 feat(frontend): 添加文档切片和并发数功能
- 新增并发数输入框
- 实现文档切片处理
- 更新生成数据集的逻辑,支持并发处理
2025-04-20 01:40:48 +08:00
carry
868fcd45ba refactor(project): 重构项目文件组织结构
- 修改模型管理和训练页面的导入路径
- 更新 main.py 中的导入模块
- 调整 tools 包的内容,移除 model 模块
- 新建 train 包,包含 model 模块
- 优化 __init__.py 文件,简化导入语句
2025-04-19 21:49:19 +08:00
carry
5a21c8598a feat(tools): 支持 OpenAI API 的 JSON 格式返回结果
- 在 call_openai_api 函数中添加对 JSON 格式返回结果的支持
- 增加 llm_request.format 参数处理,将用户 prompt 与格式要求合并
- 添加 response_format 参数到 OpenAI API 请求
- 更新示例,使用 JSON 格式返回结果
2025-04-19 21:10:22 +08:00
carry
1e829c9268 feat(tools): 优化 JSON 示例生成函数
- 增加 include_optional 参数,决定是否包含可选字段
- 添加 list_length 参数,用于控制列表字段的示例长度
- 在列表示例中添加省略标记,更直观展示多元素列表
- 优化字典字段的示例生成逻辑
2025-04-19 21:07:00 +08:00
carry
9fc3ab904b feat(frontend): 实现了固定参数的注入 2025-04-19 17:48:45 +08:00
carry
d827f9758f fix(frontend): 修复dataframe_value返回值只有一列的bug 2025-04-19 17:30:10 +08:00
carry
ff1e9731bc fix(tools): 修复call_openai_api的导出 2025-04-19 17:13:19 +08:00
carry
90fde639ff feat(tools): 增加 OpenAI API 多轮调用功能
- 在 call_openai_api 函数中添加 rounds 参数,支持多次调用
- 累加每次调用的耗时和 token 使用情况
- 将多次调用的结果存储在 LLMRequest 对象的 response 列表中
- 更新函数返回类型,返回包含多次调用信息的 LLMRequest 对象
- 优化错误处理,记录每轮调用的错误信息
2025-04-19 17:02:00 +08:00
carry
5fc90903fb feat(tools): 添加 reasoning.py 工具模块
- 新增 reasoning.py 文件,实现与 OpenAI API 的交互
- 添加 call_openai_api 函数,用于发送请求并处理响应
- 支持可选的 LLMParameters 参数,以定制化请求
- 处理 API 响应中的 tokens 使用情况
- 提供错误处理和缓存 token 字段的处理
2025-04-19 16:53:48 +08:00
carry
81c2ad4a2d refactor(schema): 重构数据模型以提高可维护性和可扩展性
- 新增 LLMParameters 类以统一处理 LLM 参数
- 新增 TokensUsage 类以统一处理 token 使用信息
- 更新 LLMResponse 和 LLMRequest 类,使用新的 LLMParameters 和 TokensUsage 类
- 优化数据模型结构,提高代码的可读性和可维护性
2025-04-19 16:39:18 +08:00
carry
314434951d feat(frontend): 实现了文档、提示和 API 提供商的获取逻辑 2025-04-19 14:47:01 +08:00
carry
e16882953d fix(tools): 修复了optional字段无法被解析的问题 2025-04-18 22:00:51 +08:00
carry
86bcf90c66 feat(frontend): 添加数据集生成轮次控制功能
- 在数据集生成页面添加"生成轮次"输入框,支持设置生成轮数
- 更新生成逻辑,根据设置的轮次进行多次生成
2025-04-18 15:47:37 +08:00
carry
961a017f19 refactor(frontend): 调整数据集生成页面布局并优化代码结构
- 使用 gr.Column(scale=1) 和 gr.Column(scale=2) 调整列宽比例
- 移除多余的空行和缩进,提高代码可读性
- 优化变量声明和组件创建的顺序,使页面结构更清晰
2025-04-18 15:40:15 +08:00
carry
5a386d6401 feat(dataset_generate_page): 添加 API 选择功能
- 在数据集生成页面添加 API 选择下拉框
- 实现 API 选择变更时的处理逻辑
- 更新数据集生成函数,增加 API 选择参数
- 优化页面布局和代码结构
2025-04-18 15:23:33 +08:00
carry
feaea1fb64 refactor(db): 重命名数据库引擎加载函数
- 将 get_sqlite_engine 函数重命名为 load_sqlite_engine
- 更新了相关模块中的导入和调用
- 这个改动是为了更好地反映函数的实际功能,提高代码可读性
2025-04-18 15:16:29 +08:00
carry
7242a2ce03 feat(frontend): 添加生成数据集进度条功能并优化了界面布局 2025-04-18 15:07:46 +08:00
carry
db6e2271dc fix(frontend): 修复 prompt_dropdown 变化时,dataframe没有相应的变化
- 将 prompt_dropdown 变化时的输出从 prompt_state 修改为 [prompt_state, variables_dataframe]
- 这个改动可能会在 prompt 变化时同时更新变量数据框
2025-04-18 14:03:26 +08:00
carry
d764537143 feat(dataset_generate_page): 更新数据集生成页面功能
- 添加模板变量列表展示和编辑功能
- 实现模板选择后动态更新变量列表
- 增加生成数据集按钮和相关逻辑
- 优化页面布局和交互
2025-04-16 12:39:48 +08:00
carry
8c35a38c47 feat(frontend): 更新模板选择功能
- 在模板选择变更时,获取所选模板的详细信息
- 创建 PromptTemplate 对象并获取输入变量列表
- 此更新为后续的模板编辑功能做准备
2025-04-15 21:31:50 +08:00
carry
7ee751c88f fix(frontend): 移除文档生成页面的冗余事件绑定代码
- 删除了原有的简单事件绑定逻辑,这些逻辑仅将输入值赋给状态变量
- 为后续添加更复杂的文档选择更改事件处理函数做准备
2025-04-15 20:44:26 +08:00
carry
b715b36a5f feat(frontend): 更新数据集生成页面并添加独立运行功能
- 重构导入路径,使用绝对路径替换相对路径
- 新增文档和模板选择的事件处理函数
- 添加独立运行数据集生成页面的功能
- 优化代码结构,提高可读性和可维护性
2025-04-15 17:13:52 +08:00
carry
8023233bb2 feat(prompt): 增加模板变量有效性检查
- 在 promptTempleta 模型中添加字段验证器
- 验证模板内容是否包含必要的 document_slice 变量
- 如果缺少该变量,抛出 ValueError 异常
2025-04-15 16:54:17 +08:00
carry
2a86b3b5b0 fix(db): 初始化 prompt store 时插入第一条记录的 ID 从 0 改为 1
- 将初始化时插入的第一条记录的 ID 从 0 改为 1
- 修正了文档节选的变量名,从 {content} 改为 {document_slice}
2025-04-15 16:45:12 +08:00
carry
ca1505304e fix(tools): 更新 tools/__init__.py 中的导入语句
- 将 from .doc import * 改为 from .document import *
- 这个修改统一了文档处理模块的命名,提高了代码的一致性和可读性
2025-04-15 16:31:55 +08:00
carry
df9260e918 fix(db): 修复初始提示词的变量花括号的空格问题 2025-04-15 16:13:57 +08:00
carry
df9aba0c6e refactor(tools): 重命名模块并更新导入
- 将 scan_doc_dir.py 重命名为 document.py
- 将 socket.py 重命名为 port.py
- 更新 __init__.py 中的导入语句
- 在 port.py 中添加测试代码,用于查找可用端口
2025-04-15 15:47:44 +08:00
carry
6b87dcb58f refactor(frontend): 重构数据集生成页面的变量命名逻辑
- 将 prompt_choices 变量重命名为 prompt_list,以更准确地反映其内容
- 更新相关代码中对这两个变量的引用,以保持一致性
2025-04-15 15:40:24 +08:00
carry
d0aebd17fa refactor(global_var): 重构全局变量管理
- 移除了 _docs 全局变量
- 更新了 get_docs() 函数,使其在每次调用时重新扫描文档目录
- 优化了全局变量初始化逻辑
2025-04-15 15:25:44 +08:00
carry
d9abf08184 fix(frontend): 修复表格选择事件的行数据获取问题
- 在 prompt_manage_page 和 setting_page 中更新了 select_record 函数
- 使用 DataFrame.iloc 方法获取选中行的数据,并转换为列表
- 添加了将第一列数据转换为整数的逻辑
- 更新了表格选择事件的参数,增加了输入和输出参数
- 将 gradio 版本升级到 5.25.0
2025-04-15 15:10:15 +08:00
carry
a27a1ab079 refactor(frontend): 重构训练页面布局并优化用户界面
- 调整数据集下拉框布局位置
- 新增超参数输入组件
- 修改训练日志输出框标签为"训练状态"
- 添加 TensorBoard 可视化 iframe 显示框
2025-04-15 00:12:09 +08:00
carry
aa758e3c2a feat(train_page): 添加 TensorBoard 可视化
- 在训练页面添加 TensorBoard iframe 显示框
- 实现动态生成 TensorBoard iframe 功能
- 更新训练按钮点击事件,同时更新 TensorBoard iframe
2025-04-14 23:28:43 +08:00
carry
664944f0c5 feat(frontend): 优化 TensorBoard 端口占用问题
- 新增端口检测逻辑,动态分配可用端口
- 修改 TensorBoard 启动过程,使用动态分配的端口
- 添加 socket 模块,用于端口检测
2025-04-14 17:06:44 +08:00
carry
9298438f98 feat(train_page): 启动 TensorBoard 进程并确保训练结束后终止
- 在训练页面中添加 TensorBoard 进程启动代码
- 创建日志目录并启动 TensorBoard 子进程
- 在训练结束后终止 TensorBoard 进程
2025-04-14 17:00:33 +08:00
carry
4f7926aec6 feat(train_page): 实现训练目录自动递增功能
- 在 training 文件夹下创建递增的目录结构
- 确保 training 文件夹存在
- 扫描现有目录,生成下一个可用的目录编号
- 更新训练模型函数,使用新的训练目录
2025-04-14 16:46:29 +08:00
carry
148f4afb25 fix(main): 修复unsloth没有最先导入的问题,删除了重复的导入语句 2025-04-14 16:31:47 +08:00
carry
11a3039775 fix(train_page): 修正模型训练保存路径 2025-04-14 16:31:00 +08:00
carry
a4289815ba build: 添加 tensorboard 依赖
- 在 requirements.txt 中添加 tensorboard>=2.19.0
- 此改动增加了 tensorboard 作为项目的新依赖项
2025-04-14 16:30:39 +08:00
carry
088067d335 train: 更新模型训练功能和日志记录方式
- 修改训练目录结构,将检查点和日志分开保存
- 添加 TensorBoard 日志记录支持
- 移除自定义 LossCallback 类,简化训练流程
- 更新训练参数和回调机制,提高代码可读性
- 在 requirements.txt 中添加 tensorboardX 依赖
2025-04-14 16:19:37 +08:00
carry
9fb31c46c8 feat(train): 添加训练过程中的日志记录和 loss 可视化功能
- 新增 LossCallback 类,用于在训练过程中记录 loss 数据
- 在训练模型函数中添加回调,实现日志记录和 loss 可视化
- 优化训练过程中的输出信息,增加当前步数和 loss 值的打印
2025-04-14 15:18:14 +08:00
carry
4f09823123 refactor(tools): 优化 train_model 函数定义
- 添加类型注解,提高代码可读性和维护性
- 使用多行格式定义函数参数,提升代码格式美观
2025-04-14 14:28:36 +08:00
26 changed files with 682 additions and 179 deletions

View File

@@ -1,15 +1,91 @@
# 基于文档驱动的自适应编码大模型微调框架 # 基于文档驱动的自适应编码大模型微调框架
## 简介 ## 简介
本人的毕业设计
### 项目概述 ### 项目概述
本项目是一个基于文档驱动的自适应编码大模型微调框架,通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。 ### 核心功能
- 文档解析与语料生成
- 大语言模型高效微调
- 交互式训练与推理界面
- 训练过程可视化监控
### 项目技术 ## 技术架构
* 使用unsloth框架在GPU上实现大语言模型的qlora微调 ### 系统架构
* 使用langchain框架编写工作流实现批量生成微调语料 ```
* 使用tinydb和sqlite实现数据的持久化 [前端界面] -> [模型微调] -> [数据存储]
* 使用gradio框架实现前端展示 ↑ ↑ ↑
│ │ │
[Gradio] [unsloth/QLoRA] [SQLite/TinyDB]
```
**施工中......** ### 技术栈
- **前端界面**: Gradio构建的交互式Web界面
- **模型微调**: 基于unsloth框架的QLoRA高效微调
- **数据存储**: SQLite(结构化数据) + TinyDB(非结构化数据)
- **工作流引擎**: LangChain实现文档解析与语料生成
- **训练监控**: TensorBoard集成
## 功能模块
### 1. 模型管理
- 支持多种格式的大语言模型加载
- 模型信息查看与状态管理
- 模型卸载与内存释放
### 2. 模型推理
- 对话式交互界面
- 流式响应输出
- 历史对话管理
### 3. 模型微调
- 训练参数配置(学习率、batch size等)
- LoRA参数配置(秩、alpha等)
- 训练过程实时监控
- 训练中断与恢复
### 4. 数据集生成
- 文档解析与清洗
- 指令-响应对生成
- 数据集质量评估
- 数据集版本管理
## 微调技术
### QLoRA原理
QLoRA(Quantized Low-Rank Adaptation)是一种高效的大模型微调技术,核心特点包括:
1. **4-bit量化**: 将预训练模型量化为4-bit表示大幅减少显存占用
2. **低秩适配**: 通过低秩矩阵分解(LoRA)实现参数高效更新
3. **内存优化**: 使用梯度检查点等技术进一步降低显存需求
### 参数配置
- **学习率**: 建议2e-5到2e-4
- **LoRA秩**: 控制适配器复杂度(建议16-64)
- **LoRA Alpha**: 控制适配器更新幅度(通常设为秩的1-2倍)
### 训练监控
- **TensorBoard集成**: 实时查看损失曲线、学习率等指标
- **日志记录**: 训练过程详细日志保存
- **模型检查点**: 定期保存中间权重
## 快速开始
1. 安装依赖:
```bash
pip install -r requirements.txt
```
2. 启动应用:
```bash
python main.py
```
3. 访问Web界面:
```
http://localhost:7860
```
## 许可证
MIT License

View File

@@ -1,11 +1,12 @@
from .init_db import get_sqlite_engine, initialize_sqlite_db from .init_db import load_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store from .prompt_store import get_prompt_tinydb, initialize_prompt_store
from .dataset_store import get_all_dataset from .dataset_store import get_all_dataset, save_dataset
__all__ = [ __all__ = [
"get_sqlite_engine", "load_sqlite_engine",
"initialize_sqlite_db", "initialize_sqlite_db",
"get_prompt_tinydb", "get_prompt_tinydb",
"initialize_prompt_store", "initialize_prompt_store",
"get_all_dataset" "get_all_dataset",
"save_dataset"
] ]

View File

@@ -8,7 +8,7 @@ from tinydb.storages import MemoryStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块 # 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A from schema.dataset import Dataset, DatasetItem, Q_A
def get_all_dataset(workdir: str) -> TinyDB: def get_all_dataset(workdir: str) -> TinyDB:
""" """
@@ -39,6 +39,30 @@ def get_all_dataset(workdir: str) -> TinyDB:
return db return db
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
"""
将TinyDB中的数据集保存为单独的json文件
Args:
db (TinyDB): 包含数据集对象的TinyDB实例
workdir (str): 工作目录路径
name (str, optional): 要保存的数据集名称None表示保存所有
"""
dataset_dir = os.path.join(workdir, "dataset")
os.makedirs(dataset_dir, exist_ok=True)
datasets = db.all() if name is None else db.search(Query().name == name)
for dataset in datasets:
try:
filename = f"{dataset.get(dataset['name'])}.json"
filepath = os.path.join(dataset_dir, filename)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":
# 定义工作目录路径 # 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
@@ -48,3 +72,10 @@ if __name__ == "__main__":
print(f"Found {len(datasets)} datasets:") print(f"Found {len(datasets)} datasets:")
for ds in datasets.all(): for ds in datasets.all():
print(f"- {ds['name']} (ID: {ds['id']})") print(f"- {ds['name']} (ID: {ds['id']})")
# 询问要保存的数据集名称
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
# 保存数据集到文件
save_dataset(datasets, workdir, name)
print(f"Datasets {'all' if name is None else name} saved to json files")

View File

@@ -14,7 +14,7 @@ from schema.dataset_generation import APIProvider
# 全局变量,用于存储数据库引擎实例 # 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None _engine: Optional[Engine] = None
def get_sqlite_engine(workdir: str) -> Engine: def load_sqlite_engine(workdir: str) -> Engine:
""" """
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。 获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
@@ -74,6 +74,6 @@ if __name__ == "__main__":
# 定义工作目录路径 # 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎 # 获取数据库引擎
engine = get_sqlite_engine(workdir) engine = load_sqlite_engine(workdir)
# 初始化数据库 # 初始化数据库
initialize_sqlite_db(engine) initialize_sqlite_db(engine)

View File

@@ -44,12 +44,12 @@ def initialize_prompt_store(db: TinyDB) -> None:
# 检查数据库是否为空 # 检查数据库是否为空
if not db.all(): # 如果数据库中没有数据 if not db.all(): # 如果数据库中没有数据
db.insert(promptTempleta( db.insert(promptTempleta(
id=0, id=1,
name="default", name="default",
description="默认提示词模板", description="默认提示词模板",
content="""项目名为:{project_name} content="""项目名为:{project_name}
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文 请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
文档节选:{ content }""").model_dump()) 文档节选:{document_slice}""").model_dump())
# 如果数据库中已有数据,则跳过插入 # 如果数据库中已有数据,则跳过插入

View File

@@ -1,42 +1,189 @@
import gradio as gr import gradio as gr
from tools import scan_docs_directory import sys
from global_var import get_docs, scan_docs_directory, get_prompt_store import json
from tinydb import Query
from pathlib import Path
from langchain.prompts import PromptTemplate
from sqlmodel import Session, select
from schema import Dataset, DatasetItem, Q_A
from db.dataset_store import get_all_dataset
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
from db import save_dataset
from tools import call_openai_api, process_markdown_file, generate_json_example
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
def dataset_generate_page(): def dataset_generate_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("## 数据集生成") gr.Markdown("## 数据集生成")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column(scale=1):
# 获取文档列表并设置初始值
docs_list = [str(doc.name) for doc in get_docs()] docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None initial_doc = docs_list[0] if docs_list else None
prompts = get_prompt_store().all()
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_list[0] if prompt_list else None
# 初始化Dataframe的值
initial_dataframe_value = []
if initial_prompt:
selected_prompt_id = int(initial_prompt.split(" ")[0])
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
initial_dataframe_value = [[var, ""] for var in input_variables]
# 从数据库获取API Provider列表
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
api_list = [f"{p.id} {p.model_id}" for p in providers]
initial_api = api_list[0] if api_list else None
api_dropdown = gr.Dropdown(
choices=api_list,
value=initial_api,
label="选择API",
interactive=True
)
doc_dropdown = gr.Dropdown( doc_dropdown = gr.Dropdown(
choices=docs_list, choices=docs_list,
value=initial_doc, # 设置初始选中项 value=initial_doc,
label="选择文档", label="选择文档",
allow_custom_value=True,
interactive=True interactive=True
) )
doc_state = gr.State(value=initial_doc) # 用文档初始值初始化状态
with gr.Column():
# 获取模板列表并设置初始值
prompts = get_prompt_store().all()
prompt_choices = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_choices[0] if prompt_choices else None
prompt_dropdown = gr.Dropdown( prompt_dropdown = gr.Dropdown(
choices=prompt_choices, choices=prompt_list,
value=initial_prompt, # 设置初始选中项 value=initial_prompt,
label="选择模板", label="选择模板",
allow_custom_value=True,
interactive=True interactive=True
) )
prompt_state = gr.State(value=initial_prompt) # 用模板初始值初始化状态 rounds_input = gr.Number(
value=1,
label="生成轮次",
minimum=1,
maximum=100,
step=1,
interactive=True
)
concurrency_input = gr.Number(
value=1,
label="并发数",
minimum=1,
maximum=10,
step=1,
interactive=True,
visible=False
)
dataset_name_input = gr.Textbox(
label="数据集名称",
placeholder="输入数据集保存名称",
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
generate_button = gr.Button("生成数据集",variant="primary")
doc_choice = gr.State(value=initial_doc)
output_text = gr.Textbox(label="生成结果", interactive=False)
api_choice = gr.State(value=initial_api)
with gr.Column(scale=2):
variables_dataframe = gr.Dataframe(
headers=["变量名", "变量值"],
datatype=["str", "str"],
interactive=True,
label="变量列表",
value=initial_dataframe_value # 设置初始化数据
)
def on_doc_change(selected_doc):
return selected_doc
# 绑定事件(保留原有逻辑,确保交互时更新) def on_api_change(selected_api):
doc_dropdown.change(lambda x: x, inputs=doc_dropdown, outputs=doc_state) return selected_api
prompt_dropdown.change(lambda x: x, inputs=prompt_dropdown, outputs=prompt_state)
def on_prompt_change(selected_prompt):
if not selected_prompt:
return None, []
selected_prompt_id = int(selected_prompt.split(" ")[0])
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
dataframe_value = [] if input_variables is None else input_variables
dataframe_value = [[var, ""] for var in input_variables]
return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
dataset_db = get_datasets()
if not dataset_db.search(Query().name == dataset_name):
raise gr.Error("数据集名称已存在")
doc = [i for i in get_docs() if i.name == doc_state][0]
doc_files = doc.markdown_files
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
prompt = PromptTemplate.from_template(prompt["content"])
with Session(get_sql_engine()) as session:
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
variables_dict = {}
for _, row in variables_dataframe.iterrows():
var_name = row['变量名'].strip()
var_value = row['变量值'].strip()
if var_name:
variables_dict[var_name] = var_value
prompt = prompt.partial(**variables_dict)
dataset = Dataset(
name=dataset_name,
model_id=[api_provider.model_id],
source_doc=doc,
dataset_items=[]
)
total_slices = len(document_slice_list)
for i, document_slice in enumerate(document_slice_list):
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
request = LLMRequest(api_provider=api_provider,
prompt=prompt.format(document_slice=document_slice),
format=generate_json_example(DatasetItem))
call_openai_api(request, rounds)
for resp in request.response:
try:
content = json.loads(resp.content)
dataset_item = DatasetItem(
message=[Q_A(
question=content.get("question", ""),
answer=content.get("answer", "")
)]
)
dataset.dataset_items.append(dataset_item)
except json.JSONDecodeError as e:
print(f"Failed to parse response: {e}")
# 保存数据集到TinyDB
dataset_db.insert(dataset.model_dump())
save_dataset(dataset_db,get_workdir(),dataset_name)
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
generate_button.click(
on_generate_click,
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
outputs=output_text
)
return demo return demo
if __name__ == "__main__":
from global_var import init_global_var
init_global_var("workdir")
demo = dataset_generate_page()
demo.launch()

View File

@@ -7,7 +7,7 @@ import torch
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, set_model, set_tokenizer from global_var import get_model, get_tokenizer, set_model, set_tokenizer
from tools.model import get_model_name from train import get_model_name
def model_manage_page(): def model_manage_page():
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹 workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
@@ -30,6 +30,11 @@ def model_manage_page():
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称") save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
save_method_dropdown = gr.Dropdown(
choices=["default", "merged_16bit", "merged_4bit", "lora", "gguf", "gguf_q4_k_m", "gguf_f16"],
label="保存模式",
value="default"
)
with gr.Column(scale=1): with gr.Column(scale=1):
save_button = gr.Button("保存模型", variant="secondary") save_button = gr.Button("保存模型", variant="secondary")
@@ -73,21 +78,12 @@ def model_manage_page():
unload_button.click(fn=unload_model, inputs=None, outputs=state_output) unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
def save_model(save_model_name): from train.save_model import save_model_to_dir
try:
global model, tokenizer
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name) def save_model(save_model_name, save_method):
os.makedirs(save_path, exist_ok=True) return save_model_to_dir(save_model_name, models_dir, get_model(), get_tokenizer(), save_method)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
return f"模型已保存到 {save_path}"
except Exception as e:
return f"保存模型时出错: {str(e)}"
save_button.click(fn=save_model, inputs=save_model_name_input, outputs=state_output) save_button.click(fn=save_model, inputs=[save_model_name_input, save_method_dropdown], outputs=state_output)
def refresh_model_list(): def refresh_model_list():
try: try:

View File

@@ -61,9 +61,11 @@ def prompt_manage_page():
selected_row = None # 保存当前选中行的全局变量 selected_row = None # 保存当前选中行的全局变量
def select_record(evt: gr.SelectData): def select_record(dataFrame ,evt: gr.SelectData):
global selected_row global selected_row
selected_row = evt.row_value selected_row = dataFrame.iloc[evt.index[0]].tolist()
selected_row[0] = int(selected_row[0])
print(selected_row)
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("## 提示词模板管理") gr.Markdown("## 提示词模板管理")
@@ -102,7 +104,10 @@ def prompt_manage_page():
outputs=[prompt_table, name_input, description_input, content_input] outputs=[prompt_table, name_input, description_input, content_input]
) )
prompt_table.select(select_record, [], [], show_progress="hidden") prompt_table.select(fn=select_record,
inputs=[prompt_table],
outputs=[],
show_progress="hidden")
edit_button.click( edit_button.click(
fn=edit_prompt, fn=edit_prompt,

View File

@@ -68,9 +68,11 @@ def setting_page():
selected_row = None # 保存当前选中行的全局变量 selected_row = None # 保存当前选中行的全局变量
def select_record(evt: gr.SelectData): def select_record(dataFrame ,evt: gr.SelectData):
global selected_row global selected_row
selected_row = evt.row_value selected_row = dataFrame.iloc[evt.index[0]].tolist()
selected_row[0] = int(selected_row[0])
print(selected_row)
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("## API Provider 管理") gr.Markdown("## API Provider 管理")
@@ -109,7 +111,10 @@ def setting_page():
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出 outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
) )
provider_table.select(select_record, [], [], show_progress="hidden") provider_table.select(fn=select_record,
inputs=[provider_table],
outputs=[],
show_progress="hidden")
edit_button.click( edit_button.click(
fn=edit_provider, fn=edit_provider,

View File

@@ -1,12 +1,15 @@
import subprocess
import os
import gradio as gr import gradio as gr
import sys import sys
import torch
from tinydb import Query from tinydb import Query
from pathlib import Path from pathlib import Path
from transformers import TrainerCallback
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, get_datasets, get_workdir from global_var import get_model, get_tokenizer, get_datasets, get_workdir
from tools import train_model from tools import find_available_port
from train import train_model
def train_page(): def train_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@@ -14,7 +17,8 @@ def train_page():
# 获取数据集列表并设置初始值 # 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in get_datasets().all()] datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None initial_dataset = datasets_list[0] if datasets_list else None
with gr.Row():
with gr.Column(scale=1):
dataset_dropdown = gr.Dropdown( dataset_dropdown = gr.Dropdown(
choices=datasets_list, choices=datasets_list,
value=initial_dataset, # 设置初始选中项 value=initial_dataset, # 设置初始选中项
@@ -22,7 +26,6 @@ def train_page():
allow_custom_value=True, allow_custom_value=True,
interactive=True interactive=True
) )
# 新增超参数输入组件 # 新增超参数输入组件
learning_rate_input = gr.Number(value=2e-4, label="学习率") learning_rate_input = gr.Number(value=2e-4, label="学习率")
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0) per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
@@ -33,7 +36,10 @@ def train_page():
train_button = gr.Button("开始微调") train_button = gr.Button("开始微调")
# 训练状态输出 # 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False) output = gr.Textbox(label="训练状态", interactive=False)
with gr.Column(scale=3):
# 新增 TensorBoard iframe 显示框
tensorboard_iframe = gr.HTML(label="TensorBoard 可视化")
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank): def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
# 使用动态传入的超参数 # 使用动态传入的超参数
@@ -42,11 +48,46 @@ def train_page():
epoch = int(epoch) epoch = int(epoch)
save_steps = int(save_steps) # 新增保存步数参数 save_steps = int(save_steps) # 新增保存步数参数
lora_rank = int(lora_rank) # 新增LoRA秩参数 lora_rank = int(lora_rank) # 新增LoRA秩参数
# 加载数据集 # 加载数据集
dataset = get_datasets().get(Query().name == dataset_name) dataset = get_datasets().get(Query().name == dataset_name)
dataset = [ds["message"][0] for ds in dataset["dataset_items"]] dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
train_model(get_model(), get_tokenizer(), dataset, get_workdir(), learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank)
# 扫描 training 文件夹并生成递增目录
training_dir = get_workdir() + "/training"
os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在
existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()]
next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1
new_training_dir = os.path.join(training_dir, str(next_dir_number))
tensorboard_port = find_available_port(6006) # 从默认端口 6006 开始检测
print(f"TensorBoard 将使用端口: {tensorboard_port}")
tensorboard_logdir = os.path.join(new_training_dir, "logs")
os.makedirs(tensorboard_logdir, exist_ok=True) # 确保日志目录存在
tensorboard_process = subprocess.Popen(
["tensorboard", "--logdir", tensorboard_logdir, "--port", str(tensorboard_port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
print("TensorBoard 已启动,日志目录:", tensorboard_logdir)
# 动态生成 TensorBoard iframe
tensorboard_url = f"http://localhost:{tensorboard_port}"
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>'
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
try:
train_model(get_model(), get_tokenizer(),
dataset, new_training_dir,
learning_rate, per_device_train_batch_size, epoch,
save_steps, lora_rank)
except Exception as e:
raise gr.Error(str(e))
finally:
# 确保训练结束后终止 TensorBoard 子进程
tensorboard_process.terminate()
print("TensorBoard 子进程已终止")
train_button.click( train_button.click(
fn=start_training, fn=start_training,
@@ -56,9 +97,9 @@ def train_page():
per_device_train_batch_size_input, per_device_train_batch_size_input,
epoch_input, epoch_input,
save_steps_input, save_steps_input,
lora_rank_input # 新增lora_rank_input lora_rank_input
], ],
outputs=output outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe
) )
return demo return demo

View File

@@ -1,18 +1,16 @@
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset from db import load_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory from tools import scan_docs_directory
_prompt_store = None _prompt_store = None
_sql_engine = None _sql_engine = None
_docs = None
_datasets = None _datasets = None
_model = None _model = None
_tokenizer = None _tokenizer = None
_workdir = None _workdir = None
def init_global_var(workdir="workdir"): def init_global_var(workdir="workdir"):
global _prompt_store, _sql_engine, _docs, _datasets, _workdir global _prompt_store, _sql_engine, _datasets, _workdir
_prompt_store = get_prompt_tinydb(workdir) _prompt_store = get_prompt_tinydb(workdir)
_sql_engine = get_sqlite_engine(workdir) _sql_engine = load_sqlite_engine(workdir)
_docs = scan_docs_directory(workdir)
_datasets = get_all_dataset(workdir) _datasets = get_all_dataset(workdir)
_workdir = workdir _workdir = workdir
@@ -34,19 +32,11 @@ def set_sql_engine(new_sql_engine):
_sql_engine = new_sql_engine _sql_engine = new_sql_engine
def get_docs(): def get_docs():
return _docs global _workdir
return scan_docs_directory(_workdir)
def set_docs(new_docs):
global _docs
_docs = new_docs
def get_datasets(): def get_datasets():
return _datasets return _datasets
def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
def get_model(): def get_model():
return _model return _model

View File

@@ -1,5 +1,5 @@
import gradio as gr import gradio as gr
from frontend.setting_page import setting_page import train
from frontend import * from frontend import *
from db import initialize_sqlite_db, initialize_prompt_store from db import initialize_sqlite_db, initialize_prompt_store
from global_var import init_global_var, get_sql_engine, get_prompt_store from global_var import init_global_var, get_sql_engine, get_prompt_store
@@ -26,4 +26,4 @@ if __name__ == "__main__":
with gr.TabItem("设置"): with gr.TabItem("设置"):
setting_page() setting_page()
app.launch() app.launch(server_name="0.0.0.0")

View File

@@ -1,9 +1,11 @@
openai>=1.0.0 openai>=1.0.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0
pydantic>=2.0.0 pydantic>=2.0.0
gradio>=5.0.0 gradio>=5.25.0
langchain>=0.3 langchain>=0.3
tinydb>=4.0.0 tinydb>=4.0.0
unsloth>=2025.3.19 unsloth>=2025.3.19
sqlmodel>=0.0.24 sqlmodel>=0.0.24
jinja2>=3.1.0 jinja2>=3.1.0
tensorboardX>=2.6.2.2
tensorboard>=2.19.0

View File

@@ -1,4 +1,4 @@
from .dataset import * from .dataset import *
from .dataset_generation import APIProvider, LLMResponse, LLMRequest from .dataset_generation import *
from .md_doc import MarkdownNode from .md_doc import MarkdownNode
from .prompt import promptTempleta from .prompt import promptTempleta

View File

@@ -2,7 +2,7 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from datetime import datetime, timezone from datetime import datetime, timezone
class doc(BaseModel): class Doc(BaseModel):
id: Optional[int] = Field(default=None, description="文档ID") id: Optional[int] = Field(default=None, description="文档ID")
name: str = Field(default="", description="文档名称") name: str = Field(default="", description="文档名称")
path: str = Field(default="", description="文档路径") path: str = Field(default="", description="文档路径")
@@ -13,18 +13,18 @@ class Q_A(BaseModel):
question: str = Field(default="", min_length=1,description="问题") question: str = Field(default="", min_length=1,description="问题")
answer: str = Field(default="", min_length=1, description="答案") answer: str = Field(default="", min_length=1, description="答案")
class dataset_item(BaseModel): class DatasetItem(BaseModel):
id: Optional[int] = Field(default=None, description="数据集项ID") id: Optional[int] = Field(default=None, description="数据集项ID")
message: list[Q_A] = Field(description="数据集项内容") message: list[Q_A] = Field(description="数据集项内容")
class dataset(BaseModel): class Dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID") id: Optional[int] = Field(default=None, description="数据集ID")
name: str = Field(default="", description="数据集名称") name: str = Field(default="", description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID") model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档") source_doc: Optional[Doc] = Field(default=None, description="数据集来源文档")
description: Optional[str] = Field(default="", description="数据集描述") description: Optional[str] = Field(default="", description="数据集描述")
created_at: datetime = Field( created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间" description="记录创建时间"
) )
dataset_items: list[dataset_item] = Field(default_factory=list, description="数据集项列表") dataset_items: list[DatasetItem] = Field(default_factory=list, description="数据集项列表")

View File

@@ -12,40 +12,36 @@ class APIProvider(SQLModel, table=True):
description="记录创建时间" description="记录创建时间"
) )
class LLMParameters(SQLModel):
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
seed: Optional[int] = None
class TokensUsage(SQLModel):
prompt_tokens: int = Field(default=0, description="提示词使用的token数量")
completion_tokens: int = Field(default=0, description="完成部分使用的token数量")
prompt_cache_hit_tokens: Optional[int] = Field(default=None, description="缓存命中token数量")
prompt_cache_miss_tokens: Optional[int] = Field(default=None, description="缓存未命中token数量")
class LLMResponse(SQLModel): class LLMResponse(SQLModel):
timestamp: datetime = Field( timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), default_factory=lambda: datetime.now(timezone.utc),
description="响应的时间戳" description="响应的时间戳"
) )
response_id: str = Field(..., description="响应的唯一ID") response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: dict = Field(default_factory=lambda: { tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
"prompt_tokens": 0, content: str = Field(default_factory=dict, description="API响应的内容")
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: dict = Field(default_factory=lambda: { llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")
"temperature": None,
"max_tokens": None,
"top_p": None,
"frequency_penalty": None,
"presence_penalty": None,
"seed": None
}, description="API的生成参数")
class LLMRequest(SQLModel): class LLMRequest(SQLModel):
prompt: str = Field(..., description="发送给API的提示词") prompt: str = Field(..., description="发送给API的提示词")
provider_id: int = Field(foreign_key="apiprovider.id") api_provider: APIProvider = Field(..., description="API提供者的信息")
provider: APIProvider = Relationship()
format: Optional[str] = Field(default=None, description="API响应的格式") format: Optional[str] = Field(default=None, description="API响应的格式")
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表") response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息") error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
total_tokens_usage: dict = Field(default_factory=lambda: { total_tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")

View File

@@ -1,6 +1,7 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from typing import Optional from typing import Optional
from datetime import datetime, timezone from datetime import datetime, timezone
from langchain.prompts import PromptTemplate
class promptTempleta(BaseModel): class promptTempleta(BaseModel):
id: Optional[int] = Field(default=None, description="模板ID") id: Optional[int] = Field(default=None, description="模板ID")
@@ -11,3 +12,9 @@ class promptTempleta(BaseModel):
default_factory=lambda: datetime.now(timezone.utc).isoformat(), default_factory=lambda: datetime.now(timezone.utc).isoformat(),
description="记录创建时间" description="记录创建时间"
) )
@field_validator('content')
def validate_content(cls, value):
if not "document_slice" in PromptTemplate.from_template(value).input_variables:
raise ValueError("模板变量缺少 document_slice")
return value

View File

@@ -1,4 +1,5 @@
from .parse_markdown import parse_markdown from .parse_markdown import *
from .scan_doc_dir import * from .document import *
from .json_example import generate_example_json from .json_example import generate_json_example
from .model import * from .port import *
from .reasoning import call_openai_api

View File

@@ -1,19 +1,19 @@
from typing import List from typing import List
from schema.dataset import dataset, dataset_item, Q_A from schema.dataset import Dataset, DatasetItem, Q_A
import json import json
def convert_json_to_dataset(json_data: List[dict]) -> dataset: def convert_json_to_dataset(json_data: List[dict]) -> Dataset:
# 将JSON数据转换为dataset格式 # 将JSON数据转换为dataset格式
dataset_items = [] dataset_items = []
item_id = 1 # 自增ID计数器 item_id = 1 # 自增ID计数器
for item in json_data: for item in json_data:
qa = Q_A(question=item["question"], answer=item["answer"]) qa = Q_A(question=item["question"], answer=item["answer"])
dataset_item_obj = dataset_item(id=item_id, message=[qa]) dataset_item_obj = DatasetItem(id=item_id, message=[qa])
dataset_items.append(dataset_item_obj) dataset_items.append(dataset_item_obj)
item_id += 1 # ID自增 item_id += 1 # ID自增
# 创建dataset对象 # 创建dataset对象
result_dataset = dataset( result_dataset = Dataset(
name="Converted Dataset", name="Converted Dataset",
model_id=None, model_id=None,
description="Dataset converted from JSON", description="Dataset converted from JSON",

View File

@@ -4,7 +4,7 @@ from pathlib import Path
# 添加项目根目录到sys.path # 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import doc from schema import Doc
def scan_docs_directory(workdir: str): def scan_docs_directory(workdir: str):
docs_dir = os.path.join(workdir, "docs") docs_dir = os.path.join(workdir, "docs")
@@ -21,7 +21,7 @@ def scan_docs_directory(workdir: str):
for file in files: for file in files:
if file.endswith(".md"): if file.endswith(".md"):
markdown_files.append(os.path.join(root, file)) markdown_files.append(os.path.join(root, file))
to_return.append(doc(name=doc_name, path=doc_path, markdown_files=markdown_files)) to_return.append(Doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
return to_return return to_return

View File

@@ -1,32 +1,29 @@
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
from typing import Any, Dict, List, Optional, get_args, get_origin from typing import Any, Dict, List, Optional, Union, get_args, get_origin
import json import json
from datetime import datetime, date from datetime import datetime, date
def generate_example_json(model: type[BaseModel]) -> str: def generate_json_example(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
""" """
根据 Pydantic V2 模型生成示例 JSON 数据结构。 根据 Pydantic V2 模型生成示例 JSON 数据结构。
""" """
def _generate_example(field_type: Any) -> Any: def _generate_example(field_type: Any) -> Any:
origin = get_origin(field_type) origin = get_origin(field_type)
args = get_args(field_type) args = get_args(field_type)
if origin is list or origin is List: if origin is list or origin is List:
if args: # 生成多个元素,这里生成 3 个
return [_generate_example(args[0])] result = [_generate_example(args[0]) for _ in range(list_length)] if args else []
else: result.append("......")
return [] return result
elif origin is dict or origin is Dict: elif origin is dict or origin is Dict:
if len(args) == 2 and args[0] is str: if len(args) == 2:
return {"key": _generate_example(args[1])} return {"key": _generate_example(args[1])}
else:
return {} return {}
elif origin is Optional or origin is type(None): elif origin is Union:
if args: # 处理 Optional 类型Union[T, None]
return _generate_example(args[0]) non_none_args = [arg for arg in args if arg is not type(None)]
else: return _generate_example(non_none_args[0]) if non_none_args else None
return None
elif field_type is str: elif field_type is str:
return "string" return "string"
elif field_type is int: elif field_type is int:
@@ -39,13 +36,22 @@ def generate_example_json(model: type[BaseModel]) -> str:
return datetime.now().isoformat() return datetime.now().isoformat()
elif field_type is date: elif field_type is date:
return date.today().isoformat() return date.today().isoformat()
elif issubclass(field_type, BaseModel): elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
return generate_example_json(field_type) return json.loads(generate_json_example(field_type, include_optional))
else: else:
return "unknown" # 对于未知类型返回 "unknown" # 处理直接类型注解(非泛型)
if field_type is type(None):
return None
try:
if issubclass(field_type, BaseModel):
return json.loads(generate_json_example(field_type, include_optional))
except TypeError:
pass
return "unknown"
example_data = {} example_data = {}
for field_name, field in model.model_fields.items(): for field_name, field in model.model_fields.items():
if include_optional or not isinstance(field.default, type(None)):
example_data[field_name] = _generate_example(field.annotation) example_data[field_name] = _generate_example(field.annotation)
return json.dumps(example_data, indent=2, default=str) return json.dumps(example_data, indent=2, default=str)
@@ -55,9 +61,7 @@ if __name__ == "__main__":
from pathlib import Path from pathlib import Path
# 添加项目根目录到sys.path # 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import Q_A from schema import Dataset
class Q_A_list(BaseModel):
Q_As: List[Q_A]
print("示例 JSON:") print("示例 JSON:")
print(generate_example_json(Q_A_list)) print(generate_json_example(Dataset))

15
tools/port.py Normal file
View File

@@ -0,0 +1,15 @@
import socket
# 启动 TensorBoard 子进程前添加端口检测逻辑
def find_available_port(start_port):
port = start_port
while True:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(('localhost', port)) != 0: # 端口未被占用
return port
port += 1 # 如果端口被占用,尝试下一个端口
if __name__ == "__main__":
start_port = 6006 # 起始端口号
available_port = find_available_port(start_port)
print(f"Available port: {available_port}")

123
tools/reasoning.py Normal file
View File

@@ -0,0 +1,123 @@
import sys
import asyncio
import openai
from pathlib import Path
from datetime import datetime, timezone
from typing import Optional
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters
async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest:
start_time = datetime.now(timezone.utc)
client = openai.AsyncOpenAI(
api_key=llm_request.api_provider.api_key,
base_url=llm_request.api_provider.base_url
)
total_duration = 0.0
total_tokens = TokensUsage()
prompt = llm_request.prompt
round_start = datetime.now(timezone.utc)
if llm_request.format:
prompt += "\n请以JSON格式返回结果" + llm_request.format
for i in range(rounds):
round_start = datetime.now(timezone.utc)
try:
messages = [{"role": "user", "content": prompt}]
create_args = {
"model": llm_request.api_provider.model_id,
"messages": messages,
"temperature": llm_parameters.temperature if llm_parameters else None,
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
"top_p": llm_parameters.top_p if llm_parameters else None,
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
"seed": llm_parameters.seed if llm_parameters else None
} # 处理format参数
if llm_request.format:
create_args["response_format"] = {"type": "json_object"}
response = await client.chat.completions.create(**create_args)
round_end = datetime.now(timezone.utc)
duration = (round_end - round_start).total_seconds()
total_duration += duration
# 处理可能不存在的缓存token字段
usage = response.usage
cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None)
cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None)
tokens_usage = TokensUsage(
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
prompt_cache_hit_tokens=cache_hit,
prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens
)
# 累加总token使用量
total_tokens.prompt_tokens += tokens_usage.prompt_tokens
total_tokens.completion_tokens += tokens_usage.completion_tokens
if tokens_usage.prompt_cache_hit_tokens:
total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens
if tokens_usage.prompt_cache_miss_tokens:
total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens
llm_request.response.append(LLMResponse(
response_id=response.id,
tokens_usage=tokens_usage,
content = response.choices[0].message.content,
total_duration=duration,
llm_parameters=llm_parameters
))
except Exception as e:
round_end = datetime.now(timezone.utc)
duration = (round_end - round_start).total_seconds()
total_duration += duration
llm_request.response.append(LLMResponse(
response_id=f"error-round-{i+1}",
content={"error": str(e)},
total_duration=duration
))
if llm_request.error is None:
llm_request.error = []
llm_request.error.append(str(e))
# 更新总耗时和总token使用量
llm_request.total_duration = total_duration
llm_request.total_tokens_usage = total_tokens
return llm_request
if __name__ == "__main__":
from json_example import generate_json_example
from sqlmodel import Session, select
from global_var import get_sql_engine, init_global_var
from schema import DatasetItem
init_global_var("workdir")
api_state = "1 deepseek-chat"
with Session(get_sql_engine()) as session:
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
llm_request = LLMRequest(
prompt="测试,随便说点什么",
api_provider=api_provider,
format=generate_json_example(DatasetItem)
)
# # 单次调用示例
# result = asyncio.run(call_openai_api(llm_request))
# print(f"\n单次调用结果 - 响应数量: {len(result.response)}")
# for i, resp in enumerate(result.response, 1):
# print(f"响应{i}: {resp.response_content}")
# 多次调用示例
params = LLMParameters(temperature=0.7, max_tokens=100)
result = asyncio.run(call_openai_api(llm_request, 3,params))
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
for i, resp in enumerate(result.response, 1):
print(f"响应{i}: {resp.content}")

1
train/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .model import *

View File

@@ -31,8 +31,18 @@ def formatting_prompts(examples,tokenizer):
return {"text": texts} return {"text": texts}
def train_model(model, tokenizer, dataset, output_dir, learning_rate, def train_model(
per_device_train_batch_size, epoch, save_steps, lora_rank): model,
tokenizer,
dataset: list,
train_dir: str,
learning_rate: float,
per_device_train_batch_size: int,
epoch: int,
save_steps: int,
lora_rank: int,
trainer_callback=None
) -> None:
# 模型配置参数 # 模型配置参数
dtype = None # 数据类型None表示自动选择 dtype = None # 数据类型None表示自动选择
load_in_4bit = False # 使用4bit量化加载模型以节省显存 load_in_4bit = False # 使用4bit量化加载模型以节省显存
@@ -75,8 +85,8 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
chat_template="qwen-2.5", chat_template="qwen-2.5",
) )
dataset = HFDataset.from_list(dataset) train_dataset = HFDataset.from_list(dataset)
dataset = dataset.map(formatting_prompts, train_dataset = train_dataset.map(formatting_prompts,
fn_kwargs={"tokenizer": tokenizer}, fn_kwargs={"tokenizer": tokenizer},
batched=True) batched=True)
@@ -84,7 +94,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
trainer = SFTTrainer( trainer = SFTTrainer(
model=model, # 待训练的模型 model=model, # 待训练的模型
tokenizer=tokenizer, # 分词器 tokenizer=tokenizer, # 分词器
train_dataset=dataset, # 训练数据集 train_dataset=train_dataset, # 训练数据集
dataset_text_field="text", # 数据集字段的名称 dataset_text_field="text", # 数据集字段的名称
max_seq_length=model.max_seq_length, # 最大序列长度 max_seq_length=model.max_seq_length, # 最大序列长度
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
@@ -96,20 +106,24 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率 warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
learning_rate=learning_rate, # 学习率 learning_rate=learning_rate, # 学习率
lr_scheduler_type="linear", # 线性学习率调度器 lr_scheduler_type="linear", # 线性学习率调度器
max_steps=int(epoch * len(dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据 max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16 fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
bf16=is_bfloat16_supported(), # 如果支持则使用bf16 bf16=is_bfloat16_supported(), # 如果支持则使用bf16
logging_steps=1, # 每1步记录一次日志 logging_steps=1, # 每1步记录一次日志
optim="adamw_8bit", # 使用8位AdamW优化器节省显存几乎不影响训练效果 optim="adamw_8bit", # 使用8位AdamW优化器节省显存几乎不影响训练效果
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合 weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
seed=114514, # 随机数种子 seed=114514, # 随机数种子
output_dir=output_dir, # 保存模型检查点和训练日志 output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志
save_strategy="steps", # 按步保存中间权重 save_strategy="steps", # 按步保存中间权重
save_steps=save_steps, # 使用动态传入的保存步数 save_steps=save_steps, # 使用动态传入的保存步数
# report_to="tensorboard", # 将信息输出到tensorboard logging_dir=train_dir + "/logs", # 日志文件存储路径
report_to="tensorboard", # 使用TensorBoard记录日志
), ),
) )
if trainer_callback is not None:
trainer.add_callback(trainer_callback)
trainer = train_on_responses_only( trainer = train_on_responses_only(
trainer, trainer,
instruction_part = "<|im_start|>user\n", instruction_part = "<|im_start|>user\n",

48
train/save_model.py Normal file
View File

@@ -0,0 +1,48 @@
import os
from global_var import get_model, get_tokenizer
def save_model_to_dir(save_model_name, models_dir, model, tokenizer, save_method="default"):
"""
保存模型到指定目录
:param save_model_name: 要保存的模型名称
:param models_dir: 模型保存的基础目录
:param model: 要保存的模型
:param tokenizer: 要保存的tokenizer
:param save_method: 保存模式选项
- "default": 默认保存方式
- "merged_16bit": 合并为16位
- "merged_4bit": 合并为4位
- "lora": 仅LoRA适配器
- "gguf": 保存为GGUF格式
- "gguf_q4_k_m": 保存为q4_k_m GGUF格式
- "gguf_f16": 保存为16位GGUF格式
:return: 保存结果消息或错误信息
"""
try:
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name)
os.makedirs(save_path, exist_ok=True)
if save_method == "default":
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
elif save_method == "merged_16bit":
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_16bit")
elif save_method == "merged_4bit":
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_4bit_forced")
elif save_method == "lora":
model.save_pretrained_merged(save_path, tokenizer, save_method="lora")
elif save_method == "gguf":
model.save_pretrained_gguf(save_path, tokenizer)
elif save_method == "gguf_q4_k_m":
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="q4_k_m")
elif save_method == "gguf_f16":
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="f16")
else:
return f"不支持的保存模式: {save_method}"
return f"模型已保存到 {save_path} (模式: {save_method})"
except Exception as e:
return f"保存模型时出错: {str(e)}"