Compare commits

106 Commits

Author SHA1 Message Date
carry
5a386d6401 feat(dataset_generate_page): 添加 API 选择功能
- 在数据集生成页面添加 API 选择下拉框
- 实现 API 选择变更时的处理逻辑
- 更新数据集生成函数,增加 API 选择参数
- 优化页面布局和代码结构
2025-04-18 15:23:33 +08:00
carry
feaea1fb64 refactor(db): 重命名数据库引擎加载函数
- 将 get_sqlite_engine 函数重命名为 load_sqlite_engine
- 更新了相关模块中的导入和调用
- 这个改动是为了更好地反映函数的实际功能,提高代码可读性
2025-04-18 15:16:29 +08:00
carry
7242a2ce03 feat(frontend): 添加生成数据集进度条功能并优化了界面布局 2025-04-18 15:07:46 +08:00
carry
db6e2271dc fix(frontend): 修复 prompt_dropdown 变化时,dataframe没有相应的变化
- 将 prompt_dropdown 变化时的输出从 prompt_state 修改为 [prompt_state, variables_dataframe]
- 这个改动可能会在 prompt 变化时同时更新变量数据框
2025-04-18 14:03:26 +08:00
carry
d764537143 feat(dataset_generate_page): 更新数据集生成页面功能
- 添加模板变量列表展示和编辑功能
- 实现模板选择后动态更新变量列表
- 增加生成数据集按钮和相关逻辑
- 优化页面布局和交互
2025-04-16 12:39:48 +08:00
carry
8c35a38c47 feat(frontend): 更新模板选择功能
- 在模板选择变更时,获取所选模板的详细信息
- 创建 PromptTemplate 对象并获取输入变量列表
- 此更新为后续的模板编辑功能做准备
2025-04-15 21:31:50 +08:00
carry
7ee751c88f fix(frontend): 移除文档生成页面的冗余事件绑定代码
- 删除了原有的简单事件绑定逻辑,这些逻辑仅将输入值赋给状态变量
- 为后续添加更复杂的文档选择更改事件处理函数做准备
2025-04-15 20:44:26 +08:00
carry
b715b36a5f feat(frontend): 更新数据集生成页面并添加独立运行功能
- 重构导入路径,使用绝对路径替换相对路径
- 新增文档和模板选择的事件处理函数
- 添加独立运行数据集生成页面的功能
- 优化代码结构,提高可读性和可维护性
2025-04-15 17:13:52 +08:00
carry
8023233bb2 feat(prompt): 增加模板变量有效性检查
- 在 promptTempleta 模型中添加字段验证器
- 验证模板内容是否包含必要的 document_slice 变量
- 如果缺少该变量,抛出 ValueError 异常
2025-04-15 16:54:17 +08:00
carry
2a86b3b5b0 fix(db): 初始化 prompt store 时插入第一条记录的 ID 从 0 改为 1
- 将初始化时插入的第一条记录的 ID 从 0 改为 1
- 修正了文档节选的变量名,从 {content} 改为 {document_slice}
2025-04-15 16:45:12 +08:00
carry
ca1505304e fix(tools): 更新 tools/__init__.py 中的导入语句
- 将 from .doc import * 改为 from .document import *
- 这个修改统一了文档处理模块的命名,提高了代码的一致性和可读性
2025-04-15 16:31:55 +08:00
carry
df9260e918 fix(db): 修复初始提示词的变量花括号的空格问题 2025-04-15 16:13:57 +08:00
carry
df9aba0c6e refactor(tools): 重命名模块并更新导入
- 将 scan_doc_dir.py 重命名为 document.py
- 将 socket.py 重命名为 port.py
- 更新 __init__.py 中的导入语句
- 在 port.py 中添加测试代码,用于查找可用端口
2025-04-15 15:47:44 +08:00
carry
6b87dcb58f refactor(frontend): 重构数据集生成页面的变量命名逻辑
- 将 prompt_choices 变量重命名为 prompt_list,以更准确地反映其内容
- 更新相关代码中对这两个变量的引用,以保持一致性
2025-04-15 15:40:24 +08:00
carry
d0aebd17fa refactor(global_var): 重构全局变量管理
- 移除了 _docs 全局变量
- 更新了 get_docs() 函数,使其在每次调用时重新扫描文档目录
- 优化了全局变量初始化逻辑
2025-04-15 15:25:44 +08:00
carry
d9abf08184 fix(frontend): 修复表格选择事件的行数据获取问题
- 在 prompt_manage_page 和 setting_page 中更新了 select_record 函数
- 使用 DataFrame.iloc 方法获取选中行的数据,并转换为列表
- 添加了将第一列数据转换为整数的逻辑
- 更新了表格选择事件的参数,增加了输入和输出参数
- 将 gradio 版本升级到 5.25.0
2025-04-15 15:10:15 +08:00
carry
a27a1ab079 refactor(frontend): 重构训练页面布局并优化用户界面
- 调整数据集下拉框布局位置
- 新增超参数输入组件
- 修改训练日志输出框标签为"训练状态"
- 添加 TensorBoard 可视化 iframe 显示框
2025-04-15 00:12:09 +08:00
carry
aa758e3c2a feat(train_page): 添加 TensorBoard 可视化
- 在训练页面添加 TensorBoard iframe 显示框
- 实现动态生成 TensorBoard iframe 功能
- 更新训练按钮点击事件,同时更新 TensorBoard iframe
2025-04-14 23:28:43 +08:00
carry
664944f0c5 feat(frontend): 优化 TensorBoard 端口占用问题
- 新增端口检测逻辑,动态分配可用端口
- 修改 TensorBoard 启动过程,使用动态分配的端口
- 添加 socket 模块,用于端口检测
2025-04-14 17:06:44 +08:00
carry
9298438f98 feat(train_page): 启动 TensorBoard 进程并确保训练结束后终止
- 在训练页面中添加 TensorBoard 进程启动代码
- 创建日志目录并启动 TensorBoard 子进程
- 在训练结束后终止 TensorBoard 进程
2025-04-14 17:00:33 +08:00
carry
4f7926aec6 feat(train_page): 实现训练目录自动递增功能
- 在 training 文件夹下创建递增的目录结构
- 确保 training 文件夹存在
- 扫描现有目录,生成下一个可用的目录编号
- 更新训练模型函数,使用新的训练目录
2025-04-14 16:46:29 +08:00
carry
148f4afb25 fix(main): 修复unsloth没有最先导入的问题,删除了重复的导入语句 2025-04-14 16:31:47 +08:00
carry
11a3039775 fix(train_page): 修正模型训练保存路径 2025-04-14 16:31:00 +08:00
carry
a4289815ba build: 添加 tensorboard 依赖
- 在 requirements.txt 中添加 tensorboard>=2.19.0
- 此改动增加了 tensorboard 作为项目的新依赖项
2025-04-14 16:30:39 +08:00
carry
088067d335 train: 更新模型训练功能和日志记录方式
- 修改训练目录结构,将检查点和日志分开保存
- 添加 TensorBoard 日志记录支持
- 移除自定义 LossCallback 类,简化训练流程
- 更新训练参数和回调机制,提高代码可读性
- 在 requirements.txt 中添加 tensorboardX 依赖
2025-04-14 16:19:37 +08:00
carry
9fb31c46c8 feat(train): 添加训练过程中的日志记录和 loss 可视化功能
- 新增 LossCallback 类,用于在训练过程中记录 loss 数据
- 在训练模型函数中添加回调,实现日志记录和 loss 可视化
- 优化训练过程中的输出信息,增加当前步数和 loss 值的打印
2025-04-14 15:18:14 +08:00
carry
4f09823123 refactor(tools): 优化 train_model 函数定义
- 添加类型注解,提高代码可读性和维护性
- 使用多行格式定义函数参数,提升代码格式美观
2025-04-14 14:28:36 +08:00
carry
1a2ca3e244 refactor(train): 重构训练功能并移至新模块
- 将训练逻辑从 train_page.py 移至 tools/model.py
- 新增 train_model 函数,包含完整的训练流程
- 更新 train_page.py 中的回调函数,使用新的训练函数
- 移除了 train_page.py 中未使用的导入
2025-04-14 14:17:04 +08:00
carry
bb1d8fbd38 feat(train_page): 添加训练 Loss 曲线显示功能
- 在训练页面添加了 Loss 曲线图表
- 实现了 GradioLossCallback 类用于记录训练过程中的 Loss 数据
- 修改了训练函数,通过回调函数收集 Loss 信息并更新图表
- 优化了训练函数的返回值结构,支持同时返回文本日志和 Loss 数据
2025-04-13 21:49:43 +08:00
carry
4558929c52 fix: 调整了import的顺序,让unsloth最先import以提高性能 2025-04-13 21:35:47 +08:00
carry
0722748997 feat(train_page): 添加 LoRA 秩动态输入功能
- 在训练页面新增 LoRA 秩输入框,使用户可以动态设置 LoRA 秩
- 更新训练模型函数,添加 LoRA 秩参数并将其用于模型配置
- 保留原有功能,仅增加 LoRA 秩相关配置
2025-04-13 21:12:02 +08:00
carry
e08f0059bb feat(train_page): 优化训练过程以专注于响应生成
- 引入 train_on_responses_only 函数,用于优化训练过程
- 设置 instruction_part 和 response_part 参数,以适应特定的对话格式
- 此修改旨在提高模型在生成响应方面的性能和效率
2025-04-13 21:05:14 +08:00
carry
6d1fecbdac build: 更新依赖版本并添加新依赖
- 将 unsloth 的最低版本要求从 2025.3.9 提升到 2025.3.19
- 新增 sqlmodel 依赖,版本不低于 0.0.24
- 新增 jinja2 依赖,版本不低于 3.1.0
2025-04-13 20:29:33 +08:00
carry
79d3eb153c refactor(train_page): 优化训练页面布局和功能
- 移除了 max_steps_input 组件,减少不必要的输入项
- 将 per_device_train_batch_size_input 和 epoch_input 的标签简化为 "batch size" 和 "epoch"
- 新增 save_steps_input 组件,用于设置保存步数
- 修改 train_model 函数,移除了 max_steps 参数
- 更新了 trainer.train() 方法的调用,设置 resume_from_checkpoint=False
2025-04-13 01:56:10 +08:00
carry
80dae7c6e2 fix(global_var):修复_workdir非全局变量的bug 2025-04-13 01:52:05 +08:00
carry
2d39b91764 feat(train_page): 添加模型训练超参数配置功能
- 新增学习率、批次大小、最大训练步数等超参数输入组件
- 实现超参数在训练过程中的动态应用
- 调整训练参数以适应不同硬件环境
- 优化训练过程,支持按步数保存模型
2025-04-13 01:04:27 +08:00
carry
5094febcb4 refactor(global_var): 重构全局变量管理并添加工作目录功能
- 添加 _workdir 全局变量以存储工作目录路径
- 在 init_global_var 函数中初始化 _workdir
- 新增 get_workdir 函数以获取工作目录路径
- 调整全局变量的定义和初始化顺序
2025-04-13 00:54:55 +08:00
carry
eeee68dbd1 chore: 更新 .gitignore 文件
- 在 .gitignore 中添加 test.py 文件,避免测试代码被版本控制
2025-04-12 18:42:36 +08:00
carry
539e14d39c feat(frontend): 完成了前端微调的代码逻辑 2025-04-12 18:42:22 +08:00
carry
9784f2aed3 fix(tools): 修正__init__.py使得model.py正确导入 2025-04-12 01:16:07 +08:00
carry
611904cef9 feat(frontend): 添加数据集选择功能到训练页面
- 在 train_page.py 中添加数据集选择下拉框
- 从全局变量中获取数据集列表并设置初始值
- 添加交互性和自定义值支持
2025-04-11 19:43:34 +08:00
carry
8a9a080745 refactor(tools): 移除未使用的导入语句
移除了 tools/model.py 文件中未使用的 get_chat_template 导入语句。这个修改提高了代码的可读性和维护性。
2025-04-11 19:43:19 +08:00
carry
a23ad88769 fix(frontend): 修复删除提示功能中的数据库连接错误
- 将 prompt_store 更改为 get_prompt_store(),以解决数据库连接未建立的问题
- 优化了删除提示功能的代码,提高了系统稳定性
2025-04-11 18:53:17 +08:00
carry
83427aaaba feat(frontend): 增加超参数设置并优化聊天页面布局
- 在聊天页面添加了超参数输入框,包括最大生成长度、温度、Top-p 采样和重复惩罚
- 优化了聊天框的布局,使用 gr.Row() 和 gr.Column() 实现了更合理的界面结构
- 更新了 bot 函数,支持根据用户输入的超参数进行文本生成
- 修复了一些代码格式问题,提高了代码的可读性
2025-04-11 18:48:13 +08:00
carry
61672021ef fix(frontend): 修复聊天页面并的流式回复
- 导入 Thread 和 TextIteratorStreamer 以支持流式生成
- 重新设计 user 和 bot 函数,优化对话历史处理
- 添加异常处理和错误信息显示
- 改进模型和分词器的加载逻辑
- 优化聊天页面布局和交互
2025-04-11 18:33:31 +08:00
carry
fb6157af05 feat(frontend): 初步实现聊天页面的智能回复功能 2025-04-11 18:08:38 +08:00
carry
f655936741 refactor(global_var): 重构全局变量初始化方法
- 新增 init_global_var 函数,用于统一初始化所有全局变量
- 修改 get_prompt_store、get_sql_engine、get_docs 和 get_datasets 函数,使用新的全局变量初始化逻辑
- 更新 main.py 中的代码,使用新的 init_global_var 函数替代原有的单独初始化方法
2025-04-11 18:08:16 +08:00
carry
ab7897351a fix(global_var): 修复全局变量多文件多副本的不统一问题 2025-04-11 18:04:42 +08:00
carry
216bfe39ae feat(tools): 添加格式化对话数据的函数
- 新增 formatting_prompts_func 函数,用于格式化对话数据
- 该函数将问题和答案组合成对话形式,并使用 tokenizer.apply_chat_template 进行格式化
- 更新 imports,添加了 unsloth.chat_templates 模块
2025-04-11 17:56:46 +08:00
carry
0fa2b51a79 refactor(frontend): 优化模型管理页面的交互和显示
- 将状态输出从 Textbox 改为 Label 组件,提高用户体验
- 添加 get_model_name 函数以获取模型名称,提高代码复用性
- 更新模型加载、卸载和保存后的状态显示,使信息更加准确
- 优化模型列表刷新功能,确保模型列表实时更新
2025-04-11 00:14:40 +08:00
carry
cbb3a09dd8 feat(tools): 添加模型名称获取函数
- 在 tools 目录下新增 model.py 文件
- 实现 get_model_name 函数,用于获取模型的名称
- 更新 tools/__init__.py,导入新的 get_model_name 函数
2025-04-10 22:05:04 +08:00
carry
2e552c186d refactor(frontend): 重构模型选择界面的变量命名
- 将模型选择的 Dropdown 组件从 dropdown 重命名为 model_select_dropdown,提高代码可读性
- 更新 load_button 和 refresh_button 的输出目标,以适应新的变量名
2025-04-10 21:19:58 +08:00
carry
1b3f546669 refactor(frontend): 重构前端页面并添加独立运行功能
- 在 chat_page 和 prompt_manage_page 中添加了独立运行的入口
- 引入 sys 和 pathlib 模块以支持路径操作
- 修改了模块导入方式,使其能够作为独立脚本运行
- 优化了代码结构,提高了可读性和可维护性
2025-04-10 21:18:05 +08:00
carry
402bc73dce feat(model_manage_page): 增加模型保存和刷新功能
- 新增保存模型功能,用户可以输入模型名称并保存当前加载的模型
- 添加刷新模型列表按钮,用户可以随时更新模型下拉菜单中的选项
- 优化页面布局,使按钮和输入框更加合理地排列
2025-04-10 20:18:03 +08:00
carry
bb5851f800 build: 添加 unsloth 依赖
- 在 requirements.txt 中添加 unsloth>=2025.3.9 依赖
2025-04-10 19:56:44 +08:00
carry
a407fa1f76 feat(model_manage_page): 实现模型加载和卸载功能
- 添加模型加载和卸载按钮
- 实现模型加载和卸载的逻辑
- 添加相关模块的导入
- 扫描模型目录并显示在下拉框中
2025-04-10 19:52:08 +08:00
carry
4b465ec917 chore: 更新 .gitignore 文件
- 修改测试代码注释,扩大至参考代码
- 新增 refer/ 目录到忽略列表
2025-04-10 17:38:29 +08:00
carry
e7cc03297b feat(frontend): 添加了简单聊天机器人页面 2025-04-10 17:38:02 +08:00
carry
051d1a7535 feat(frontend): 添加模型管理页面并初始化模型相关全局变量
- 在 frontend/__init__.py 中添加 model_manage_page 模块引用
- 新增 model_manage_page.py 文件,实现模型管理页面的基本框架
- 在 global_var.py 中添加 model 和 tokenizer 全局变量
- 在 main.py 中集成模型管理页面到主应用的 Tabs 组件中
2025-04-10 17:37:45 +08:00
carry
97172f9596 feat(dataset): 设置问答数据集展示页面的每页显示数量
- 在 dataset_manage_page 函数中添加 samples_per_page 参数
- 设置每页显示的样本数量为 20 条
2025-04-10 16:12:59 +08:00
carry
f582820443 feat(tools): 添加生成 Pydantic V2 模型示例 JSON 的工具脚本
- 新增 json_example.py 脚本,用于生成 Pydantic V2 模型的示例 JSON 数据结构
- 支持列表、字典、可选类型以及基本数据类型(字符串、整数、浮点数、布尔值、日期和时间)的示例生成
- 可递归生成嵌套模型的示例 JSON
- 示例使用了项目中的 Q_A 模型,生成了包含多个 Q_A 对象的列表 JSON 结构
2025-04-10 15:38:28 +08:00
carry
8fb9f785b9 feat(frontend): 展示数据集管理页面的问答数据
- 添加 QA 数据集展示组件
- 实现数据集选择时动态加载对应的问答数据
- 优化数据集管理页面布局
2025-04-09 22:23:55 +08:00
carry
2c8e54bb1e feat(dataset): 初步完成数据集管理页面和功能 2025-04-09 20:49:20 +08:00
carry
932d1e2687 refactor(schema): 修改数据集名称默认值
- 将 dataset 类中的 name 字段默认值从 None 改为 ""
- 这个改动确保了数据集名称始终有一个默认的空字符串值,而不是 None,提高了数据一致性和代码健壮性
2025-04-09 19:42:00 +08:00
carry
202d4c44df feat(db): 添加数据集存储和读取功能
- 新增 dataset_store.py 文件,实现数据集的存储和读取功能
- 添加 get_all_dataset 函数,用于获取所有数据集
- 使用 tinydb 和 json 进行数据持久化
- 在项目根目录下创建 workdir/dataset 目录用于存储数据集文件
2025-04-09 18:21:27 +08:00
carry
4d77c429bd refactor(schema): 更新 dataset 模型并为 doc 模型添加版本字段
- 在 doc 模型中添加 version 字段,用于表示文档版本
- 将 dataset 模型中的 source_doc 字段类型从 list[doc] 改为 doc,简化数据结构
2025-04-09 18:18:29 +08:00
carry
41447c5ed4 feat(dataset): 添加数据集来源文档字段
- 在 dataset 模型中增加 source_doc 字段,用于记录数据集的来源文档
- 新增字段为可选列表,包含 doc 类型的元素
2025-04-09 17:37:24 +08:00
carry
84fe78243a feat(tools): 添加 JSON 数据转换为 dataset 的工具脚本
- 新增 convert_json_to_dataset 函数,用于将 JSON 数据转换为 dataset 对象
- 实现了从 JSON 文件读取数据、转换为 dataset 格式并输出到文件的功能
- 该工具可帮助用户将旧数据集快速转换为新的 dataset 格式
2025-04-09 17:31:53 +08:00
carry
4d8754aad2 feat(frontend): 实现数据集生成页面的文档和模板选择功能
- 添加文档和模板的下拉选择框
- 实现文档和模板选择后的状态更新
- 优化页面布局,分为文档和模板两个列
2025-04-09 17:19:40 +08:00
carry
541d37c674 feat(schema): 新增数据集相关模型并添加文档扫描功能
- 新增 dataset.py 文件,定义数据集相关模型
- 新增 tools 目录,包含解析 Markdown 和扫描文档的功能
- 修改 parse_markdown.py,增加处理 Markdown 文件的函数
- 新增 scan_doc_dir.py,实现文档目录扫描功能
2025-04-09 13:02:18 +08:00
carry
6a00699472 feat(frontend): 实现提示词模板管理页面
- 添加获取、添加、编辑和删除提示词功能
- 实现数据表格展示和操作
2025-04-09 11:08:18 +08:00
carry
ff8162890d refactor(db): 移除了提示词模板中冗余的 JSON 格式说明 2025-04-09 10:35:11 +08:00
carry
daddcd34da fix(db): 为 promptStore 添加空数据库初始化逻辑
- 在 initialize_prompt_store 函数中增加空数据库检查和初始化逻辑
- 为默认模板添加 id 字段,设置为 0
2025-04-09 10:28:31 +08:00
carry
5c7ced30df fix(db): 修复 prompt_store 初始化逻辑
- 在插入默认模板之前检查数据库是否为空,如果数据库已有数据,则跳过插入默认模板
2025-04-09 10:26:14 +08:00
carry
9741ce6b92 refactor(db): 优化了代码,调整了import顺序,删除了无用变量 2025-04-09 10:19:57 +08:00
carry
67281fe06a feat(db): 添加 prompt 存储功能
- 新增 prompt_store 模块,使用 TinyDB 存储 prompt 模板
- 在全局变量中添加 prompt_store 实例
- 更新 main.py,初始化 prompt 存储
- 新增 prompt 模板的 Pydantic 模型
- 更新 requirements.txt,添加 tinydb 依赖
2025-04-09 09:58:42 +08:00
carry
2d905a0270 refactor(db): 调整导入模块顺序
- 将 os 和 sys 模块导入提前到文件顶部
- 优化代码结构,遵循常见的 Python 导入模块顺序
2025-04-09 09:57:20 +08:00
carry
374b124cf8 feat(setting_page): 添加供应商后清空输入框
- 修改 add_provider 函数,返回清空后的输入框值
- 更新 add_button.click 事件处理,添加清空输入框的输出
2025-04-09 08:17:43 +08:00
carry
74ae5e1426 refactor(db): 重命名数据库引擎获取函数
将 get_engine 函数重命名为 get_sqlite_engine,以更清晰地表示其功能和用途。
- 更新了 db/__init__.py 中的导入和 __all__ 列表
- 修改了 db/init_db.py 中的函数定义
- 更新了前端设置页面和全局变量中的导入和函数调用

