Compare commits

125 Commits

Author SHA1 Message Date
carry
7a4388c928 featmodel): 添加保存模式选择功能
在模型管理页面中新增保存模式选择功能,用户可以通过下拉菜单选择不同的保存模式(如默认、合并16位、合并4位等)。同时,将保存模型的逻辑抽离到独立的`save_model.py`文件中,以提高代码的可维护性和复用性。
2025-04-23 14:09:02 +08:00
carry
6338706967 feat: 修改应用启动方式
- 将 app.launch() 修改为 app.launch(server_name="0.0.0.0")
- 此修改使应用能够监听所有网络接口,提高可用性
2025-04-22 19:08:23 +08:00
carry
3718c75cee feat(frontend): 添加数据集生成页面的处理进度显示
- 在处理文档片段时添加进度条,提升用户体验
- 优化代码格式,调整缩进和空行
2025-04-22 00:14:16 +08:00
carry
905658073a docs(README): 更新项目文档
- 添加项目概述、核心功能、技术架构等详细信息
- 插入系统架构图和技术栈说明
- 细化功能模块描述,包括模型管理、推理、微调等
- 增加QLoRA原理和参数配置说明
- 补充快速开始指南和许可证信息
- 优化文档结构,增强可读性和完整性
2025-04-21 14:28:20 +08:00
carry
9806334517 fix(train_page): 捕获训练过程中的异常并终止 TensorBoard 进程
- 在训练过程中添加异常捕获,将异常信息转换为 gr.Error 抛出
- 确保在发生异常时也能终止 TensorBoard 子进程
2025-04-20 21:40:46 +08:00
carry
0a4efa5641 feat(dataset): 添加数据集生成功能
- 新增数据集生成页面和相关逻辑
- 实现数据集名称重复性检查
- 添加数据集对象创建和保存功能
- 优化文档处理和提示模板应用
- 增加错误处理和数据解析
2025-04-20 21:25:51 +08:00
carry
994d600221 refactor(frontend): 调整 TensorBoard iframe 高度
- 将 TensorBoard iframe 的高度从 500px 修改为 1000px
- 此修改旨在提供更宽敞的显示区域,改善用户体验
2025-04-20 21:25:37 +08:00
carry
d5774eee0c feat(db): 添加数据集导出功能
- 新增 save_dataset 函数,用于将 TinyDB 中的数据集保存为单独的 JSON 文件
- 更新 db/__init__.py,添加 get_dataset_tinydb 函数的引用
- 修改 db/dataset_store.py,实现 save_dataset 函数并添加相关逻辑
2025-04-20 19:44:11 +08:00
carry
87501c9353 fix(global_var): 移除全局变量设置函数set_datasets
- 删除了 global_var.py 文件中的 set_datasets 函数
- 该函数用于设置全局变量 _datasets,但似乎已不再使用
2025-04-20 19:14:00 +08:00
carry
5fc3b4950b refactor(schema): 修改 LLMResponse 中 API 响应内容的字段名称
- 将 LLMResponse 类中的 response_content 字段重命名为 content
- 更新字段类型从 dict 改为 str,以更准确地表示响应内容
- 在 reasoning.py 中相应地修改了调用 LLMResponse 时的参数
2025-04-20 18:40:51 +08:00
carry
c28e4819d9 refactor(frontend/tools): 重命名生成示例 JSON 数据结构的函数
- 将 generate_example_json 函数重命名为 generate_json_example
- 更新相关文件中的函数调用和引用
- 此更改旨在使函数名称更具描述性和一致性
2025-04-20 16:11:36 +08:00
carry
e7cf51d662 refactor(frontend): 重构数据集生成页面
- 调整页面布局,优化用户交互流程
- 新增数据集名称输入框
- 使用 LLMRequest 和 LLMResponse 模型处理请求和响应
- 添加 generate_example_json 函数用于格式化生成数据
- 改进数据集生成逻辑,支持多轮次生成
2025-04-20 16:10:08 +08:00
carry
4c9caff668 refactor(schema): 重构数据集和文档类的命名
- 将 dataset、dataset_item 和 doc 类的首字母大写,以符合 Python 类命名惯例
- 更新相关模块中的导入和引用,以适应新的类名
- 此更改不影响功能,仅提高了代码的一致性和可读性
2025-04-20 01:46:15 +08:00
carry
9236f49b36 feat(frontend): 添加文档切片和并发数功能
- 新增并发数输入框
- 实现文档切片处理
- 更新生成数据集的逻辑,支持并发处理
2025-04-20 01:40:48 +08:00
carry
868fcd45ba refactor(project): 重构项目文件组织结构
- 修改模型管理和训练页面的导入路径
- 更新 main.py 中的导入模块
- 调整 tools 包的内容,移除 model 模块
- 新建 train 包,包含 model 模块
- 优化 __init__.py 文件,简化导入语句
2025-04-19 21:49:19 +08:00
carry
5a21c8598a feat(tools): 支持 OpenAI API 的 JSON 格式返回结果
- 在 call_openai_api 函数中添加对 JSON 格式返回结果的支持
- 增加 llm_request.format 参数处理,将用户 prompt 与格式要求合并
- 添加 response_format 参数到 OpenAI API 请求
- 更新示例,使用 JSON 格式返回结果
2025-04-19 21:10:22 +08:00
carry
1e829c9268 feat(tools): 优化 JSON 示例生成函数
- 增加 include_optional 参数,决定是否包含可选字段
- 添加 list_length 参数,用于控制列表字段的示例长度
- 在列表示例中添加省略标记,更直观展示多元素列表
- 优化字典字段的示例生成逻辑
2025-04-19 21:07:00 +08:00
carry
9fc3ab904b feat(frontend): 实现了固定参数的注入 2025-04-19 17:48:45 +08:00
carry
d827f9758f fix(frontend): 修复dataframe_value返回值只有一列的bug 2025-04-19 17:30:10 +08:00
carry
ff1e9731bc fix(tools): 修复call_openai_api的导出 2025-04-19 17:13:19 +08:00
carry
90fde639ff feat(tools): 增加 OpenAI API 多轮调用功能
- 在 call_openai_api 函数中添加 rounds 参数,支持多次调用
- 累加每次调用的耗时和 token 使用情况
- 将多次调用的结果存储在 LLMRequest 对象的 response 列表中
- 更新函数返回类型,返回包含多次调用信息的 LLMRequest 对象
- 优化错误处理,记录每轮调用的错误信息
2025-04-19 17:02:00 +08:00
carry
5fc90903fb feat(tools): 添加 reasoning.py 工具模块
- 新增 reasoning.py 文件,实现与 OpenAI API 的交互
- 添加 call_openai_api 函数,用于发送请求并处理响应
- 支持可选的 LLMParameters 参数,以定制化请求
- 处理 API 响应中的 tokens 使用情况
- 提供错误处理和缓存 token 字段的处理
2025-04-19 16:53:48 +08:00
carry
81c2ad4a2d refactor(schema): 重构数据模型以提高可维护性和可扩展性
- 新增 LLMParameters 类以统一处理 LLM 参数
- 新增 TokensUsage 类以统一处理 token 使用信息
- 更新 LLMResponse 和 LLMRequest 类,使用新的 LLMParameters 和 TokensUsage 类
- 优化数据模型结构,提高代码的可读性和可维护性
2025-04-19 16:39:18 +08:00
carry
314434951d feat(frontend): 实现了文档、提示和 API 提供商的获取逻辑 2025-04-19 14:47:01 +08:00
carry
e16882953d fix(tools): 修复了optional字段无法被解析的问题 2025-04-18 22:00:51 +08:00
carry
86bcf90c66 feat(frontend): 添加数据集生成轮次控制功能
- 在数据集生成页面添加"生成轮次"输入框,支持设置生成轮数
- 更新生成逻辑,根据设置的轮次进行多次生成
2025-04-18 15:47:37 +08:00
carry
961a017f19 refactor(frontend): 调整数据集生成页面布局并优化代码结构
- 使用 gr.Column(scale=1) 和 gr.Column(scale=2) 调整列宽比例
- 移除多余的空行和缩进,提高代码可读性
- 优化变量声明和组件创建的顺序,使页面结构更清晰
2025-04-18 15:40:15 +08:00
carry
5a386d6401 feat(dataset_generate_page): 添加 API 选择功能
- 在数据集生成页面添加 API 选择下拉框
- 实现 API 选择变更时的处理逻辑
- 更新数据集生成函数,增加 API 选择参数
- 优化页面布局和代码结构
2025-04-18 15:23:33 +08:00
carry
feaea1fb64 refactor(db): 重命名数据库引擎加载函数
- 将 get_sqlite_engine 函数重命名为 load_sqlite_engine
- 更新了相关模块中的导入和调用
- 这个改动是为了更好地反映函数的实际功能,提高代码可读性
2025-04-18 15:16:29 +08:00
carry
7242a2ce03 feat(frontend): 添加生成数据集进度条功能并优化了界面布局 2025-04-18 15:07:46 +08:00
carry
db6e2271dc fix(frontend): 修复 prompt_dropdown 变化时,dataframe没有相应的变化
- 将 prompt_dropdown 变化时的输出从 prompt_state 修改为 [prompt_state, variables_dataframe]
- 这个改动可能会在 prompt 变化时同时更新变量数据框
2025-04-18 14:03:26 +08:00
carry
d764537143 feat(dataset_generate_page): 更新数据集生成页面功能
- 添加模板变量列表展示和编辑功能
- 实现模板选择后动态更新变量列表
- 增加生成数据集按钮和相关逻辑
- 优化页面布局和交互
2025-04-16 12:39:48 +08:00
carry
8c35a38c47 feat(frontend): 更新模板选择功能
- 在模板选择变更时,获取所选模板的详细信息
- 创建 PromptTemplate 对象并获取输入变量列表
- 此更新为后续的模板编辑功能做准备
2025-04-15 21:31:50 +08:00
carry
7ee751c88f fix(frontend): 移除文档生成页面的冗余事件绑定代码
- 删除了原有的简单事件绑定逻辑,这些逻辑仅将输入值赋给状态变量
- 为后续添加更复杂的文档选择更改事件处理函数做准备
2025-04-15 20:44:26 +08:00
carry
b715b36a5f feat(frontend): 更新数据集生成页面并添加独立运行功能
- 重构导入路径,使用绝对路径替换相对路径
- 新增文档和模板选择的事件处理函数
- 添加独立运行数据集生成页面的功能
- 优化代码结构,提高可读性和可维护性
2025-04-15 17:13:52 +08:00
carry
8023233bb2 feat(prompt): 增加模板变量有效性检查
- 在 promptTempleta 模型中添加字段验证器
- 验证模板内容是否包含必要的 document_slice 变量
- 如果缺少该变量,抛出 ValueError 异常
2025-04-15 16:54:17 +08:00
carry
2a86b3b5b0 fix(db): 初始化 prompt store 时插入第一条记录的 ID 从 0 改为 1
- 将初始化时插入的第一条记录的 ID 从 0 改为 1
- 修正了文档节选的变量名,从 {content} 改为 {document_slice}
2025-04-15 16:45:12 +08:00
carry
ca1505304e fix(tools): 更新 tools/__init__.py 中的导入语句
- 将 from .doc import * 改为 from .document import *
- 这个修改统一了文档处理模块的命名,提高了代码的一致性和可读性
2025-04-15 16:31:55 +08:00
carry
df9260e918 fix(db): 修复初始提示词的变量花括号的空格问题 2025-04-15 16:13:57 +08:00
carry
df9aba0c6e refactor(tools): 重命名模块并更新导入
- 将 scan_doc_dir.py 重命名为 document.py
- 将 socket.py 重命名为 port.py
- 更新 __init__.py 中的导入语句
- 在 port.py 中添加测试代码,用于查找可用端口
2025-04-15 15:47:44 +08:00
carry
6b87dcb58f refactor(frontend): 重构数据集生成页面的变量命名逻辑
- 将 prompt_choices 变量重命名为 prompt_list,以更准确地反映其内容
- 更新相关代码中对这两个变量的引用,以保持一致性
2025-04-15 15:40:24 +08:00
carry
d0aebd17fa refactor(global_var): 重构全局变量管理
- 移除了 _docs 全局变量
- 更新了 get_docs() 函数,使其在每次调用时重新扫描文档目录
- 优化了全局变量初始化逻辑
2025-04-15 15:25:44 +08:00
carry
d9abf08184 fix(frontend): 修复表格选择事件的行数据获取问题
- 在 prompt_manage_page 和 setting_page 中更新了 select_record 函数
- 使用 DataFrame.iloc 方法获取选中行的数据,并转换为列表
- 添加了将第一列数据转换为整数的逻辑
- 更新了表格选择事件的参数,增加了输入和输出参数
- 将 gradio 版本升级到 5.25.0
2025-04-15 15:10:15 +08:00
carry
a27a1ab079 refactor(frontend): 重构训练页面布局并优化用户界面
- 调整数据集下拉框布局位置
- 新增超参数输入组件
- 修改训练日志输出框标签为"训练状态"
- 添加 TensorBoard 可视化 iframe 显示框
2025-04-15 00:12:09 +08:00
carry
aa758e3c2a feat(train_page): 添加 TensorBoard 可视化
- 在训练页面添加 TensorBoard iframe 显示框
- 实现动态生成 TensorBoard iframe 功能
- 更新训练按钮点击事件,同时更新 TensorBoard iframe
2025-04-14 23:28:43 +08:00
carry
664944f0c5 feat(frontend): 优化 TensorBoard 端口占用问题
- 新增端口检测逻辑,动态分配可用端口
- 修改 TensorBoard 启动过程,使用动态分配的端口
- 添加 socket 模块,用于端口检测
2025-04-14 17:06:44 +08:00
carry
9298438f98 feat(train_page): 启动 TensorBoard 进程并确保训练结束后终止
- 在训练页面中添加 TensorBoard 进程启动代码
- 创建日志目录并启动 TensorBoard 子进程
- 在训练结束后终止 TensorBoard 进程
2025-04-14 17:00:33 +08:00
carry
4f7926aec6 feat(train_page): 实现训练目录自动递增功能
- 在 training 文件夹下创建递增的目录结构
- 确保 training 文件夹存在
- 扫描现有目录,生成下一个可用的目录编号
- 更新训练模型函数,使用新的训练目录
2025-04-14 16:46:29 +08:00
carry
148f4afb25 fix(main): 修复unsloth没有最先导入的问题,删除了重复的导入语句 2025-04-14 16:31:47 +08:00
carry
11a3039775 fix(train_page): 修正模型训练保存路径 2025-04-14 16:31:00 +08:00
carry
a4289815ba build: 添加 tensorboard 依赖
- 在 requirements.txt 中添加 tensorboard>=2.19.0
- 此改动增加了 tensorboard 作为项目的新依赖项
2025-04-14 16:30:39 +08:00
carry
088067d335 train: 更新模型训练功能和日志记录方式
- 修改训练目录结构,将检查点和日志分开保存
- 添加 TensorBoard 日志记录支持
- 移除自定义 LossCallback 类,简化训练流程
- 更新训练参数和回调机制,提高代码可读性
- 在 requirements.txt 中添加 tensorboardX 依赖
2025-04-14 16:19:37 +08:00
carry
9fb31c46c8 feat(train): 添加训练过程中的日志记录和 loss 可视化功能
- 新增 LossCallback 类,用于在训练过程中记录 loss 数据
- 在训练模型函数中添加回调,实现日志记录和 loss 可视化
- 优化训练过程中的输出信息,增加当前步数和 loss 值的打印
2025-04-14 15:18:14 +08:00
carry
4f09823123 refactor(tools): 优化 train_model 函数定义
- 添加类型注解,提高代码可读性和维护性
- 使用多行格式定义函数参数,提升代码格式美观
2025-04-14 14:28:36 +08:00
carry
1a2ca3e244 refactor(train): 重构训练功能并移至新模块
- 将训练逻辑从 train_page.py 移至 tools/model.py
- 新增 train_model 函数,包含完整的训练流程
- 更新 train_page.py 中的回调函数,使用新的训练函数
- 移除了 train_page.py 中未使用的导入
2025-04-14 14:17:04 +08:00
carry
bb1d8fbd38 feat(train_page): 添加训练 Loss 曲线显示功能
- 在训练页面添加了 Loss 曲线图表
- 实现了 GradioLossCallback 类用于记录训练过程中的 Loss 数据
- 修改了训练函数,通过回调函数收集 Loss 信息并更新图表
- 优化了训练函数的返回值结构,支持同时返回文本日志和 Loss 数据
2025-04-13 21:49:43 +08:00
carry
4558929c52 fix: 调整了import的顺序,让unsloth最先import以提高性能 2025-04-13 21:35:47 +08:00
carry
0722748997 feat(train_page): 添加 LoRA 秩动态输入功能
- 在训练页面新增 LoRA 秩输入框,使用户可以动态设置 LoRA 秩
- 更新训练模型函数,添加 LoRA 秩参数并将其用于模型配置
- 保留原有功能,仅增加 LoRA 秩相关配置
2025-04-13 21:12:02 +08:00
carry
e08f0059bb feat(train_page): 优化训练过程以专注于响应生成
- 引入 train_on_responses_only 函数,用于优化训练过程
- 设置 instruction_part 和 response_part 参数,以适应特定的对话格式
- 此修改旨在提高模型在生成响应方面的性能和效率
2025-04-13 21:05:14 +08:00
carry
6d1fecbdac build: 更新依赖版本并添加新依赖
- 将 unsloth 的最低版本要求从 2025.3.9 提升到 2025.3.19
- 新增 sqlmodel 依赖,版本不低于 0.0.24
- 新增 jinja2 依赖,版本不低于 3.1.0
2025-04-13 20:29:33 +08:00
carry
79d3eb153c refactor(train_page): 优化训练页面布局和功能
- 移除了 max_steps_input 组件,减少不必要的输入项
- 将 per_device_train_batch_size_input 和 epoch_input 的标签简化为 "batch size" 和 "epoch"
- 新增 save_steps_input 组件,用于设置保存步数
- 修改 train_model 函数,移除了 max_steps 参数
- 更新了 trainer.train() 方法的调用,设置 resume_from_checkpoint=False
2025-04-13 01:56:10 +08:00
carry
80dae7c6e2 fix(global_var):修复_workdir非全局变量的bug 2025-04-13 01:52:05 +08:00
carry
2d39b91764 feat(train_page): 添加模型训练超参数配置功能
- 新增学习率、批次大小、最大训练步数等超参数输入组件
- 实现超参数在训练过程中的动态应用
- 调整训练参数以适应不同硬件环境
- 优化训练过程,支持按步数保存模型
2025-04-13 01:04:27 +08:00
carry
5094febcb4 refactor(global_var): 重构全局变量管理并添加工作目录功能
- 添加 _workdir 全局变量以存储工作目录路径
- 在 init_global_var 函数中初始化 _workdir
- 新增 get_workdir 函数以获取工作目录路径
- 调整全局变量的定义和初始化顺序
2025-04-13 00:54:55 +08:00
carry
eeee68dbd1 chore: 更新 .gitignore 文件
- 在 .gitignore 中添加 test.py 文件,避免测试代码被版本控制
2025-04-12 18:42:36 +08:00
carry
539e14d39c feat(frontend): 完成了前端微调的代码逻辑 2025-04-12 18:42:22 +08:00
carry
9784f2aed3 fix(tools): 修正__init__.py使得model.py正确导入 2025-04-12 01:16:07 +08:00
carry
611904cef9 feat(frontend): 添加数据集选择功能到训练页面
- 在 train_page.py 中添加数据集选择下拉框
- 从全局变量中获取数据集列表并设置初始值
- 添加交互性和自定义值支持
2025-04-11 19:43:34 +08:00
carry
8a9a080745 refactor(tools): 移除未使用的导入语句
移除了 tools/model.py 文件中未使用的 get_chat_template 导入语句。这个修改提高了代码的可读性和维护性。
2025-04-11 19:43:19 +08:00
carry
a23ad88769 fix(frontend): 修复删除提示功能中的数据库连接错误
- 将 prompt_store 更改为 get_prompt_store(),以解决数据库连接未建立的问题
- 优化了删除提示功能的代码,提高了系统稳定性
2025-04-11 18:53:17 +08:00
carry
83427aaaba feat(frontend): 增加超参数设置并优化聊天页面布局
- 在聊天页面添加了超参数输入框,包括最大生成长度、温度、Top-p 采样和重复惩罚
- 优化了聊天框的布局,使用 gr.Row() 和 gr.Column() 实现了更合理的界面结构
- 更新了 bot 函数,支持根据用户输入的超参数进行文本生成
- 修复了一些代码格式问题,提高了代码的可读性
2025-04-11 18:48:13 +08:00
carry
61672021ef fix(frontend): 修复聊天页面并的流式回复
- 导入 Thread 和 TextIteratorStreamer 以支持流式生成
- 重新设计 user 和 bot 函数,优化对话历史处理
- 添加异常处理和错误信息显示
- 改进模型和分词器的加载逻辑
- 优化聊天页面布局和交互
2025-04-11 18:33:31 +08:00
carry
fb6157af05 feat(frontend): 初步实现聊天页面的智能回复功能 2025-04-11 18:08:38 +08:00
carry
f655936741 refactor(global_var): 重构全局变量初始化方法
- 新增 init_global_var 函数,用于统一初始化所有全局变量
- 修改 get_prompt_store、get_sql_engine、get_docs 和 get_datasets 函数,使用新的全局变量初始化逻辑
- 更新 main.py 中的代码,使用新的 init_global_var 函数替代原有的单独初始化方法
2025-04-11 18:08:16 +08:00
carry
ab7897351a fix(global_var): 修复全局变量多文件多副本的不统一问题 2025-04-11 18:04:42 +08:00
carry
216bfe39ae feat(tools): 添加格式化对话数据的函数
- 新增 formatting_prompts_func 函数,用于格式化对话数据
- 该函数将问题和答案组合成对话形式,并使用 tokenizer.apply_chat_template 进行格式化
- 更新 imports,添加了 unsloth.chat_templates 模块
2025-04-11 17:56:46 +08:00
carry
0fa2b51a79 refactor(frontend): 优化模型管理页面的交互和显示
- 将状态输出从 Textbox 改为 Label 组件,提高用户体验
- 添加 get_model_name 函数以获取模型名称,提高代码复用性
- 更新模型加载、卸载和保存后的状态显示,使信息更加准确
- 优化模型列表刷新功能,确保模型列表实时更新
2025-04-11 00:14:40 +08:00
carry
cbb3a09dd8 feat(tools): 添加模型名称获取函数
- 在 tools 目录下新增 model.py 文件
- 实现 get_model_name 函数,用于获取模型的名称
- 更新 tools/__init__.py,导入新的 get_model_name 函数
2025-04-10 22:05:04 +08:00
carry
2e552c186d refactor(frontend): 重构模型选择界面的变量命名
- 将模型选择的 Dropdown 组件从 dropdown 重命名为 model_select_dropdown,提高代码可读性
- 更新 load_button 和 refresh_button 的输出目标,以适应新的变量名
2025-04-10 21:19:58 +08:00
carry
1b3f546669 refactor(frontend): 重构前端页面并添加独立运行功能
- 在 chat_page 和 prompt_manage_page 中添加了独立运行的入口
- 引入 sys 和 pathlib 模块以支持路径操作
- 修改了模块导入方式,使其能够作为独立脚本运行
- 优化了代码结构,提高了可读性和可维护性
2025-04-10 21:18:05 +08:00
carry
402bc73dce feat(model_manage_page): 增加模型保存和刷新功能
- 新增保存模型功能,用户可以输入模型名称并保存当前加载的模型
- 添加刷新模型列表按钮,用户可以随时更新模型下拉菜单中的选项
- 优化页面布局,使按钮和输入框更加合理地排列
2025-04-10 20:18:03 +08:00
carry
bb5851f800 build: 添加 unsloth 依赖
- 在 requirements.txt 中添加 unsloth>=2025.3.9 依赖
2025-04-10 19:56:44 +08:00
carry
a407fa1f76 feat(model_manage_page): 实现模型加载和卸载功能
- 添加模型加载和卸载按钮
- 实现模型加载和卸载的逻辑
- 添加相关模块的导入
- 扫描模型目录并显示在下拉框中
2025-04-10 19:52:08 +08:00
carry
4b465ec917 chore: 更新 .gitignore 文件
- 修改测试代码注释,扩大至参考代码
- 新增 refer/ 目录到忽略列表
2025-04-10 17:38:29 +08:00
carry
e7cc03297b feat(frontend): 添加了简单聊天机器人页面 2025-04-10 17:38:02 +08:00
carry
051d1a7535 feat(frontend): 添加模型管理页面并初始化模型相关全局变量
- 在 frontend/__init__.py 中添加 model_manage_page 模块引用
- 新增 model_manage_page.py 文件,实现模型管理页面的基本框架
- 在 global_var.py 中添加 model 和 tokenizer 全局变量
- 在 main.py 中集成模型管理页面到主应用的 Tabs 组件中
2025-04-10 17:37:45 +08:00
carry
97172f9596 feat(dataset): 设置问答数据集展示页面的每页显示数量
- 在 dataset_manage_page 函数中添加 samples_per_page 参数
- 设置每页显示的样本数量为 20 条
2025-04-10 16:12:59 +08:00
carry
f582820443 feat(tools): 添加生成 Pydantic V2 模型示例 JSON 的工具脚本
- 新增 json_example.py 脚本,用于生成 Pydantic V2 模型的示例 JSON 数据结构
- 支持列表、字典、可选类型以及基本数据类型(字符串、整数、浮点数、布尔值、日期和时间)的示例生成
- 可递归生成嵌套模型的示例 JSON
- 示例使用了项目中的 Q_A 模型,生成了包含多个 Q_A 对象的列表 JSON 结构
2025-04-10 15:38:28 +08:00
carry
8fb9f785b9 feat(frontend): 展示数据集管理页面的问答数据
- 添加 QA 数据集展示组件
- 实现数据集选择时动态加载对应的问答数据
- 优化数据集管理页面布局
2025-04-09 22:23:55 +08:00
carry
2c8e54bb1e feat(dataset): 初步完成数据集管理页面和功能 2025-04-09 20:49:20 +08:00
carry
932d1e2687 refactor(schema): 修改数据集名称默认值
- 将 dataset 类中的 name 字段默认值从 None 改为 ""
- 这个改动确保了数据集名称始终有一个默认的空字符串值,而不是 None,提高了数据一致性和代码健壮性
2025-04-09 19:42:00 +08:00
carry
202d4c44df feat(db): 添加数据集存储和读取功能
- 新增 dataset_store.py 文件,实现数据集的存储和读取功能
- 添加 get_all_dataset 函数,用于获取所有数据集
- 使用 tinydb 和 json 进行数据持久化
- 在项目根目录下创建 workdir/dataset 目录用于存储数据集文件
2025-04-09 18:21:27 +08:00
carry
4d77c429bd refactor(schema): 更新 dataset 模型并为 doc 模型添加版本字段
- 在 doc 模型中添加 version 字段,用于表示文档版本
- 将 dataset 模型中的 source_doc 字段类型从 list[doc] 改为 doc,简化数据结构
2025-04-09 18:18:29 +08:00
carry
41447c5ed4 feat(dataset): 添加数据集来源文档字段
- 在 dataset 模型中增加 source_doc 字段,用于记录数据集的来源文档
- 新增字段为可选列表,包含 doc 类型的元素
2025-04-09 17:37:24 +08:00
carry
84fe78243a feat(tools): 添加 JSON 数据转换为 dataset 的工具脚本
- 新增 convert_json_to_dataset 函数,用于将 JSON 数据转换为 dataset 对象
- 实现了从 JSON 文件读取数据、转换为 dataset 格式并输出到文件的功能
- 该工具可帮助用户将旧数据集快速转换为新的 dataset 格式
2025-04-09 17:31:53 +08:00
carry
4d8754aad2 feat(frontend): 实现数据集生成页面的文档和模板选择功能
- 添加文档和模板的下拉选择框
- 实现文档和模板选择后的状态更新
- 优化页面布局,分为文档和模板两个列
2025-04-09 17:19:40 +08:00
carry
541d37c674 feat(schema): 新增数据集相关模型并添加文档扫描功能
- 新增 dataset.py 文件,定义数据集相关模型
- 新增 tools 目录,包含解析 Markdown 和扫描文档的功能
- 修改 parse_markdown.py,增加处理 Markdown 文件的函数
- 新增 scan_doc_dir.py,实现文档目录扫描功能
2025-04-09 13:02:18 +08:00
carry
6a00699472 feat(frontend): 实现提示词模板管理页面
- 添加获取、添加、编辑和删除提示词功能
- 实现数据表格展示和操作
2025-04-09 11:08:18 +08:00
carry
ff8162890d refactor(db): 移除了提示词模板中冗余的 JSON 格式说明 2025-04-09 10:35:11 +08:00
carry
daddcd34da fix(db): 为 promptStore 添加空数据库初始化逻辑
- 在 initialize_prompt_store 函数中增加空数据库检查和初始化逻辑
- 为默认模板添加 id 字段,设置为 0
2025-04-09 10:28:31 +08:00
carry
5c7ced30df fix(db): 修复 prompt_store 初始化逻辑
- 在插入默认模板之前检查数据库是否为空,如果数据库已有数据,则跳过插入默认模板
2025-04-09 10:26:14 +08:00
carry
9741ce6b92 refactor(db): 优化了代码,调整了import顺序,删除了无用变量 2025-04-09 10:19:57 +08:00
carry
67281fe06a feat(db): 添加 prompt 存储功能
- 新增 prompt_store 模块,使用 TinyDB 存储 prompt 模板
- 在全局变量中添加 prompt_store 实例
- 更新 main.py,初始化 prompt 存储
- 新增 prompt 模板的 Pydantic 模型
- 更新 requirements.txt,添加 tinydb 依赖
2025-04-09 09:58:42 +08:00
carry
2d905a0270 refactor(db): 调整导入模块顺序
- 将 os 和 sys 模块导入提前到文件顶部
- 优化代码结构,遵循常见的 Python 导入模块顺序
2025-04-09 09:57:20 +08:00
carry
374b124cf8 feat(setting_page): 添加供应商后清空输入框
- 修改 add_provider 函数,返回清空后的输入框值
- 更新 add_button.click 事件处理,添加清空输入框的输出
2025-04-09 08:17:43 +08:00
carry
74ae5e1426 refactor(db): 重命名数据库引擎获取函数
将 get_engine 函数重命名为 get_sqlite_engine,以更清晰地表示其功能和用途。
- 更新了 db/__init__.py 中的导入和 __all__ 列表
- 修改了 db/init_db.py 中的函数定义
- 更新了前端设置页面和全局变量中的导入和函数调用

此更改提高了代码的可读性和维护性,特别是在将来可能添加其他类型数据库引擎的情况下。
2025-04-09 08:12:59 +08:00
carry
0a6ae7a4ee feat(frontend): 重构前端页面并添加新功能
- 重命名 dataset_page 为 prompt_manage_page,支持提示词模板管理
- 新增 dataset_generate_page 和 dataset_manage_page 页面
- 更新 main.py 中的页面引用和标签名称
- 修改前端初始化文件,使用 * 导入所有页面模块
2025-04-09 08:11:40 +08:00
carry
faf72d1e99 feat(frontend): 完成了编辑 API Provider 功能 2025-04-09 08:04:40 +08:00
carry
cce5e4e114 feat(frontend): 完成了 API Provider 删除和添加了编辑功能的函数 2025-04-09 00:48:22 +08:00
carry
293f63017f feat(frontend): 添加 API Provider 表格选中行状态监听
- 新增选中行的全局变量 selected_row
- 实现 select_record 函数来保存选中行数据
- 在表格中添加选中行事件监听
- 优化代码结构,提高可读性和可维护性
2025-04-09 00:37:15 +08:00
carry
2e31f4f57c build: 升级 gradio 至 5.0.0 版本
- 将 requirements.txt 中 gradio 版本要求从 >=3.0.0 修改为 >=5.0.0
- 此次升级可能会影响项目的用户界面或功能,需要进行测试以确保兼容性
2025-04-08 16:17:21 +08:00
carry
967133162e refactor(schema): 在 APIProvider 模型中设置 id 字段为不可变
- 在 APIProvider 类中,将 id 字段的定义更新,添加 allow_mutation=False 参数
- 这个改动确保了主键字段在创建后不可更改,提高了数据的一致性和安全性
2025-04-08 16:02:46 +08:00
carry
dc28c25c65 feat(frontend): 更新设置页面按钮样式
- 为"添加新API"按钮添加 primary 样式
- 为"编辑选中行"按钮添加 primary 样式
- 为"删除选中行"按钮添加 stop 样式
- 保持"刷新数据"按钮的 secondary 样式
2025-04-08 14:23:31 +08:00
carry
70b64dc3d3 refactor(db): 重命名数据库初始化函数以明确其适用范围
- 将 initialize_db 函数重命名为 initialize_sqlite_db,以明确该函数专用于 SQLite 数据库
- 更新相关模块和文件中的引用,以确保代码一致性
- 此修改旨在提高代码的可读性和维护性,特别是未来可能接入多种数据库时
2025-04-08 14:16:12 +08:00
carry
b52ca9b1af docs: 添加项目基础文档
- 新增 LICENSE 文件,定义项目使用的 MIT 开源许可证
- 新增 README.md 文件,简要介绍项目内容和技术栈
2025-04-08 13:35:30 +08:00
carry
46b4453ccd refactor(frontend): 重构数据库连接方式
- 移除各前端页面中重复的数据库引擎初始化代码
- 在 global_var.py 中统一初始化和存储数据库引擎
- 更新 setting_page.py 和 main.py 中的数据库连接逻辑
- 优化代码结构,提高可维护性和可扩展性
2025-04-08 13:19:58 +08:00
carry
d5b528d375 chore: 更新 .gitignore 文件
- 保留 unsloth_compiled_cache 目录
- 添加 test.ipynb 到忽略列表,避免测试代码影响版本控制
2025-04-08 12:28:42 +08:00
carry
475cd033d9 build: 添加 langchain 依赖
- 在 requirements.txt 中添加 langchain>=0.3 版本的依赖
- 保持其他依赖版本不变
2025-04-08 11:53:58 +08:00
carry
3970a67df3 refactor(dataset_generation): 增加 APIProvider 模型字段的最小长度验证
- 为 base_url 和 model_id 字段添加 min_length=1 的验证
- 更新字段描述,明确这些字段不能为空
2025-04-07 23:37:14 +08:00
carry
286db405ca feat(frontend): 优化设置页面并添加数据刷新功能
- 为 get_providers 函数添加异常处理,提高数据获取的稳定性
- 在设置页面添加刷新按钮,用户可手动触发数据刷新
- 优化页面布局,调整组件间距和对齐方式
2025-04-07 23:17:43 +08:00
carry
d40f5b1f24 fix(frontend): 优化 API Provider 添加功能并处理异常
- 为 model_id、base_url 和 api_key 添加空值检查,避免无效输入
- 添加异常处理,确保在出现错误时能够及时响应并提示用户
- 优化 add_provider 函数,提高代码可读性和健壮性
2025-04-07 13:02:45 +08:00
carry
7a77f61ee6 feat(frontend): 添加 API Provider 的增加功能 2025-04-07 00:28:52 +08:00
carry
841e14a093 feat(frontend): 添加数据集页面并重构主页面布局
- 新增 dataset_page 模块,实现数据集页面的基本布局
- 重构 main.py 中的页面加载方式,使用列表收集所有页面
- 更新主页面布局,将聊天页面作为第一个选项卡
- 调整设置页面的加载方式,直接使用函数调用
2025-04-06 22:49:37 +08:00
carry
2ff077bb1c refactor(frontend): 重构前端页面导入方式
- 在 main.py 中使用更简洁的导入方式
- 新增 __init__.py 文件以简化前端页面的导入
2025-04-06 22:46:31 +08:00
carry
513b639bce feat(frontend): 添加了设置页面的api provider展示 2025-04-06 22:05:56 +08:00
32 changed files with 1732 additions and 70 deletions