此更改提高了代码的可读性和维护性,特别是在将来可能添加其他类型数据库引擎的情况下。
2025-04-09 08:12:59 +08:00
carry
0a6ae7a4ee feat(frontend): 重构前端页面并添加新功能
- 重命名 dataset_page 为 prompt_manage_page,支持提示词模板管理
- 新增 dataset_generate_page 和 dataset_manage_page 页面
- 更新 main.py 中的页面引用和标签名称
- 修改前端初始化文件,使用 * 导入所有页面模块
2025-04-09 08:11:40 +08:00
carry
faf72d1e99 feat(frontend): 完成了编辑 API Provider 功能 2025-04-09 08:04:40 +08:00
carry
cce5e4e114 feat(frontend): 完成了 API Provider 删除和添加了编辑功能的函数 2025-04-09 00:48:22 +08:00
carry
293f63017f feat(frontend): 添加 API Provider 表格选中行状态监听
- 新增选中行的全局变量 selected_row
- 实现 select_record 函数来保存选中行数据
- 在表格中添加选中行事件监听
- 优化代码结构,提高可读性和可维护性
2025-04-09 00:37:15 +08:00
carry
2e31f4f57c build: 升级 gradio 至 5.0.0 版本
- 将 requirements.txt 中 gradio 版本要求从 >=3.0.0 修改为 >=5.0.0
- 此次升级可能会影响项目的用户界面或功能,需要进行测试以确保兼容性
2025-04-08 16:17:21 +08:00
carry
967133162e refactor(schema): 在 APIProvider 模型中设置 id 字段为不可变
- 在 APIProvider 类中,将 id 字段的定义更新,添加 allow_mutation=False 参数
- 这个改动确保了主键字段在创建后不可更改,提高了数据的一致性和安全性
2025-04-08 16:02:46 +08:00
carry
dc28c25c65 feat(frontend): 更新设置页面按钮样式
- 为"添加新API"按钮添加 primary 样式
- 为"编辑选中行"按钮添加 primary 样式
- 为"删除选中行"按钮添加 stop 样式
- 保持"刷新数据"按钮的 secondary 样式
2025-04-08 14:23:31 +08:00
carry
70b64dc3d3 refactor(db): 重命名数据库初始化函数以明确其适用范围
- 将 initialize_db 函数重命名为 initialize_sqlite_db,以明确该函数专用于 SQLite 数据库
- 更新相关模块和文件中的引用,以确保代码一致性
- 此修改旨在提高代码的可读性和维护性,特别是未来可能接入多种数据库时
2025-04-08 14:16:12 +08:00
carry
b52ca9b1af docs: 添加项目基础文档
- 新增 LICENSE 文件,定义项目使用的 MIT 开源许可证
- 新增 README.md 文件,简要介绍项目内容和技术栈
2025-04-08 13:35:30 +08:00
carry
46b4453ccd refactor(frontend): 重构数据库连接方式
- 移除各前端页面中重复的数据库引擎初始化代码
- 在 global_var.py 中统一初始化和存储数据库引擎
- 更新 setting_page.py 和 main.py 中的数据库连接逻辑
- 优化代码结构,提高可维护性和可扩展性
2025-04-08 13:19:58 +08:00
carry
d5b528d375 chore: 更新 .gitignore 文件
- 保留 unsloth_compiled_cache 目录
- 添加 test.ipynb 到忽略列表,避免测试代码影响版本控制
2025-04-08 12:28:42 +08:00
carry
475cd033d9 build: 添加 langchain 依赖
- 在 requirements.txt 中添加 langchain>=0.3 版本的依赖
- 保持其他依赖版本不变
2025-04-08 11:53:58 +08:00
carry
3970a67df3 refactor(dataset_generation): 增加 APIProvider 模型字段的最小长度验证
- 为 base_url 和 model_id 字段添加 min_length=1 的验证
- 更新字段描述,明确这些字段不能为空
2025-04-07 23:37:14 +08:00
carry
286db405ca feat(frontend): 优化设置页面并添加数据刷新功能
- 为 get_providers 函数添加异常处理,提高数据获取的稳定性
- 在设置页面添加刷新按钮,用户可手动触发数据刷新
- 优化页面布局,调整组件间距和对齐方式
2025-04-07 23:17:43 +08:00
carry
d40f5b1f24 fix(frontend): 优化 API Provider 添加功能并处理异常
- 为 model_id、base_url 和 api_key 添加空值检查,避免无效输入
- 添加异常处理,确保在出现错误时能够及时响应并提示用户
- 优化 add_provider 函数,提高代码可读性和健壮性
2025-04-07 13:02:45 +08:00
carry
7a77f61ee6 feat(frontend): 添加 API Provider 的增加功能 2025-04-07 00:28:52 +08:00
carry
841e14a093 feat(frontend): 添加数据集页面并重构主页面布局
- 新增 dataset_page 模块,实现数据集页面的基本布局
- 重构 main.py 中的页面加载方式,使用列表收集所有页面
- 更新主页面布局,将聊天页面作为第一个选项卡
- 调整设置页面的加载方式,直接使用函数调用
2025-04-06 22:49:37 +08:00
carry
2ff077bb1c refactor(frontend): 重构前端页面导入方式
- 在 main.py 中使用更简洁的导入方式
- 新增 __init__.py 文件以简化前端页面的导入
2025-04-06 22:46:31 +08:00
carry
513b639bce feat(frontend): 添加了设置页面的api provider展示 2025-04-06 22:05:56 +08:00
carry
f93f213a31 feat(db): 添加数据库连接和初始化功能
- 新增 db/__init__.py 文件,提供数据库连接和初始化的接口
- 导入 get_engine 和 initialize_db 函数,方便外部使用
2025-04-06 21:27:25 +08:00
carry
10b4c29bda docs(db): 修改了代码注释 2025-04-06 21:26:53 +08:00
carry
b1e98ca913 feat(db): 初始化数据库并创建 APIProvider 表
- 新增 init_db.py 文件,实现数据库初始化和 APIProvider 表的创建
- 新增 dataset_generation.py 文件,定义 LLMRequest、LLMResponse 和 APIProvider 模型
- 在初始化数据库时,如果环境变量中存在 API_KEY、BASE_URL 和 MODEL_ID,会自动添加一条 APIProvider 记录
2025-04-06 19:59:23 +08:00
carry
2d5a5277ae refactor(schema): 更新 prompt 导入
- 将 prompt_templeta 重命名为 promptTempleta,以符合驼峰命名规范
- 优化导入语句格式
2025-04-06 19:39:43 +08:00
carry
519a5f3773 feat(frontend): 添加前端页面模块并实现基本布局
- 新增 chat_page.py、setting_page.py 和 train_page.py 文件,分别实现聊天、设置和微调页面的基本布局
- 添加 main.py 文件,集成所有页面并创建主应用
- 在 requirements.txt 中添加 gradio 依赖
2025-04-06 14:49:01 +08:00
carry
1f4d491694 build: 添加 pydantic 依赖 2025-04-05 01:00:33 +08:00
carry
8ce4f1e373 chore: 添加 .roo 到 .gitignore 文件
- 在 .gitignore 文件中添加 .roo 目录,以忽略相关文件
2025-04-05 00:59:42 +08:00
carry
3395b860e4 refactor(parse_markdown): 重构 Markdown 解析逻辑并使用 Pydantic 模型
将 MarkdownNode 类重构为使用 Pydantic 模型,提高代码的可维护性和类型安全性。同时,将解析逻辑与节点操作分离,简化代码结构。
2025-04-04 20:50:39 +08:00
30 changed files with 1564 additions and 31 deletions