7
.gitignore vendored
View File

@@ -28,4 +28,9 @@ Thumbs.db
workdir/ workdir/
# cache # cache
unsloth_compiled_cache unsloth_compiled_cache
# 测试和参考代码
test.ipynb
test.py
refer/

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 C-a-r-r-y
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

91
README.md Normal file
View File

@@ -0,0 +1,91 @@
# 基于文档驱动的自适应编码大模型微调框架
## 简介
### 项目概述
本项目是一个基于文档驱动的自适应编码大模型微调框架,通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
### 核心功能
- 文档解析与语料生成
- 大语言模型高效微调
- 交互式训练与推理界面
- 训练过程可视化监控
## 技术架构
### 系统架构
```
[前端界面] -> [模型微调] -> [数据存储]
↑ ↑ ↑
│ │ │
[Gradio] [unsloth/QLoRA] [SQLite/TinyDB]
```
### 技术栈
- **前端界面**: Gradio构建的交互式Web界面
- **模型微调**: 基于unsloth框架的QLoRA高效微调
- **数据存储**: SQLite(结构化数据) + TinyDB(非结构化数据)
- **工作流引擎**: LangChain实现文档解析与语料生成
- **训练监控**: TensorBoard集成
## 功能模块
### 1. 模型管理
- 支持多种格式的大语言模型加载
- 模型信息查看与状态管理
- 模型卸载与内存释放
### 2. 模型推理
- 对话式交互界面
- 流式响应输出
- 历史对话管理
### 3. 模型微调
- 训练参数配置(学习率、batch size等)
- LoRA参数配置(秩、alpha等)
- 训练过程实时监控
- 训练中断与恢复
### 4. 数据集生成
- 文档解析与清洗
- 指令-响应对生成
- 数据集质量评估
- 数据集版本管理
## 微调技术
### QLoRA原理
QLoRA(Quantized Low-Rank Adaptation)是一种高效的大模型微调技术,核心特点包括:
1. **4-bit量化**: 将预训练模型量化为4-bit表示大幅减少显存占用
2. **低秩适配**: 通过低秩矩阵分解(LoRA)实现参数高效更新
3. **内存优化**: 使用梯度检查点等技术进一步降低显存需求
### 参数配置
- **学习率**: 建议2e-5到2e-4
- **LoRA秩**: 控制适配器复杂度(建议16-64)
- **LoRA Alpha**: 控制适配器更新幅度(通常设为秩的1-2倍)
### 训练监控
- **TensorBoard集成**: 实时查看损失曲线、学习率等指标
- **日志记录**: 训练过程详细日志保存
- **模型检查点**: 定期保存中间权重
## 快速开始
1. 安装依赖:
```bash
pip install -r requirements.txt
```
2. 启动应用:
```bash
python main.py
```
3. 访问Web界面:
```
http://localhost:7860
```
## 许可证
MIT License