6
.gitignore vendored
View File

@@ -11,6 +11,7 @@ env/
# IDE
.vscode/
.idea/
.roo
# Environment files
.env
@@ -28,3 +29,8 @@ workdir/
# cache
unsloth_compiled_cache
# 测试和参考代码
test.ipynb
test.py
refer/

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 C-a-r-r-y
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

15
README.md Normal file
View File

@@ -0,0 +1,15 @@
# 基于文档驱动的自适应编码大模型微调框架
## 简介
本人的毕业设计
### 项目概述
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
### 项目技术
* 使用unsloth框架在GPU上实现大语言模型的qlora微调
* 使用langchain框架编写工作流实现批量生成微调语料
* 使用tinydb和sqlite实现数据的持久化
* 使用gradio框架实现前端展示
**施工中......**

11
db/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
from .init_db import load_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
from .dataset_store import get_all_dataset
__all__ = [
"load_sqlite_engine",
"initialize_sqlite_db",
"get_prompt_tinydb",
"initialize_prompt_store",
"get_all_dataset"
]

50
db/dataset_store.py Normal file
View File

@@ -0,0 +1,50 @@
import os
import sys
import json
from pathlib import Path
from typing import List
from tinydb import TinyDB, Query
from tinydb.storages import MemoryStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A
def get_all_dataset(workdir: str) -> TinyDB:
"""
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
Args:
workdir (str): 工作目录路径
Returns:
TinyDB: 包含所有数据集对象的TinyDB对象
"""
dataset_dir = os.path.join(workdir, "dataset")
if not os.path.exists(dataset_dir):
return TinyDB(storage=MemoryStorage)
db = TinyDB(storage=MemoryStorage)
for filename in os.listdir(dataset_dir):
if filename.endswith(".json"):
filepath = os.path.join(dataset_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
db.insert(data)
except (json.JSONDecodeError, Exception) as e:
print(f"Error loading dataset file {filename}: {str(e)}")
continue
return db
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取所有数据集
datasets = get_all_dataset(workdir)
# 打印结果
print(f"Found {len(datasets)} datasets:")
for ds in datasets.all():
print(f"- {ds['name']} (ID: {ds['id']})")

79
db/init_db.py Normal file
View File

@@ -0,0 +1,79 @@
import os
import sys
from sqlmodel import SQLModel, create_engine, Session
from sqlmodel import select
from typing import Optional
from pathlib import Path
from dotenv import load_dotenv
from sqlalchemy.engine import Engine
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset_generation import APIProvider
# 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None
def load_sqlite_engine(workdir: str) -> Engine:
"""
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
Args:
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
Returns:
Engine: SQLAlchemy 数据库引擎实例。
"""
global _engine
if not _engine:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "db.sqlite")
# 创建数据库URL
db_url = f"sqlite:///{db_path}"
# 创建数据库引擎
_engine = create_engine(db_url)
return _engine
def initialize_sqlite_db(engine: Engine) -> None:
"""
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
Args:
engine (Engine): SQLAlchemy 数据库引擎实例。
"""
# 创建所有表结构
SQLModel.metadata.create_all(engine)
# 加载环境变量
load_dotenv()
# 从环境变量中获取API相关配置
api_key = os.getenv("API_KEY")
base_url = os.getenv("BASE_URL")
model_id = os.getenv("MODEL_ID")
# 如果所有必要的环境变量都存在,则插入初始数据
if api_key and base_url and model_id:
with Session(engine) as session:
# 查询是否已存在APIProvider记录
statement = select(APIProvider).limit(1)
existing_provider = session.exec(statement).first()
# 如果不存在则插入新的APIProvider记录
if not existing_provider:
provider = APIProvider(
base_url=base_url,
model_id=model_id,
api_key=api_key
)
session.add(provider)
session.commit()
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎
engine = load_sqlite_engine(workdir)
# 初始化数据库
initialize_sqlite_db(engine)

62
db/prompt_store.py Normal file
View File

@@ -0,0 +1,62 @@
import os
import sys
from typing import Optional
from pathlib import Path
from datetime import datetime, timezone
from tinydb import TinyDB, Query
from tinydb.storages import JSONStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.prompt import promptTempleta
# 全局变量用于存储TinyDB实例
_db_instance: Optional[TinyDB] = None
# 自定义存储类用于格式化JSON数据
def get_prompt_tinydb(workdir: str) -> TinyDB:
"""
获取TinyDB实例。如果实例尚未创建则创建一个新的并返回。
Args:
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
Returns:
TinyDB: TinyDB数据库实例
"""
global _db_instance
if not _db_instance:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "prompts.json")
# 创建TinyDB实例
_db_instance = TinyDB(db_path)
return _db_instance
def initialize_prompt_store(db: TinyDB) -> None:
"""
初始化prompt模板存储
Args:
db (TinyDB): TinyDB数据库实例
"""
# 检查数据库是否为空
if not db.all(): # 如果数据库中没有数据
db.insert(promptTempleta(
id=1,
name="default",
description="默认提示词模板",
content="""项目名为:{project_name}
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
文档节选:{document_slice}""").model_dump())
# 如果数据库中已有数据,则跳过插入
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库实例
db = get_prompt_tinydb(workdir)
# 初始化prompt存储
initialize_prompt_store(db)