View File

@@ -1,3 +1,12 @@
from .init_db import get_engine, initialize_db from .init_db import load_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
from .dataset_store import get_all_dataset, save_dataset
__all__ = ['get_engine', 'initialize_db'] __all__ = [
"load_sqlite_engine",
"initialize_sqlite_db",
"get_prompt_tinydb",
"initialize_prompt_store",
"get_all_dataset",
"save_dataset"
]

81
db/dataset_store.py Normal file
View File

@@ -0,0 +1,81 @@
import os
import sys
import json
from pathlib import Path
from typing import List
from tinydb import TinyDB, Query
from tinydb.storages import MemoryStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import Dataset, DatasetItem, Q_A
def get_all_dataset(workdir: str) -> TinyDB:
"""
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
Args:
workdir (str): 工作目录路径
Returns:
TinyDB: 包含所有数据集对象的TinyDB对象
"""
dataset_dir = os.path.join(workdir, "dataset")
if not os.path.exists(dataset_dir):
return TinyDB(storage=MemoryStorage)
db = TinyDB(storage=MemoryStorage)
for filename in os.listdir(dataset_dir):
if filename.endswith(".json"):
filepath = os.path.join(dataset_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
db.insert(data)
except (json.JSONDecodeError, Exception) as e:
print(f"Error loading dataset file {filename}: {str(e)}")
continue
return db
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
"""
将TinyDB中的数据集保存为单独的json文件
Args:
db (TinyDB): 包含数据集对象的TinyDB实例
workdir (str): 工作目录路径
name (str, optional): 要保存的数据集名称None表示保存所有
"""
dataset_dir = os.path.join(workdir, "dataset")
os.makedirs(dataset_dir, exist_ok=True)
datasets = db.all() if name is None else db.search(Query().name == name)
for dataset in datasets:
try:
filename = f"{dataset.get(dataset['name'])}.json"
filepath = os.path.join(dataset_dir, filename)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取所有数据集
datasets = get_all_dataset(workdir)
# 打印结果
print(f"Found {len(datasets)} datasets:")
for ds in datasets.all():
print(f"- {ds['name']} (ID: {ds['id']})")
# 询问要保存的数据集名称
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
# 保存数据集到文件
save_dataset(datasets, workdir, name)
print(f"Datasets {'all' if name is None else name} saved to json files")

View File

@@ -1,9 +1,9 @@
import os
import sys
from sqlmodel import SQLModel, create_engine, Session from sqlmodel import SQLModel, create_engine, Session
from sqlmodel import select from sqlmodel import select
from typing import Optional from typing import Optional
import os
from pathlib import Path from pathlib import Path
import sys
from dotenv import load_dotenv from dotenv import load_dotenv
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@@ -14,7 +14,7 @@ from schema.dataset_generation import APIProvider
# 全局变量,用于存储数据库引擎实例 # 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None _engine: Optional[Engine] = None
def get_engine(workdir: str) -> Engine: def load_sqlite_engine(workdir: str) -> Engine:
""" """
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。 获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
@@ -37,7 +37,7 @@ def get_engine(workdir: str) -> Engine:
_engine = create_engine(db_url) _engine = create_engine(db_url)
return _engine return _engine
def initialize_db(engine: Engine) -> None: def initialize_sqlite_db(engine: Engine) -> None:
""" """
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。 初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
@@ -74,6 +74,6 @@ if __name__ == "__main__":
# 定义工作目录路径 # 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎 # 获取数据库引擎
engine = get_engine(workdir) engine = load_sqlite_engine(workdir)
# 初始化数据库 # 初始化数据库
initialize_db(engine) initialize_sqlite_db(engine)

62
db/prompt_store.py Normal file
View File

@@ -0,0 +1,62 @@
import os
import sys
from typing import Optional
from pathlib import Path
from datetime import datetime, timezone
from tinydb import TinyDB, Query
from tinydb.storages import JSONStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.prompt import promptTempleta
# 全局变量用于存储TinyDB实例
_db_instance: Optional[TinyDB] = None
# 自定义存储类用于格式化JSON数据
def get_prompt_tinydb(workdir: str) -> TinyDB:
"""
获取TinyDB实例。如果实例尚未创建则创建一个新的并返回。
Args:
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
Returns:
TinyDB: TinyDB数据库实例
"""
global _db_instance
if not _db_instance:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "prompts.json")
# 创建TinyDB实例
_db_instance = TinyDB(db_path)
return _db_instance
def initialize_prompt_store(db: TinyDB) -> None:
"""
初始化prompt模板存储
Args:
db (TinyDB): TinyDB数据库实例
"""
# 检查数据库是否为空
if not db.all(): # 如果数据库中没有数据
db.insert(promptTempleta(
id=1,
name="default",
description="默认提示词模板",
content="""项目名为:{project_name}
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
文档节选:{document_slice}""").model_dump())
# 如果数据库中已有数据,则跳过插入
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库实例
db = get_prompt_tinydb(workdir)
# 初始化prompt存储
initialize_prompt_store(db)