7
frontend/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
from .train_page import *
from .chat_page import *
from .setting_page import *
from .model_manage_page import *
from .dataset_manage_page import *
from .dataset_generate_page import *
from .prompt_manage_page import *

97
frontend/chat_page.py Normal file
View File

@@ -0,0 +1,97 @@
import gradio as gr
import sys
from pathlib import Path
from threading import Thread # 需要导入 Thread
from transformers import TextIteratorStreamer # 使用 TextIteratorStreamer
# 假设 global_var.py 在父目录
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer # 假设这两个函数能正确获取模型和分词器
def chat_page():
with gr.Blocks() as demo:
# 聊天框
gr.Markdown("## 对话")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(type="messages", label="聊天机器人")
msg = gr.Textbox(label="输入消息")
with gr.Column(scale=1):
# 新增超参数输入框
max_new_tokens_input = gr.Textbox(label="最大生成长度", value="1024")
temperature_input = gr.Textbox(label="温度 (Temperature)", value="0.8")
top_p_input = gr.Textbox(label="Top-p 采样", value="0.95")
repetition_penalty_input = gr.Textbox(label="重复惩罚", value="1.1")
clear = gr.Button("清除对话")
def user(user_message, history: list):
return "", history + [{"role": "user", "content": user_message}]
def bot(history: list, max_new_tokens, temperature, top_p, repetition_penalty):
model = get_model()
tokenizer = get_tokenizer()
if not history:
yield history
return
if model is None or tokenizer is None:
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"})
yield history
return
try:
inputs = tokenizer.apply_chat_template(
history,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 将超参数转换为数值类型
generation_kwargs = dict(
input_ids=inputs,
streamer=streamer,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
repetition_penalty=float(repetition_penalty),
do_sample=True,
use_cache=False
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
history.append({"role": "assistant", "content": ""})
for new_text in streamer:
if new_text:
history[-1]["content"] += new_text
yield history
except Exception as e:
import traceback
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
history[-1]["content"] = error_message
else:
history.append({"role": "assistant", "content": error_message})
yield history
# 更新 .then() 调用,将超参数传递给 bot 函数
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, max_new_tokens_input, temperature_input, top_p_input, repetition_penalty_input], chatbot
)
clear.click(lambda: [], None, chatbot, queue=False)
return demo
if __name__ == "__main__":
from model_manage_page import model_manage_page
# 装载两个页面
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()

View File

@@ -0,0 +1,136 @@
import gradio as gr
import sys
from pathlib import Path
from langchain.prompts import PromptTemplate
from sqlmodel import Session, select
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider
from global_var import get_docs, get_prompt_store, get_sql_engine
def dataset_generate_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集生成")
with gr.Row():
with gr.Column():
docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown(
choices=docs_list,
value=initial_doc,
label="选择文档",
interactive=True
)
doc_choice = gr.State(value=initial_doc)
with gr.Column():
prompts = get_prompt_store().all()
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_list[0] if prompt_list else None
# 初始化Dataframe的值
initial_dataframe_value = []
if initial_prompt:
selected_prompt_id = int(initial_prompt.split(" ")[0])
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
initial_dataframe_value = [[var, ""] for var in input_variables]
prompt_dropdown = gr.Dropdown(
choices=prompt_list,
value=initial_prompt,
label="选择模板",
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
with gr.Column():
# 从数据库获取API Provider列表
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
api_list = [f"{p.id} {p.model_id}" for p in providers]
initial_api = api_list[0] if api_list else None
api_dropdown = gr.Dropdown(
choices=api_list,
value=initial_api,
label="选择API",
interactive=True
)
api_choice = gr.State(value=initial_api)
generate_button = gr.Button("生成数据集",variant="primary")
output_text = gr.Textbox(label="生成结果", interactive=False)
variables_dataframe = gr.Dataframe(
headers=["变量名", "变量值"],
datatype=["str", "str"],
interactive=True,
label="变量列表",
value=initial_dataframe_value # 设置初始化数据
)
def on_doc_change(selected_doc):
return selected_doc
def on_api_change(selected_api):
return selected_api
def on_prompt_change(selected_prompt):
if not selected_prompt:
return None, []
selected_prompt_id = int(selected_prompt.split(" ")[0])
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
dataframe_value = [] if input_variables is None else input_variables
return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, progress=gr.Progress()):
variables_dict = {}
# 正确遍历DataFrame的行数据
for _, row in variables_dataframe.iterrows():
var_name = row['变量名'].strip()
var_value = row['变量值'].strip()
if var_name:
variables_dict[var_name] = var_value
import time
total_steps = 10
for i in range(total_steps):
# 模拟每个步骤的工作负载
time.sleep(0.5)
# 更新进度条
# 第一个参数是当前的进度比例 (0.0 到 1.0)
# desc 参数可以动态更新进度条旁边的描述文字
current_progress = (i + 1) / total_steps
progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}")
return "all done"
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
generate_button.click(
on_generate_click,
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe],
outputs=output_text
)
return demo
if __name__ == "__main__":
from global_var import init_global_var
init_global_var("workdir")
demo = dataset_generate_page()
demo.launch()

View File

@@ -0,0 +1,55 @@
import gradio as gr
from global_var import get_datasets
from tinydb import Query
def dataset_manage_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集管理")
with gr.Row():
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
# 添加数据集展示组件
qa_dataset = gr.Dataset(
components=["text", "text"],
label="问答数据",
headers=["问题", "答案"],
samples=[["示例问题", "示例答案"]],
samples_per_page=20,
)
def update_qa_display(dataset_name):
if not dataset_name:
return {"samples": [], "__type__": "update"}
# 从数据库获取数据集
Dataset = Query()
ds = get_datasets().get(Dataset.name == dataset_name)
if not ds:
return {"samples": [], "__type__": "update"}
# 提取所有Q_A数据
qa_list = []
for item in ds["dataset_items"]:
for qa in item["message"]:
qa_list.append([qa["question"], qa["answer"]])
return {"samples": qa_list, "__type__": "update"}
# 绑定事件更新QA数据显示
dataset_dropdown.change(
update_qa_display,
inputs=dataset_dropdown,
outputs=qa_dataset
)
return demo

View File

@@ -0,0 +1,107 @@
import gradio as gr
import os # 导入os模块以便扫描文件夹
import sys
from pathlib import Path
from unsloth import FastLanguageModel
import torch
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
from tools.model import get_model_name
def model_manage_page():
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
models_dir = os.path.join(workdir, "models")
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))] # 扫描models文件夹下的所有子文件夹
with gr.Blocks() as demo:
gr.Markdown("## 模型管理")
state_output = gr.Label(label="当前状态",value="当前未加载模型") # 将 Textbox 改为 Label
with gr.Row():
with gr.Column(scale=3):
model_select_dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中并设置为可选
max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0)
load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True)
with gr.Column(scale=1):
load_button = gr.Button("加载模型", variant="primary")
unload_button = gr.Button("卸载模型", variant="stop")
refresh_button = gr.Button("刷新模型列表", variant="secondary") # 新增刷新按钮
with gr.Row():
with gr.Column(scale=3):
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
with gr.Column(scale=1):
save_button = gr.Button("保存模型", variant="secondary")
def load_model(selected_model, max_seq_length, load_in_4bit):
try:
# 判空操作,如果模型已加载,则先卸载
if get_model() is not None:
unload_model()
model_path = os.path.join(models_dir, selected_model)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
)
set_model(model)
set_tokenizer(tokenizer)
return f"模型 {get_model_name(model)} 已加载"
except Exception as e:
return f"加载模型时出错: {str(e)}"
load_button.click(fn=load_model, inputs=[model_select_dropdown, max_seq_length_input, load_in_4bit_input], outputs=state_output)
def unload_model():
try:
# 将模型移动到 CPU
model = get_model()
if model is not None:
model.cpu()
# 清空 CUDA 缓存
torch.cuda.empty_cache()
# 将模型和tokenizer设置为 None
set_model(None)
set_tokenizer(None)
return "当前未加载模型"
except Exception as e:
return f"卸载模型时出错: {str(e)}"
unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
def save_model(save_model_name):
try:
global model, tokenizer
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name)
os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
return f"模型已保存到 {save_path}"
except Exception as e:
return f"保存模型时出错: {str(e)}"
save_button.click(fn=save_model, inputs=save_model_name_input, outputs=state_output)
def refresh_model_list():
try:
nonlocal model_folders
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))]
return gr.Dropdown(choices=model_folders)
except Exception as e:
return f"刷新模型列表时出错: {str(e)}"
refresh_button.click(fn=refresh_model_list, inputs=None, outputs=model_select_dropdown)
return demo
if __name__ == "__main__":
demo = model_manage_page()
demo.queue()
demo.launch()