7
frontend/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
from .train_page import *
from .chat_page import *
from .setting_page import *
from .model_manage_page import *
from .dataset_manage_page import *
from .dataset_generate_page import *
from .prompt_manage_page import *

View File

@@ -1,9 +1,97 @@
import gradio as gr import gradio as gr
import sys
from pathlib import Path
from threading import Thread # 需要导入 Thread
from transformers import TextIteratorStreamer # 使用 TextIteratorStreamer
# 假设 global_var.py 在父目录
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer # 假设这两个函数能正确获取模型和分词器
def chat_page(): def chat_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("## 聊天") # 聊天框
gr.Markdown("## 对话")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column(scale=4):
pass chatbot = gr.Chatbot(type="messages", label="聊天机器人")
return demo msg = gr.Textbox(label="输入消息")
with gr.Column(scale=1):
# 新增超参数输入框
max_new_tokens_input = gr.Textbox(label="最大生成长度", value="1024")
temperature_input = gr.Textbox(label="温度 (Temperature)", value="0.8")
top_p_input = gr.Textbox(label="Top-p 采样", value="0.95")
repetition_penalty_input = gr.Textbox(label="重复惩罚", value="1.1")
clear = gr.Button("清除对话")
def user(user_message, history: list):
return "", history + [{"role": "user", "content": user_message}]
def bot(history: list, max_new_tokens, temperature, top_p, repetition_penalty):
model = get_model()
tokenizer = get_tokenizer()
if not history:
yield history
return
if model is None or tokenizer is None:
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"})
yield history
return
try:
inputs = tokenizer.apply_chat_template(
history,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 将超参数转换为数值类型
generation_kwargs = dict(
input_ids=inputs,
streamer=streamer,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
repetition_penalty=float(repetition_penalty),
do_sample=True,
use_cache=False
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
history.append({"role": "assistant", "content": ""})
for new_text in streamer:
if new_text:
history[-1]["content"] += new_text
yield history
except Exception as e:
import traceback
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
history[-1]["content"] = error_message
else:
history.append({"role": "assistant", "content": error_message})
yield history
# 更新 .then() 调用,将超参数传递给 bot 函数
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, max_new_tokens_input, temperature_input, top_p_input, repetition_penalty_input], chatbot
)
clear.click(lambda: [], None, chatbot, queue=False)
return demo
if __name__ == "__main__":
from model_manage_page import model_manage_page
# 装载两个页面
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()

View File

@@ -0,0 +1,189 @@
import gradio as gr
import sys
import json
from tinydb import Query
from pathlib import Path
from langchain.prompts import PromptTemplate
from sqlmodel import Session, select
from schema import Dataset, DatasetItem, Q_A
from db.dataset_store import get_all_dataset
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
from db import save_dataset
from tools import call_openai_api, process_markdown_file, generate_json_example
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
def dataset_generate_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集生成")
with gr.Row():
with gr.Column(scale=1):
docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None
prompts = get_prompt_store().all()
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_list[0] if prompt_list else None
# 初始化Dataframe的值
initial_dataframe_value = []
if initial_prompt:
selected_prompt_id = int(initial_prompt.split(" ")[0])
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
initial_dataframe_value = [[var, ""] for var in input_variables]
# 从数据库获取API Provider列表
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
api_list = [f"{p.id} {p.model_id}" for p in providers]
initial_api = api_list[0] if api_list else None
api_dropdown = gr.Dropdown(
choices=api_list,
value=initial_api,
label="选择API",
interactive=True
)
doc_dropdown = gr.Dropdown(
choices=docs_list,
value=initial_doc,
label="选择文档",
interactive=True
)
prompt_dropdown = gr.Dropdown(
choices=prompt_list,
value=initial_prompt,
label="选择模板",
interactive=True
)
rounds_input = gr.Number(
value=1,
label="生成轮次",
minimum=1,
maximum=100,
step=1,
interactive=True
)
concurrency_input = gr.Number(
value=1,
label="并发数",
minimum=1,
maximum=10,
step=1,
interactive=True,
visible=False
)
dataset_name_input = gr.Textbox(
label="数据集名称",
placeholder="输入数据集保存名称",
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
generate_button = gr.Button("生成数据集",variant="primary")
doc_choice = gr.State(value=initial_doc)
output_text = gr.Textbox(label="生成结果", interactive=False)
api_choice = gr.State(value=initial_api)
with gr.Column(scale=2):
variables_dataframe = gr.Dataframe(
headers=["变量名", "变量值"],
datatype=["str", "str"],
interactive=True,
label="变量列表",
value=initial_dataframe_value # 设置初始化数据
)
def on_doc_change(selected_doc):
return selected_doc
def on_api_change(selected_api):
return selected_api
def on_prompt_change(selected_prompt):
if not selected_prompt:
return None, []
selected_prompt_id = int(selected_prompt.split(" ")[0])
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
dataframe_value = [] if input_variables is None else input_variables
dataframe_value = [[var, ""] for var in input_variables]
return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
dataset_db = get_datasets()
if not dataset_db.search(Query().name == dataset_name):
raise gr.Error("数据集名称已存在")
doc = [i for i in get_docs() if i.name == doc_state][0]
doc_files = doc.markdown_files
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
prompt = PromptTemplate.from_template(prompt["content"])
with Session(get_sql_engine()) as session:
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
variables_dict = {}
for _, row in variables_dataframe.iterrows():
var_name = row['变量名'].strip()
var_value = row['变量值'].strip()
if var_name:
variables_dict[var_name] = var_value
prompt = prompt.partial(**variables_dict)
dataset = Dataset(
name=dataset_name,
model_id=[api_provider.model_id],
source_doc=doc,
dataset_items=[]
)
total_slices = len(document_slice_list)
for i, document_slice in enumerate(document_slice_list):
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
request = LLMRequest(api_provider=api_provider,
prompt=prompt.format(document_slice=document_slice),
format=generate_json_example(DatasetItem))
call_openai_api(request, rounds)
for resp in request.response:
try:
content = json.loads(resp.content)
dataset_item = DatasetItem(
message=[Q_A(
question=content.get("question", ""),
answer=content.get("answer", "")
)]
)
dataset.dataset_items.append(dataset_item)
except json.JSONDecodeError as e:
print(f"Failed to parse response: {e}")
# 保存数据集到TinyDB
dataset_db.insert(dataset.model_dump())
save_dataset(dataset_db,get_workdir(),dataset_name)
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
generate_button.click(
on_generate_click,
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
outputs=output_text
)
return demo
if __name__ == "__main__":
from global_var import init_global_var
init_global_var("workdir")
demo = dataset_generate_page()
demo.launch()

View File

@@ -0,0 +1,55 @@
import gradio as gr
from global_var import get_datasets
from tinydb import Query
def dataset_manage_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集管理")
with gr.Row():
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
# 添加数据集展示组件
qa_dataset = gr.Dataset(
components=["text", "text"],
label="问答数据",
headers=["问题", "答案"],
samples=[["示例问题", "示例答案"]],
samples_per_page=20,
)
def update_qa_display(dataset_name):
if not dataset_name:
return {"samples": [], "__type__": "update"}
# 从数据库获取数据集
Dataset = Query()
ds = get_datasets().get(Dataset.name == dataset_name)
if not ds:
return {"samples": [], "__type__": "update"}
# 提取所有Q_A数据
qa_list = []
for item in ds["dataset_items"]:
for qa in item["message"]:
qa_list.append([qa["question"], qa["answer"]])
return {"samples": qa_list, "__type__": "update"}
# 绑定事件更新QA数据显示
dataset_dropdown.change(
update_qa_display,
inputs=dataset_dropdown,
outputs=qa_dataset
)
return demo

View File

@@ -0,0 +1,103 @@
import gradio as gr
import os # 导入os模块以便扫描文件夹
import sys
from pathlib import Path
from unsloth import FastLanguageModel
import torch
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
from train import get_model_name
def model_manage_page():
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
models_dir = os.path.join(workdir, "models")
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))] # 扫描models文件夹下的所有子文件夹
with gr.Blocks() as demo:
gr.Markdown("## 模型管理")
state_output = gr.Label(label="当前状态",value="当前未加载模型") # 将 Textbox 改为 Label
with gr.Row():
with gr.Column(scale=3):
model_select_dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中并设置为可选
max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0)
load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True)
with gr.Column(scale=1):
load_button = gr.Button("加载模型", variant="primary")
unload_button = gr.Button("卸载模型", variant="stop")
refresh_button = gr.Button("刷新模型列表", variant="secondary") # 新增刷新按钮
with gr.Row():
with gr.Column(scale=3):
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
save_method_dropdown = gr.Dropdown(
choices=["default", "merged_16bit", "merged_4bit", "lora", "gguf", "gguf_q4_k_m", "gguf_f16"],
label="保存模式",
value="default"
)
with gr.Column(scale=1):
save_button = gr.Button("保存模型", variant="secondary")
def load_model(selected_model, max_seq_length, load_in_4bit):
try:
# 判空操作,如果模型已加载,则先卸载
if get_model() is not None:
unload_model()
model_path = os.path.join(models_dir, selected_model)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
)
set_model(model)
set_tokenizer(tokenizer)
return f"模型 {get_model_name(model)} 已加载"
except Exception as e:
return f"加载模型时出错: {str(e)}"
load_button.click(fn=load_model, inputs=[model_select_dropdown, max_seq_length_input, load_in_4bit_input], outputs=state_output)
def unload_model():
try:
# 将模型移动到 CPU
model = get_model()
if model is not None:
model.cpu()
# 清空 CUDA 缓存
torch.cuda.empty_cache()
# 将模型和tokenizer设置为 None
set_model(None)
set_tokenizer(None)
return "当前未加载模型"
except Exception as e:
return f"卸载模型时出错: {str(e)}"
unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
from train.save_model import save_model_to_dir
def save_model(save_model_name, save_method):
return save_model_to_dir(save_model_name, models_dir, get_model(), get_tokenizer(), save_method)
save_button.click(fn=save_model, inputs=[save_model_name_input, save_method_dropdown], outputs=state_output)
def refresh_model_list():
try:
nonlocal model_folders
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))]
return gr.Dropdown(choices=model_folders)
except Exception as e:
return f"刷新模型列表时出错: {str(e)}"
refresh_button.click(fn=refresh_model_list, inputs=None, outputs=model_select_dropdown)
return demo
if __name__ == "__main__":
demo = model_manage_page()
demo.queue()
demo.launch()

View File

@@ -0,0 +1,130 @@
import gradio as gr
import sys
from pathlib import Path
from typing import List
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_prompt_store
from schema.prompt import promptTempleta
def prompt_manage_page():
def get_prompts() -> List[List[str]]:
selected_row = None
try:
db = get_prompt_store()
prompts = db.all()
return [
[p["id"], p["name"], p["description"], p["content"]]
for p in prompts
]
except Exception as e:
raise gr.Error(f"获取提示词失败: {str(e)}")
def add_prompt(name, description, content):
try:
db = get_prompt_store()
new_prompt = promptTempleta(
name=name if name else "",
description=description if description else "",
content=content if content else ""
)
prompt_id = db.insert(new_prompt.model_dump())
# 更新ID
db.update({"id": prompt_id}, doc_ids=[prompt_id])
return get_prompts(), "", "", "" # 返回清空后的输入框值
except Exception as e:
raise gr.Error(f"添加失败: {str(e)}")
def edit_prompt():
global selected_row
if not selected_row:
raise gr.Error("请先选择要编辑的行")
try:
db = get_prompt_store()
db.update({
"name": selected_row[1] if selected_row[1] else "",
"description": selected_row[2] if selected_row[2] else "",
"content": selected_row[3] if selected_row[3] else ""
}, doc_ids=[selected_row[0]])
return get_prompts()
except Exception as e:
raise gr.Error(f"编辑失败: {str(e)}")
def delete_prompt():
global selected_row
if not selected_row:
raise gr.Error("请先选择要删除的行")
try:
db = get_prompt_store()
db.remove(doc_ids=[selected_row[0]])
return get_prompts()
except Exception as e:
raise gr.Error(f"删除失败: {str(e)}")
selected_row = None # 保存当前选中行的全局变量
def select_record(dataFrame ,evt: gr.SelectData):
global selected_row
selected_row = dataFrame.iloc[evt.index[0]].tolist()
selected_row[0] = int(selected_row[0])
print(selected_row)
with gr.Blocks() as demo:
gr.Markdown("## 提示词模板管理")
with gr.Row():
with gr.Column(scale=1):
name_input = gr.Textbox(label="模板名称")
description_input = gr.Textbox(label="模板描述")
content_input = gr.Textbox(label="模板内容", lines=10)
add_button = gr.Button("添加新模板", variant="primary")
with gr.Column(scale=3):
prompt_table = gr.DataFrame(
headers=["id", "名称", "描述", "内容"],
datatype=["number", "str", "str", "str"],
interactive=True,
value=get_prompts(),
wrap=True,
col_count=(4, "auto")
)
with gr.Row():
refresh_button = gr.Button("刷新数据", variant="secondary")
edit_button = gr.Button("编辑选中行", variant="primary")
delete_button = gr.Button("删除选中行", variant="stop")
refresh_button.click(
fn=get_prompts,
outputs=[prompt_table],
queue=False
)
add_button.click(
fn=add_prompt,
inputs=[name_input, description_input, content_input],
outputs=[prompt_table, name_input, description_input, content_input]
)
prompt_table.select(fn=select_record,
inputs=[prompt_table],
outputs=[],
show_progress="hidden")
edit_button.click(
fn=edit_prompt,
inputs=[],
outputs=[prompt_table]
)
delete_button.click(
fn=delete_prompt,
inputs=[],
outputs=[prompt_table]
)
return demo
if __name__ == "__main__":
demo = prompt_manage_page()
demo.queue()
demo.launch()

View File

@@ -1,9 +1,131 @@
import gradio as gr import gradio as gr
from typing import List
from sqlmodel import Session, select
from schema import APIProvider
from global_var import get_sql_engine
def setting_page(): def setting_page():
def get_providers() -> List[List[str]]:
selected_row = None
try: # 添加异常处理
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
return [
[p.id, p.model_id, p.base_url, p.api_key or ""]
for p in providers
]
except Exception as e:
raise gr.Error(f"获取数据失败: {str(e)}")
def add_provider(model_id, base_url, api_key):
try:
with Session(get_sql_engine()) as session:
new_provider = APIProvider(
model_id=model_id if model_id else None,
base_url=base_url if base_url else None,
api_key=api_key if api_key else None
)
session.add(new_provider)
session.commit()
session.refresh(new_provider)
return get_providers(), "", "", "" # 返回清空后的输入框值
except Exception as e:
raise gr.Error(f"添加失败: {str(e)}")
def edit_provider():
global selected_row
if not selected_row:
raise gr.Error("请先选择要编辑的行")
try:
with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0])
if not provider:
raise gr.Error("找不到选中的记录")
provider.model_id = selected_row[1] if selected_row[1] else None
provider.base_url = selected_row[2] if selected_row[2] else None
provider.api_key = selected_row[3] if selected_row[3] else None
session.add(provider)
session.commit()
session.refresh(provider)
return get_providers()
except Exception as e:
raise gr.Error(f"编辑失败: {str(e)}")
def delete_provider():
global selected_row
if not selected_row:
raise gr.Error("请先选择要删除的行")
try:
with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0])
if not provider:
raise gr.Error("找不到选中的记录")
session.delete(provider)
session.commit()
return get_providers()
except Exception as e:
raise gr.Error(f"删除失败: {str(e)}")
selected_row = None # 保存当前选中行的全局变量
def select_record(dataFrame ,evt: gr.SelectData):
global selected_row
selected_row = dataFrame.iloc[evt.index[0]].tolist()
selected_row[0] = int(selected_row[0])
print(selected_row)
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("## 设置") gr.Markdown("## API Provider 管理")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column(scale=1):
pass model_id_input = gr.Textbox(label="Model ID")
return demo base_url_input = gr.Textbox(label="Base URL")
api_key_input = gr.Textbox(label="API Key")
add_button = gr.Button("添加新API", variant="primary")
with gr.Column(scale=3):
provider_table = gr.DataFrame(
headers=["id", "model id", "base URL", "API Key"],
datatype=["number", "str", "str", "str"],
interactive=True,
value=get_providers(),
wrap=True,
col_count=(4, "auto")
)
with gr.Row():
refresh_button = gr.Button("刷新数据", variant="secondary")
edit_button = gr.Button("编辑选中行", variant="primary")
delete_button = gr.Button("删除选中行", variant="stop")
refresh_button.click(
fn=get_providers,
outputs=[provider_table],
queue=False # 立即刷新不需要排队
)
add_button.click(
fn=add_provider,
inputs=[model_id_input, base_url_input, api_key_input],
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
)
provider_table.select(fn=select_record,
inputs=[provider_table],
outputs=[],
show_progress="hidden")
edit_button.click(
fn=edit_provider,
inputs=[],
outputs=[provider_table]
)
delete_button.click(
fn=delete_provider,
inputs=[],
outputs=[provider_table]
)
return demo

View File