View File

@@ -0,0 +1,130 @@
import gradio as gr
import sys
from pathlib import Path
from typing import List
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_prompt_store
from schema.prompt import promptTempleta
def prompt_manage_page():
def get_prompts() -> List[List[str]]:
selected_row = None
try:
db = get_prompt_store()
prompts = db.all()
return [
[p["id"], p["name"], p["description"], p["content"]]
for p in prompts
]
except Exception as e:
raise gr.Error(f"获取提示词失败: {str(e)}")
def add_prompt(name, description, content):
try:
db = get_prompt_store()
new_prompt = promptTempleta(
name=name if name else "",
description=description if description else "",
content=content if content else ""
)
prompt_id = db.insert(new_prompt.model_dump())
# 更新ID
db.update({"id": prompt_id}, doc_ids=[prompt_id])
return get_prompts(), "", "", "" # 返回清空后的输入框值
except Exception as e:
raise gr.Error(f"添加失败: {str(e)}")
def edit_prompt():
global selected_row
if not selected_row:
raise gr.Error("请先选择要编辑的行")
try:
db = get_prompt_store()
db.update({
"name": selected_row[1] if selected_row[1] else "",
"description": selected_row[2] if selected_row[2] else "",
"content": selected_row[3] if selected_row[3] else ""
}, doc_ids=[selected_row[0]])
return get_prompts()
except Exception as e:
raise gr.Error(f"编辑失败: {str(e)}")
def delete_prompt():
global selected_row
if not selected_row:
raise gr.Error("请先选择要删除的行")
try:
db = get_prompt_store()
db.remove(doc_ids=[selected_row[0]])
return get_prompts()
except Exception as e:
raise gr.Error(f"删除失败: {str(e)}")
selected_row = None # 保存当前选中行的全局变量
def select_record(dataFrame ,evt: gr.SelectData):
global selected_row
selected_row = dataFrame.iloc[evt.index[0]].tolist()
selected_row[0] = int(selected_row[0])
print(selected_row)
with gr.Blocks() as demo:
gr.Markdown("## 提示词模板管理")
with gr.Row():
with gr.Column(scale=1):
name_input = gr.Textbox(label="模板名称")
description_input = gr.Textbox(label="模板描述")
content_input = gr.Textbox(label="模板内容", lines=10)
add_button = gr.Button("添加新模板", variant="primary")
with gr.Column(scale=3):
prompt_table = gr.DataFrame(
headers=["id", "名称", "描述", "内容"],
datatype=["number", "str", "str", "str"],
interactive=True,
value=get_prompts(),
wrap=True,
col_count=(4, "auto")
)
with gr.Row():
refresh_button = gr.Button("刷新数据", variant="secondary")
edit_button = gr.Button("编辑选中行", variant="primary")
delete_button = gr.Button("删除选中行", variant="stop")
refresh_button.click(
fn=get_prompts,
outputs=[prompt_table],
queue=False
)
add_button.click(
fn=add_prompt,
inputs=[name_input, description_input, content_input],
outputs=[prompt_table, name_input, description_input, content_input]
)
prompt_table.select(fn=select_record,
inputs=[prompt_table],
outputs=[],
show_progress="hidden")
edit_button.click(
fn=edit_prompt,
inputs=[],
outputs=[prompt_table]
)
delete_button.click(
fn=delete_prompt,
inputs=[],
outputs=[prompt_table]
)
return demo
if __name__ == "__main__":
demo = prompt_manage_page()
demo.queue()
demo.launch()

131
frontend/setting_page.py Normal file
View File

@@ -0,0 +1,131 @@
import gradio as gr
from typing import List
from sqlmodel import Session, select
from schema import APIProvider
from global_var import get_sql_engine
def setting_page():
def get_providers() -> List[List[str]]:
selected_row = None
try: # 添加异常处理
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
return [
[p.id, p.model_id, p.base_url, p.api_key or ""]
for p in providers
]
except Exception as e:
raise gr.Error(f"获取数据失败: {str(e)}")
def add_provider(model_id, base_url, api_key):
try:
with Session(get_sql_engine()) as session:
new_provider = APIProvider(
model_id=model_id if model_id else None,
base_url=base_url if base_url else None,
api_key=api_key if api_key else None
)
session.add(new_provider)
session.commit()
session.refresh(new_provider)
return get_providers(), "", "", "" # 返回清空后的输入框值
except Exception as e:
raise gr.Error(f"添加失败: {str(e)}")
def edit_provider():
global selected_row
if not selected_row:
raise gr.Error("请先选择要编辑的行")
try:
with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0])
if not provider:
raise gr.Error("找不到选中的记录")
provider.model_id = selected_row[1] if selected_row[1] else None
provider.base_url = selected_row[2] if selected_row[2] else None
provider.api_key = selected_row[3] if selected_row[3] else None
session.add(provider)
session.commit()
session.refresh(provider)
return get_providers()
except Exception as e:
raise gr.Error(f"编辑失败: {str(e)}")
def delete_provider():
global selected_row
if not selected_row:
raise gr.Error("请先选择要删除的行")
try:
with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0])
if not provider:
raise gr.Error("找不到选中的记录")
session.delete(provider)
session.commit()
return get_providers()
except Exception as e:
raise gr.Error(f"删除失败: {str(e)}")
selected_row = None # 保存当前选中行的全局变量
def select_record(dataFrame ,evt: gr.SelectData):
global selected_row
selected_row = dataFrame.iloc[evt.index[0]].tolist()
selected_row[0] = int(selected_row[0])
print(selected_row)
with gr.Blocks() as demo:
gr.Markdown("## API Provider 管理")
with gr.Row():
with gr.Column(scale=1):
model_id_input = gr.Textbox(label="Model ID")
base_url_input = gr.Textbox(label="Base URL")
api_key_input = gr.Textbox(label="API Key")
add_button = gr.Button("添加新API", variant="primary")
with gr.Column(scale=3):
provider_table = gr.DataFrame(
headers=["id", "model id", "base URL", "API Key"],
datatype=["number", "str", "str", "str"],
interactive=True,
value=get_providers(),
wrap=True,
col_count=(4, "auto")
)
with gr.Row():
refresh_button = gr.Button("刷新数据", variant="secondary")
edit_button = gr.Button("编辑选中行", variant="primary")
delete_button = gr.Button("删除选中行", variant="stop")
refresh_button.click(
fn=get_providers,
outputs=[provider_table],
queue=False # 立即刷新不需要排队
)
add_button.click(
fn=add_provider,
inputs=[model_id_input, base_url_input, api_key_input],
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
)
provider_table.select(fn=select_record,
inputs=[provider_table],
outputs=[],
show_progress="hidden")
edit_button.click(
fn=edit_provider,
inputs=[],
outputs=[provider_table]
)
delete_button.click(
fn=delete_provider,
inputs=[],
outputs=[provider_table]
)
return demo

110
frontend/train_page.py Normal file
View File