@@ -1,9 +1,113 @@
import subprocess
import os
import gradio as gr import gradio as gr
import sys
from tinydb import Query
from pathlib import Path
from transformers import TrainerCallback
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
from tools import find_available_port
from train import train_model
def train_page(): def train_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("## 微调") gr.Markdown("## 微调")
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column(scale=1):
pass dataset_dropdown = gr.Dropdown(
return demo choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
# 新增超参数输入组件
learning_rate_input = gr.Number(value=2e-4, label="学习率")
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
epoch_input = gr.Number(value=1, label="epoch", precision=0)
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
train_button = gr.Button("开始微调")
# 训练状态输出
output = gr.Textbox(label="训练状态", interactive=False)
with gr.Column(scale=3):
# 新增 TensorBoard iframe 显示框
tensorboard_iframe = gr.HTML(label="TensorBoard 可视化")
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
# 使用动态传入的超参数
learning_rate = float(learning_rate)
per_device_train_batch_size = int(per_device_train_batch_size)
epoch = int(epoch)
save_steps = int(save_steps) # 新增保存步数参数
lora_rank = int(lora_rank) # 新增LoRA秩参数
# 加载数据集
dataset = get_datasets().get(Query().name == dataset_name)
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
# 扫描 training 文件夹并生成递增目录
training_dir = get_workdir() + "/training"
os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在
existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()]
next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1
new_training_dir = os.path.join(training_dir, str(next_dir_number))
tensorboard_port = find_available_port(6006) # 从默认端口 6006 开始检测
print(f"TensorBoard 将使用端口: {tensorboard_port}")
tensorboard_logdir = os.path.join(new_training_dir, "logs")
os.makedirs(tensorboard_logdir, exist_ok=True) # 确保日志目录存在
tensorboard_process = subprocess.Popen(
["tensorboard", "--logdir", tensorboard_logdir, "--port", str(tensorboard_port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
print("TensorBoard 已启动,日志目录:", tensorboard_logdir)
# 动态生成 TensorBoard iframe
tensorboard_url = f"http://localhost:{tensorboard_port}"
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>'
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
try:
train_model(get_model(), get_tokenizer(),
dataset, new_training_dir,
learning_rate, per_device_train_batch_size, epoch,
save_steps, lora_rank)
except Exception as e:
raise gr.Error(str(e))
finally:
# 确保训练结束后终止 TensorBoard 子进程
tensorboard_process.terminate()
print("TensorBoard 子进程已终止")
train_button.click(
fn=start_training,
inputs=[
dataset_dropdown,
learning_rate_input,
per_device_train_batch_size_input,
epoch_input,
save_steps_input,
lora_rank_input
],
outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe
)
return demo
if __name__ == "__main__":
from global_var import init_global_var
from model_manage_page import model_manage_page
init_global_var("workdir")
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()

52
global_var.py Normal file
View File

@@ -0,0 +1,52 @@
from db import load_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory
_prompt_store = None
_sql_engine = None
_datasets = None
_model = None
_tokenizer = None
_workdir = None
def init_global_var(workdir="workdir"):
global _prompt_store, _sql_engine, _datasets, _workdir
_prompt_store = get_prompt_tinydb(workdir)
_sql_engine = load_sqlite_engine(workdir)
_datasets = get_all_dataset(workdir)
_workdir = workdir
def get_workdir():
return _workdir
def get_prompt_store():
return _prompt_store
def set_prompt_store(new_prompt_store):
global _prompt_store
_prompt_store = new_prompt_store
def get_sql_engine():
return _sql_engine
def set_sql_engine(new_sql_engine):
global _sql_engine
_sql_engine = new_sql_engine
def get_docs():
global _workdir
return scan_docs_directory(_workdir)
def get_datasets():
return _datasets
def get_model():
return _model
def set_model(new_model):
global _model
_model = new_model
def get_tokenizer():
return _tokenizer
def set_tokenizer(new_tokenizer):
global _tokenizer
_tokenizer = new_tokenizer

38
main.py
View File

@@ -1,23 +1,29 @@
import gradio as gr import gradio as gr
from frontend.setting_page import setting_page import train
from frontend.chat_page import chat_page from frontend import *
from frontend.train_page import train_page from db import initialize_sqlite_db, initialize_prompt_store
def main(): from global_var import init_global_var, get_sql_engine, get_prompt_store
setting_demo = setting_page()
chat_demo = chat_page()
train_demo = train_page()
if __name__ == "__main__":
init_global_var()
initialize_sqlite_db(get_sql_engine())
initialize_prompt_store(get_prompt_store())
with gr.Blocks() as app: with gr.Blocks() as app:
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架") gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
with gr.Tabs(): with gr.Tabs():
with gr.TabItem("微调"): with gr.TabItem("模型管理"):
train_demo.render() model_manage_page()
with gr.TabItem("聊天"): with gr.TabItem("模型推理"):
chat_demo.render() chat_page()
with gr.TabItem("模型微调"):
train_page()
with gr.TabItem("数据集生成"):
dataset_generate_page()
with gr.TabItem("数据集管理"):
dataset_manage_page()
with gr.TabItem("提示词模板管理"):
prompt_manage_page()
with gr.TabItem("设置"): with gr.TabItem("设置"):
setting_demo.render() setting_page()
app.launch() app.launch(server_name="0.0.0.0")
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,11 @@
openai>=1.0.0 openai>=1.0.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0
pydantic>=2.0.0 pydantic>=2.0.0
gradio>=3.0.0 gradio>=5.25.0
langchain>=0.3
tinydb>=4.0.0
unsloth>=2025.3.19
sqlmodel>=0.0.24
jinja2>=3.1.0
tensorboardX>=2.6.2.2
tensorboard>=2.19.0

View File

@@ -1,4 +1,4 @@
from .dataset import * from .dataset import *
from .dataset_generation import APIProvider, LLMResponse, LLMRequest from .dataset_generation import *
from .md_doc import MarkdownNode from .md_doc import MarkdownNode
from .prompt import promptTempleta from .prompt import promptTempleta

30
schema/dataset.py Normal file
View File

@@ -0,0 +1,30 @@
from typing import Optional
from pydantic import BaseModel, Field
from datetime import datetime, timezone
class Doc(BaseModel):
id: Optional[int] = Field(default=None, description="文档ID")
name: str = Field(default="", description="文档名称")
path: str = Field(default="", description="文档路径")
markdown_files: list[str] = Field(default_factory=list, description="文档路径列表")
version: Optional[str] = Field(default="", description="文档版本")
class Q_A(BaseModel):
question: str = Field(default="", min_length=1,description="问题")
answer: str = Field(default="", min_length=1, description="答案")
class DatasetItem(BaseModel):
id: Optional[int] = Field(default=None, description="数据集项ID")
message: list[Q_A] = Field(description="数据集项内容")
class Dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID")
name: str = Field(default="", description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
source_doc: Optional[Doc] = Field(default=None, description="数据集来源文档")
description: Optional[str] = Field(default="", description="数据集描述")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
dataset_items: list[DatasetItem] = Field(default_factory=list, description="数据集项列表")

View File

@@ -3,49 +3,45 @@ from typing import Optional
from sqlmodel import SQLModel, Relationship, Field from sqlmodel import SQLModel, Relationship, Field
class APIProvider(SQLModel, table=True): class APIProvider(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True,allow_mutation=False)
base_url: str = Field(..., description="API的基础URL") base_url: str = Field(...,min_length=1,description="API的基础URL,不能为空")
model_id: str = Field(..., description="API使用的模型ID") model_id: str = Field(...,min_length=1,description="API使用的模型ID,不能为空")
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥") api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
created_at: datetime = Field( created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间" description="记录创建时间"
) )
class LLMParameters(SQLModel):
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
seed: Optional[int] = None
class TokensUsage(SQLModel):
prompt_tokens: int = Field(default=0, description="提示词使用的token数量")
completion_tokens: int = Field(default=0, description="完成部分使用的token数量")
prompt_cache_hit_tokens: Optional[int] = Field(default=None, description="缓存命中token数量")
prompt_cache_miss_tokens: Optional[int] = Field(default=None, description="缓存未命中token数量")
class LLMResponse(SQLModel): class LLMResponse(SQLModel):
timestamp: datetime = Field( timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), default_factory=lambda: datetime.now(timezone.utc),
description="响应的时间戳" description="响应的时间戳"
) )
response_id: str = Field(..., description="响应的唯一ID") response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: dict = Field(default_factory=lambda: { tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
"prompt_tokens": 0, content: str = Field(default_factory=dict, description="API响应的内容")
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: dict = Field(default_factory=lambda: { llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")
"temperature": None,
"max_tokens": None,
"top_p": None,
"frequency_penalty": None,
"presence_penalty": None,
"seed": None
}, description="API的生成参数")
class LLMRequest(SQLModel): class LLMRequest(SQLModel):
prompt: str = Field(..., description="发送给API的提示词") prompt: str = Field(..., description="发送给API的提示词")
provider_id: int = Field(foreign_key="apiprovider.id") api_provider: APIProvider = Field(..., description="API提供者的信息")
provider: APIProvider = Relationship()
format: Optional[str] = Field(default=None, description="API响应的格式") format: Optional[str] = Field(default=None, description="API响应的格式")
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表") response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息") error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
total_tokens_usage: dict = Field(default_factory=lambda: { total_tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")

20
schema/prompt.py Normal file
View File

@@ -0,0 +1,20 @@
from pydantic import BaseModel, Field, field_validator
from typing import Optional
from datetime import datetime, timezone
from langchain.prompts import PromptTemplate
class promptTempleta(BaseModel):
id: Optional[int] = Field(default=None, description="模板ID")
name: Optional[str] = Field(default="", description="模板名称")
description: Optional[str] = Field(default="", description="模板描述")
content: str = Field(default="", min_length=1, description="模板内容")
created_at: str = Field(
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
description="记录创建时间"
)
@field_validator('content')
def validate_content(cls, value):
if not "document_slice" in PromptTemplate.from_template(value).input_variables:
raise ValueError("模板变量缺少 document_slice")
return value

5
tools/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .parse_markdown import *
from .document import *
from .json_example import generate_json_example
from .port import *
from .reasoning import call_openai_api

View File

@@ -0,0 +1,35 @@
from typing import List
from schema.dataset import Dataset, DatasetItem, Q_A
import json
def convert_json_to_dataset(json_data: List[dict]) -> Dataset:
# 将JSON数据转换为dataset格式
dataset_items = []
item_id = 1 # 自增ID计数器
for item in json_data:
qa = Q_A(question=item["question"], answer=item["answer"])
dataset_item_obj = DatasetItem(id=item_id, message=[qa])
dataset_items.append(dataset_item_obj)
item_id += 1 # ID自增
# 创建dataset对象
result_dataset = Dataset(
name="Converted Dataset",
model_id=None,
description="Dataset converted from JSON",
dataset_items=dataset_items
)
return result_dataset
# 示例从文件读取JSON并转换
if __name__ == "__main__":
# 假设JSON数据存储在文件中
with open(r"workdir\dataset_old\llamafactory.json", "r", encoding="utf-8") as file:
json_data = json.load(file)
# 转换为dataset格式
converted_dataset = convert_json_to_dataset(json_data)
# 输出结果到文件
with open("output.json", "w", encoding="utf-8") as file:
file.write(converted_dataset.model_dump_json(indent=4))

32
tools/document.py Normal file
View File

@@ -0,0 +1,32 @@
import sys
import os
from pathlib import Path
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import Doc
def scan_docs_directory(workdir: str):
docs_dir = os.path.join(workdir, "docs")
doc_list = os.listdir(docs_dir)
to_return = []
for doc_name in doc_list:
doc_path = os.path.join(docs_dir, doc_name)
if os.path.isdir(doc_path):
markdown_files = []
for root, dirs, files in os.walk(doc_path):
for file in files:
if file.endswith(".md"):
markdown_files.append(os.path.join(root, file))
to_return.append(Doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
return to_return
# 添加测试代码
if __name__ == "__main__":
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
docs = scan_docs_directory(workdir)
print(docs)

67
tools/json_example.py Normal file
View File

@@ -0,0 +1,67 @@
from pydantic import BaseModel, create_model
from typing import Any, Dict, List, Optional, Union, get_args, get_origin
import json
from datetime import datetime, date
def generate_json_example(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
"""
根据 Pydantic V2 模型生成示例 JSON 数据结构。
"""
def _generate_example(field_type: Any) -> Any:
origin = get_origin(field_type)
args = get_args(field_type)
if origin is list or origin is List:
# 生成多个元素,这里生成 3 个
result = [_generate_example(args[0]) for _ in range(list_length)] if args else []
result.append("......")
return result
elif origin is dict or origin is Dict:
if len(args) == 2:
return {"key": _generate_example(args[1])}
return {}
elif origin is Union:
# 处理 Optional 类型Union[T, None]
non_none_args = [arg for arg in args if arg is not type(None)]
return _generate_example(non_none_args[0]) if non_none_args else None
elif field_type is str:
return "string"
elif field_type is int:
return 0
elif field_type is float:
return 0.0
elif field_type is bool:
return True
elif field_type is datetime:
return datetime.now().isoformat()
elif field_type is date:
return date.today().isoformat()
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
return json.loads(generate_json_example(field_type, include_optional))
else:
# 处理直接类型注解(非泛型)
if field_type is type(None):
return None
try:
if issubclass(field_type, BaseModel):
return json.loads(generate_json_example(field_type, include_optional))
except TypeError:
pass
return "unknown"
example_data = {}
for field_name, field in model.model_fields.items():
if include_optional or not isinstance(field.default, type(None)):
example_data[field_name] = _generate_example(field.annotation)
return json.dumps(example_data, indent=2, default=str)
if __name__ == "__main__":
import sys
from pathlib import Path
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import Dataset
print("示例 JSON:")
print(generate_json_example(Dataset))

View File

@@ -6,6 +6,27 @@ from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import MarkdownNode from schema import MarkdownNode
def process_markdown_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
root = parse_markdown(content)
results = []
def traverse(node, parent_titles):
current_titles = parent_titles.copy()
current_titles.append(node.title)
if not node.children: # 叶子节点
if node.content:
full_text = ' -> '.join(current_titles) + '\n' + node.content
results.append(full_text)
else:
for child in node.children:
traverse(child, current_titles)
traverse(root, [])
return results
def add_child(parent, child): def add_child(parent, child):
parent.children.append(child) parent.children.append(child)
@@ -60,10 +81,13 @@ def parse_markdown(markdown):
return root return root
if __name__=="__main__": if __name__=="__main__":
# 从文件读取 Markdown 内容 # # 从文件读取 Markdown 内容
with open("workdir/example.md", "r", encoding="utf-8") as f: # with open("workdir/example.md", "r", encoding="utf-8") as f:
markdown = f.read() # markdown = f.read()
# 解析 Markdown 并打印树结构 # # 解析 Markdown 并打印树结构
root = parse_markdown(markdown) # root = parse_markdown(markdown)
print_tree(root) # print_tree(root)
for i in process_markdown_file("workdir/example.md"):
print("~"*20)
print(i)

15
tools/port.py Normal file
View File

@@ -0,0 +1,15 @@
import socket
# 启动 TensorBoard 子进程前添加端口检测逻辑
def find_available_port(start_port):
port = start_port
while True:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(('localhost', port)) != 0: # 端口未被占用
return port
port += 1 # 如果端口被占用,尝试下一个端口
if __name__ == "__main__":
start_port = 6006 # 起始端口号
available_port = find_available_port(start_port)
print(f"Available port: {available_port}")

123
tools/reasoning.py Normal file
View File

@@ -0,0 +1,123 @@
import sys
import asyncio
import openai
from pathlib import Path
from datetime import datetime, timezone
from typing import Optional
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters
async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest:
start_time = datetime.now(timezone.utc)
client = openai.AsyncOpenAI(
api_key=llm_request.api_provider.api_key,
base_url=llm_request.api_provider.base_url
)
total_duration = 0.0
total_tokens = TokensUsage()
prompt = llm_request.prompt
round_start = datetime.now(timezone.utc)
if llm_request.format:
prompt += "\n请以JSON格式返回结果" + llm_request.format
for i in range(rounds):
round_start = datetime.now(timezone.utc)
try:
messages = [{"role": "user", "content": prompt}]
create_args = {
"model": llm_request.api_provider.model_id,
"messages": messages,
"temperature": llm_parameters.temperature if llm_parameters else None,
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
"top_p": llm_parameters.top_p if llm_parameters else None,
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
"seed": llm_parameters.seed if llm_parameters else None
} # 处理format参数
if llm_request.format:
create_args["response_format"] = {"type": "json_object"}
response = await client.chat.completions.create(**create_args)
round_end = datetime.now(timezone.utc)
duration = (round_end - round_start).total_seconds()
total_duration += duration
# 处理可能不存在的缓存token字段
usage = response.usage
cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None)
cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None)
tokens_usage = TokensUsage(
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
prompt_cache_hit_tokens=cache_hit,
prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens
)
# 累加总token使用量
total_tokens.prompt_tokens += tokens_usage.prompt_tokens
total_tokens.completion_tokens += tokens_usage.completion_tokens
if tokens_usage.prompt_cache_hit_tokens:
total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens
if tokens_usage.prompt_cache_miss_tokens:
total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens
llm_request.response.append(LLMResponse(
response_id=response.id,
tokens_usage=tokens_usage,
content = response.choices[0].message.content,
total_duration=duration,
llm_parameters=llm_parameters
))
except Exception as e:
round_end = datetime.now(timezone.utc)
duration = (round_end - round_start).total_seconds()
total_duration += duration
llm_request.response.append(LLMResponse(
response_id=f"error-round-{i+1}",
content={"error": str(e)},
total_duration=duration
))
if llm_request.error is None:
llm_request.error = []
llm_request.error.append(str(e))
# 更新总耗时和总token使用量
llm_request.total_duration = total_duration
llm_request.total_tokens_usage = total_tokens
return llm_request
if __name__ == "__main__":
from json_example import generate_json_example
from sqlmodel import Session, select
from global_var import get_sql_engine, init_global_var
from schema import DatasetItem
init_global_var("workdir")
api_state = "1 deepseek-chat"
with Session(get_sql_engine()) as session:
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
llm_request = LLMRequest(
prompt="测试,随便说点什么",
api_provider=api_provider,
format=generate_json_example(DatasetItem)
)
# # 单次调用示例
# result = asyncio.run(call_openai_api(llm_request))
# print(f"\n单次调用结果 - 响应数量: {len(result.response)}")
# for i, resp in enumerate(result.response, 1):
# print(f"响应{i}: {resp.response_content}")
# 多次调用示例
params = LLMParameters(temperature=0.7, max_tokens=100)
result = asyncio.run(call_openai_api(llm_request, 3,params))
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
for i, resp in enumerate(result.response, 1):
print(f"响应{i}: {resp.content}")

1
train/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .model import *

134
train/model.py Normal file
View File

@@ -0,0 +1,134 @@
import os
from datasets import Dataset as HFDataset
from unsloth import FastLanguageModel
from trl import SFTTrainer # 用于监督微调的训练器
from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数
from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练
from unsloth.chat_templates import get_chat_template, train_on_responses_only
def get_model_name(model):
return os.path.basename(model.name_or_path)
def formatting_prompts(examples,tokenizer):
"""格式化对话数据的函数
Args:
examples: 包含对话列表的字典
Returns:
包含格式化文本的字典
"""
questions = examples["question"]
answer = examples["answer"]
convos = [
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
for q, r in zip(questions, answer)
]
# 使用tokenizer.apply_chat_template格式化对话
texts = [
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
for convo in convos
]
return {"text": texts}
def train_model(
model,
tokenizer,
dataset: list,
train_dir: str,
learning_rate: float,
per_device_train_batch_size: int,
epoch: int,
save_steps: int,
lora_rank: int,
trainer_callback=None
) -> None:
# 模型配置参数
dtype = None # 数据类型None表示自动选择
load_in_4bit = False # 使用4bit量化加载模型以节省显存
model = FastLanguageModel.get_peft_model(
# 原始模型
model,
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
# 建议: 8-32之间
r=lora_rank, # 使用动态传入的LoRA秩
# 需要应用LoRA的目标模块列表
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
"gate_proj", "up_proj", "down_proj", # FFN相关层
],
# LoRA缩放因子,用于控制LoRA更新的幅度。值越大LoRA的更新影响越大。
lora_alpha=16,
# LoRA层的dropout率,用于防止过拟合,这里设为0表示不使用dropout。
# 如果数据集较小建议设置0.1左右。
lora_dropout=0,
# 是否对bias参数进行微调,none表示不微调bias
# none: 不微调偏置参数;
# all: 微调所有参数;
# lora_only: 只微调LoRA参数。
bias="none",
# 是否使用梯度检查点技术节省显存,使用unsloth优化版本
# 会略微降低训练速度,但可以显著减少显存使用
use_gradient_checkpointing="unsloth",
# 随机数种子,用于结果复现
random_state=3407,
# 是否使用rank-stabilized LoRA,这里不使用
# 会略微降低训练速度,但可以显著减少显存使用
use_rslora=False,
# LoFTQ配置,这里不使用该量化技术,用于进一步压缩模型大小
loftq_config=None,
)
tokenizer = get_chat_template(
tokenizer,
chat_template="qwen-2.5",
)
train_dataset = HFDataset.from_list(dataset)
train_dataset = train_dataset.map(formatting_prompts,
fn_kwargs={"tokenizer": tokenizer},
batched=True)
# 初始化SFT训练器
trainer = SFTTrainer(
model=model, # 待训练的模型
tokenizer=tokenizer, # 分词器
train_dataset=train_dataset, # 训练数据集
dataset_text_field="text", # 数据集字段的名称
max_seq_length=model.max_seq_length, # 最大序列长度
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=1, # 数据集处理的并行进程数
packing=False,
args=TrainingArguments(
per_device_train_batch_size=per_device_train_batch_size, # 每个GPU的训练批次大小
gradient_accumulation_steps=4, # 梯度累积步数,用于模拟更大的batch size
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
learning_rate=learning_rate, # 学习率
lr_scheduler_type="linear", # 线性学习率调度器
max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
logging_steps=1, # 每1步记录一次日志
optim="adamw_8bit", # 使用8位AdamW优化器节省显存几乎不影响训练效果
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
seed=114514, # 随机数种子
output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志
save_strategy="steps", # 按步保存中间权重
save_steps=save_steps, # 使用动态传入的保存步数
logging_dir=train_dir + "/logs", # 日志文件存储路径
report_to="tensorboard", # 使用TensorBoard记录日志
),
)
if trainer_callback is not None:
trainer.add_callback(trainer_callback)
trainer = train_on_responses_only(
trainer,
instruction_part = "<|im_start|>user\n",
response_part = "<|im_start|>assistant\n",
)
# 开始训练
trainer_stats = trainer.train(resume_from_checkpoint=False)

48
train/save_model.py Normal file
View File

@@ -0,0 +1,48 @@
import os
from global_var import get_model, get_tokenizer
def save_model_to_dir(save_model_name, models_dir, model, tokenizer, save_method="default"):
"""
保存模型到指定目录
:param save_model_name: 要保存的模型名称
:param models_dir: 模型保存的基础目录
:param model: 要保存的模型
:param tokenizer: 要保存的tokenizer
:param save_method: 保存模式选项
- "default": 默认保存方式
- "merged_16bit": 合并为16位
- "merged_4bit": 合并为4位
- "lora": 仅LoRA适配器
- "gguf": 保存为GGUF格式
- "gguf_q4_k_m": 保存为q4_k_m GGUF格式
- "gguf_f16": 保存为16位GGUF格式
:return: 保存结果消息或错误信息
"""
try:
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name)
os.makedirs(save_path, exist_ok=True)
if save_method == "default":
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
elif save_method == "merged_16bit":
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_16bit")
elif save_method == "merged_4bit":
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_4bit_forced")
elif save_method == "lora":
model.save_pretrained_merged(save_path, tokenizer, save_method="lora")
elif save_method == "gguf":
model.save_pretrained_gguf(save_path, tokenizer)
elif save_method == "gguf_q4_k_m":
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="q4_k_m")
elif save_method == "gguf_f16":
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="f16")
else:
return f"不支持的保存模式: {save_method}"
return f"模型已保存到 {save_path} (模式: {save_method})"
except Exception as e:
return f"保存模型时出错: {str(e)}"