@@ -0,0 +1,110 @@
import subprocess
import os
import gradio as gr
import sys
from tinydb import Query
from pathlib import Path
from transformers import TrainerCallback
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
from tools import train_model, find_available_port
def train_page():
with gr.Blocks() as demo:
gr.Markdown("## 微调")
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None
with gr.Row():
with gr.Column(scale=1):
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
# 新增超参数输入组件
learning_rate_input = gr.Number(value=2e-4, label="学习率")
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
epoch_input = gr.Number(value=1, label="epoch", precision=0)
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
train_button = gr.Button("开始微调")
# 训练状态输出
output = gr.Textbox(label="训练状态", interactive=False)
with gr.Column(scale=3):
# 新增 TensorBoard iframe 显示框
tensorboard_iframe = gr.HTML(label="TensorBoard 可视化")
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
# 使用动态传入的超参数
learning_rate = float(learning_rate)
per_device_train_batch_size = int(per_device_train_batch_size)
epoch = int(epoch)
save_steps = int(save_steps) # 新增保存步数参数
lora_rank = int(lora_rank) # 新增LoRA秩参数
# 加载数据集
dataset = get_datasets().get(Query().name == dataset_name)
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
# 扫描 training 文件夹并生成递增目录
training_dir = get_workdir() + "/training"
os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在
existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()]
next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1
new_training_dir = os.path.join(training_dir, str(next_dir_number))
tensorboard_port = find_available_port(6006) # 从默认端口 6006 开始检测
print(f"TensorBoard 将使用端口: {tensorboard_port}")
tensorboard_logdir = os.path.join(new_training_dir, "logs")
os.makedirs(tensorboard_logdir, exist_ok=True) # 确保日志目录存在
tensorboard_process = subprocess.Popen(
["tensorboard", "--logdir", tensorboard_logdir, "--port", str(tensorboard_port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
print("TensorBoard 已启动,日志目录:", tensorboard_logdir)
# 动态生成 TensorBoard iframe
tensorboard_url = f"http://localhost:{tensorboard_port}"
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="500px"></iframe>'
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
try:
train_model(get_model(), get_tokenizer(),
dataset, new_training_dir,
learning_rate, per_device_train_batch_size, epoch,
save_steps, lora_rank)
finally:
# 确保训练结束后终止 TensorBoard 子进程
tensorboard_process.terminate()
print("TensorBoard 子进程已终止")
train_button.click(
fn=start_training,
inputs=[
dataset_dropdown,
learning_rate_input,
per_device_train_batch_size_input,
epoch_input,
save_steps_input,
lora_rank_input
],
outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe
)
return demo
if __name__ == "__main__":
from global_var import init_global_var
from model_manage_page import model_manage_page
init_global_var("workdir")
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()

56
global_var.py Normal file
View File

@@ -0,0 +1,56 @@
from db import load_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory
_prompt_store = None
_sql_engine = None
_datasets = None
_model = None
_tokenizer = None
_workdir = None
def init_global_var(workdir="workdir"):
global _prompt_store, _sql_engine, _datasets, _workdir
_prompt_store = get_prompt_tinydb(workdir)
_sql_engine = load_sqlite_engine(workdir)
_datasets = get_all_dataset(workdir)
_workdir = workdir
def get_workdir():
return _workdir
def get_prompt_store():
return _prompt_store
def set_prompt_store(new_prompt_store):
global _prompt_store
_prompt_store = new_prompt_store
def get_sql_engine():
return _sql_engine
def set_sql_engine(new_sql_engine):
global _sql_engine
_sql_engine = new_sql_engine
def get_docs():
global _workdir
return scan_docs_directory(_workdir)
def get_datasets():
return _datasets
def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
def get_model():
return _model
def set_model(new_model):
global _model
_model = new_model
def get_tokenizer():
return _tokenizer
def set_tokenizer(new_tokenizer):
global _tokenizer
_tokenizer = new_tokenizer

29
main.py Normal file
View File

@@ -0,0 +1,29 @@
import gradio as gr
import unsloth
from frontend import *
from db import initialize_sqlite_db, initialize_prompt_store
from global_var import init_global_var, get_sql_engine, get_prompt_store
if __name__ == "__main__":
init_global_var()
initialize_sqlite_db(get_sql_engine())
initialize_prompt_store(get_prompt_store())
with gr.Blocks() as app:
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
with gr.Tabs():
with gr.TabItem("模型管理"):
model_manage_page()
with gr.TabItem("模型推理"):
chat_page()
with gr.TabItem("模型微调"):
train_page()
with gr.TabItem("数据集生成"):
dataset_generate_page()
with gr.TabItem("数据集管理"):
dataset_manage_page()
with gr.TabItem("提示词模板管理"):
prompt_manage_page()
with gr.TabItem("设置"):
setting_page()
app.launch()

View File

@@ -1,2 +1,11 @@
openai>=1.0.0
python-dotenv>=1.0.0
pydantic>=2.0.0
gradio>=5.25.0
langchain>=0.3
tinydb>=4.0.0
unsloth>=2025.3.19
sqlmodel>=0.0.24
jinja2>=3.1.0
tensorboardX>=2.6.2.2
tensorboard>=2.19.0

4
schema/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .dataset import *
from .dataset_generation import APIProvider, LLMResponse, LLMRequest
from .md_doc import MarkdownNode
from .prompt import promptTempleta

30
schema/dataset.py Normal file
View File

@@ -0,0 +1,30 @@
from typing import Optional
from pydantic import BaseModel, Field
from datetime import datetime, timezone
class doc(BaseModel):
id: Optional[int] = Field(default=None, description="文档ID")
name: str = Field(default="", description="文档名称")
path: str = Field(default="", description="文档路径")
markdown_files: list[str] = Field(default_factory=list, description="文档路径列表")
version: Optional[str] = Field(default="", description="文档版本")
class Q_A(BaseModel):
question: str = Field(default="", min_length=1,description="问题")
answer: str = Field(default="", min_length=1, description="答案")
class dataset_item(BaseModel):
id: Optional[int] = Field(default=None, description="数据集项ID")
message: list[Q_A] = Field(description="数据集项内容")
class dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID")
name: str = Field(default="", description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档")
description: Optional[str] = Field(default="", description="数据集描述")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
dataset_items: list[dataset_item] = Field(default_factory=list, description="数据集项列表")

View File

@@ -0,0 +1,51 @@
from datetime import datetime, timezone
from typing import Optional
from sqlmodel import SQLModel, Relationship, Field
class APIProvider(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True,allow_mutation=False)
base_url: str = Field(...,min_length=1,description="API的基础URL不能为空")
model_id: str = Field(...,min_length=1,description="API使用的模型ID不能为空")
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
class LLMResponse(SQLModel):
timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="响应的时间戳"
)
response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: dict = Field(default_factory=lambda: {
"temperature": None,
"max_tokens": None,
"top_p": None,
"frequency_penalty": None,
"presence_penalty": None,
"seed": None
}, description="API的生成参数")
class LLMRequest(SQLModel):
prompt: str = Field(..., description="发送给API的提示词")
provider_id: int = Field(foreign_key="apiprovider.id")
provider: APIProvider = Relationship()
format: Optional[str] = Field(default=None, description="API响应的格式")
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
total_tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")

13
schema/md_doc.py Normal file
View File

@@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import List, Optional
class MarkdownNode(BaseModel):
level: int = Field(default=0, description="节点层级")
title: str = Field(default="Root", description="节点标题")
content: Optional[str] = Field(default=None, description="节点内容")
children: List['MarkdownNode'] = Field(default_factory=list, description="子节点列表")
class Config:
arbitrary_types_allowed = True
MarkdownNode.model_rebuild()

20
schema/prompt.py Normal file
View File

@@ -0,0 +1,20 @@
from pydantic import BaseModel, Field, field_validator
from typing import Optional
from datetime import datetime, timezone
from langchain.prompts import PromptTemplate
class promptTempleta(BaseModel):
id: Optional[int] = Field(default=None, description="模板ID")
name: Optional[str] = Field(default="", description="模板名称")
description: Optional[str] = Field(default="", description="模板描述")
content: str = Field(default="", min_length=1, description="模板内容")
created_at: str = Field(
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
description="记录创建时间"
)
@field_validator('content')
def validate_content(cls, value):
if not "document_slice" in PromptTemplate.from_template(value).input_variables:
raise ValueError("模板变量缺少 document_slice")
return value

5
tools/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .parse_markdown import parse_markdown
from .document import *
from .json_example import generate_example_json
from .model import *
from .port import *

View File

@@ -0,0 +1,35 @@
from typing import List
from schema.dataset import dataset, dataset_item, Q_A
import json
def convert_json_to_dataset(json_data: List[dict]) -> dataset:
# 将JSON数据转换为dataset格式
dataset_items = []
item_id = 1 # 自增ID计数器
for item in json_data:
qa = Q_A(question=item["question"], answer=item["answer"])
dataset_item_obj = dataset_item(id=item_id, message=[qa])
dataset_items.append(dataset_item_obj)
item_id += 1 # ID自增
# 创建dataset对象
result_dataset = dataset(
name="Converted Dataset",
model_id=None,
description="Dataset converted from JSON",
dataset_items=dataset_items
)
return result_dataset
# 示例从文件读取JSON并转换
if __name__ == "__main__":
# 假设JSON数据存储在文件中
with open(r"workdir\dataset_old\llamafactory.json", "r", encoding="utf-8") as file:
json_data = json.load(file)
# 转换为dataset格式
converted_dataset = convert_json_to_dataset(json_data)
# 输出结果到文件
with open("output.json", "w", encoding="utf-8") as file:
file.write(converted_dataset.model_dump_json(indent=4))

32
tools/document.py Normal file
View File

@@ -0,0 +1,32 @@
import sys
import os
from pathlib import Path
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import doc
def scan_docs_directory(workdir: str):
docs_dir = os.path.join(workdir, "docs")
doc_list = os.listdir(docs_dir)
to_return = []
for doc_name in doc_list:
doc_path = os.path.join(docs_dir, doc_name)
if os.path.isdir(doc_path):
markdown_files = []
for root, dirs, files in os.walk(doc_path):
for file in files:
if file.endswith(".md"):
markdown_files.append(os.path.join(root, file))
to_return.append(doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
return to_return
# 添加测试代码
if __name__ == "__main__":
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
docs = scan_docs_directory(workdir)
print(docs)

63
tools/json_example.py Normal file
View File

@@ -0,0 +1,63 @@
from pydantic import BaseModel, create_model
from typing import Any, Dict, List, Optional, get_args, get_origin
import json
from datetime import datetime, date
def generate_example_json(model: type[BaseModel]) -> str:
"""
根据 Pydantic V2 模型生成示例 JSON 数据结构。
"""
def _generate_example(field_type: Any) -> Any:
origin = get_origin(field_type)
args = get_args(field_type)
if origin is list or origin is List:
if args:
return [_generate_example(args[0])]
else:
return []
elif origin is dict or origin is Dict:
if len(args) == 2 and args[0] is str:
return {"key": _generate_example(args[1])}
else:
return {}
elif origin is Optional or origin is type(None):
if args:
return _generate_example(args[0])
else:
return None
elif field_type is str:
return "string"
elif field_type is int:
return 0
elif field_type is float:
return 0.0
elif field_type is bool:
return True
elif field_type is datetime:
return datetime.now().isoformat()
elif field_type is date:
return date.today().isoformat()
elif issubclass(field_type, BaseModel):
return generate_example_json(field_type)
else:
return "unknown" # 对于未知类型返回 "unknown"
example_data = {}
for field_name, field in model.model_fields.items():
example_data[field_name] = _generate_example(field.annotation)
return json.dumps(example_data, indent=2, default=str)
if __name__ == "__main__":
import sys
from pathlib import Path
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import Q_A
class Q_A_list(BaseModel):
Q_As: List[Q_A]
print("示例 JSON:")
print(generate_example_json(Q_A_list))

134
tools/model.py Normal file
View File

@@ -0,0 +1,134 @@
import os
from datasets import Dataset as HFDataset
from unsloth import FastLanguageModel
from trl import SFTTrainer # 用于监督微调的训练器
from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数
from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练
from unsloth.chat_templates import get_chat_template, train_on_responses_only
def get_model_name(model):
return os.path.basename(model.name_or_path)
def formatting_prompts(examples,tokenizer):
"""格式化对话数据的函数
Args:
examples: 包含对话列表的字典
Returns:
包含格式化文本的字典
"""
questions = examples["question"]
answer = examples["answer"]
convos = [
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
for q, r in zip(questions, answer)
]
# 使用tokenizer.apply_chat_template格式化对话
texts = [
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
for convo in convos
]
return {"text": texts}
def train_model(
model,
tokenizer,
dataset: list,
train_dir: str,
learning_rate: float,
per_device_train_batch_size: int,
epoch: int,
save_steps: int,
lora_rank: int,
trainer_callback=None
) -> None:
# 模型配置参数
dtype = None # 数据类型None表示自动选择
load_in_4bit = False # 使用4bit量化加载模型以节省显存
model = FastLanguageModel.get_peft_model(
# 原始模型
model,
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
# 建议: 8-32之间
r=lora_rank, # 使用动态传入的LoRA秩
# 需要应用LoRA的目标模块列表
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
"gate_proj", "up_proj", "down_proj", # FFN相关层
],
# LoRA缩放因子,用于控制LoRA更新的幅度。值越大LoRA的更新影响越大。
lora_alpha=16,
# LoRA层的dropout率,用于防止过拟合,这里设为0表示不使用dropout。
# 如果数据集较小建议设置0.1左右。
lora_dropout=0,
# 是否对bias参数进行微调,none表示不微调bias
# none: 不微调偏置参数;
# all: 微调所有参数;
# lora_only: 只微调LoRA参数。
bias="none",
# 是否使用梯度检查点技术节省显存,使用unsloth优化版本
# 会略微降低训练速度,但可以显著减少显存使用
use_gradient_checkpointing="unsloth",
# 随机数种子,用于结果复现
random_state=3407,
# 是否使用rank-stabilized LoRA,这里不使用
# 会略微降低训练速度,但可以显著减少显存使用
use_rslora=False,
# LoFTQ配置,这里不使用该量化技术,用于进一步压缩模型大小
loftq_config=None,
)
tokenizer = get_chat_template(
tokenizer,
chat_template="qwen-2.5",
)
train_dataset = HFDataset.from_list(dataset)
train_dataset = train_dataset.map(formatting_prompts,
fn_kwargs={"tokenizer": tokenizer},
batched=True)
# 初始化SFT训练器
trainer = SFTTrainer(
model=model, # 待训练的模型
tokenizer=tokenizer, # 分词器
train_dataset=train_dataset, # 训练数据集
dataset_text_field="text", # 数据集字段的名称
max_seq_length=model.max_seq_length, # 最大序列长度
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=1, # 数据集处理的并行进程数
packing=False,
args=TrainingArguments(
per_device_train_batch_size=per_device_train_batch_size, # 每个GPU的训练批次大小
gradient_accumulation_steps=4, # 梯度累积步数,用于模拟更大的batch size
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
learning_rate=learning_rate, # 学习率
lr_scheduler_type="linear", # 线性学习率调度器
max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
logging_steps=1, # 每1步记录一次日志
optim="adamw_8bit", # 使用8位AdamW优化器节省显存几乎不影响训练效果
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
seed=114514, # 随机数种子
output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志
save_strategy="steps", # 按步保存中间权重
save_steps=save_steps, # 使用动态传入的保存步数
logging_dir=train_dir + "/logs", # 日志文件存储路径
report_to="tensorboard", # 使用TensorBoard记录日志
),
)
if trainer_callback is not None:
trainer.add_callback(trainer_callback)
trainer = train_on_responses_only(
trainer,
instruction_part = "<|im_start|>user\n",
response_part = "<|im_start|>assistant\n",
)
# 开始训练
trainer_stats = trainer.train(resume_from_checkpoint=False)

View File

@@ -1,28 +1,45 @@
import re
import sys
from pathlib import Path
class MarkdownNode:
def __init__(self, level=0, title="Root"):
self.level = level
self.title = title
self.content = "" # 使用字符串存储合并后的内容
self.children = []
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import MarkdownNode
def __repr__(self):
return f"({self.level}) {self.title}"
def process_markdown_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
def add_child(self, child):
self.children.append(child)
root = parse_markdown(content)
results = []
def print_tree(self, indent=0):
def traverse(node, parent_titles):
current_titles = parent_titles.copy()
current_titles.append(node.title)
if not node.children: # 叶子节点
if node.content:
full_text = ' -> '.join(current_titles) + '\n' + node.content
results.append(full_text)
else:
for child in node.children:
traverse(child, current_titles)
traverse(root, [])
return results
def add_child(parent, child):
parent.children.append(child)
def print_tree(node, indent=0):
prefix = "" * (indent - 1) + "└─ " if indent > 0 else ""
print(f"{prefix}{self.title}")
if self.content:
print(f"{prefix}{node.title}")
if node.content:
content_prefix = "" * indent + "├─ [内容]"
print(content_prefix)
for line in self.content.split('\n'):
for line in node.content.split('\n'):
print("" * indent + "" + line)
for child in self.children:
child.print_tree(indent + 1)
for child in node.children:
print_tree(child, indent + 1)
def parse_markdown(markdown):
lines = markdown.split('\n')
@@ -51,10 +68,10 @@ def parse_markdown(markdown):
if match:
level = len(match.group(1))
title = match.group(2)
node = MarkdownNode(level, title)
node = MarkdownNode(level=level, title=title, content="", children=[])
while stack[-1].level >= level:
stack.pop()
stack[-1].add_child(node)
add_child(stack[-1], node)
stack.append(node)
else:
if stack[-1].content:
@@ -64,10 +81,13 @@ def parse_markdown(markdown):
return root
if __name__=="__main__":
# 从文件读取 Markdown 内容
with open("example.md", "r", encoding="utf-8") as f:
markdown = f.read()
# # 从文件读取 Markdown 内容
# with open("workdir/example.md", "r", encoding="utf-8") as f:
# markdown = f.read()
# 解析 Markdown 并打印树结构
root = parse_markdown(markdown)
root.print_tree()
# # 解析 Markdown 并打印树结构
# root = parse_markdown(markdown)
# print_tree(root)
for i in process_markdown_file("workdir/example.md"):
print("~"*20)
print(i)

15
tools/port.py Normal file
View File

@@ -0,0 +1,15 @@
import socket
# 启动 TensorBoard 子进程前添加端口检测逻辑
def find_available_port(start_port):
port = start_port
while True:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(('localhost', port)) != 0: # 端口未被占用
return port
port += 1 # 如果端口被占用,尝试下一个端口
if __name__ == "__main__":
start_port = 6006 # 起始端口号
available_port = find_available_port(start_port)
print(f"Available port: {available_port}")