Compare commits

5 Commits

Author SHA1 Message Date
carry
7a4388c928 featmodel): 添加保存模式选择功能
在模型管理页面中新增保存模式选择功能,用户可以通过下拉菜单选择不同的保存模式(如默认、合并16位、合并4位等)。同时,将保存模型的逻辑抽离到独立的`save_model.py`文件中,以提高代码的可维护性和复用性。
2025-04-23 14:09:02 +08:00
carry
6338706967 feat: 修改应用启动方式
- 将 app.launch() 修改为 app.launch(server_name="0.0.0.0")
- 此修改使应用能够监听所有网络接口,提高可用性
2025-04-22 19:08:23 +08:00
carry
3718c75cee feat(frontend): 添加数据集生成页面的处理进度显示
- 在处理文档片段时添加进度条,提升用户体验
- 优化代码格式,调整缩进和空行
2025-04-22 00:14:16 +08:00
carry
905658073a docs(README): 更新项目文档
- 添加项目概述、核心功能、技术架构等详细信息
- 插入系统架构图和技术栈说明
- 细化功能模块描述,包括模型管理、推理、微调等
- 增加QLoRA原理和参数配置说明
- 补充快速开始指南和许可证信息
- 优化文档结构,增强可读性和完整性
2025-04-21 14:28:20 +08:00
carry
9806334517 fix(train_page): 捕获训练过程中的异常并终止 TensorBoard 进程
- 在训练过程中添加异常捕获,将异常信息转换为 gr.Error 抛出
- 确保在发生异常时也能终止 TensorBoard 子进程
2025-04-20 21:40:46 +08:00
6 changed files with 150 additions and 28 deletions

View File

@@ -1,15 +1,91 @@
# 基于文档驱动的自适应编码大模型微调框架
## 简介
本人的毕业设计
### 项目概述
本项目是一个基于文档驱动的自适应编码大模型微调框架,通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
### 核心功能
- 文档解析与语料生成
- 大语言模型高效微调
- 交互式训练与推理界面
- 训练过程可视化监控
### 项目技术
## 技术架构
* 使用unsloth框架在GPU上实现大语言模型的qlora微调
* 使用langchain框架编写工作流实现批量生成微调语料
* 使用tinydb和sqlite实现数据的持久化
* 使用gradio框架实现前端展示
### 系统架构
```
[前端界面] -> [模型微调] -> [数据存储]
↑ ↑ ↑
│ │ │
[Gradio] [unsloth/QLoRA] [SQLite/TinyDB]
```
**施工中......**
### 技术栈
- **前端界面**: Gradio构建的交互式Web界面
- **模型微调**: 基于unsloth框架的QLoRA高效微调
- **数据存储**: SQLite(结构化数据) + TinyDB(非结构化数据)
- **工作流引擎**: LangChain实现文档解析与语料生成
- **训练监控**: TensorBoard集成
## 功能模块
### 1. 模型管理
- 支持多种格式的大语言模型加载
- 模型信息查看与状态管理
- 模型卸载与内存释放
### 2. 模型推理
- 对话式交互界面
- 流式响应输出
- 历史对话管理
### 3. 模型微调
- 训练参数配置(学习率、batch size等)
- LoRA参数配置(秩、alpha等)
- 训练过程实时监控
- 训练中断与恢复
### 4. 数据集生成
- 文档解析与清洗
- 指令-响应对生成
- 数据集质量评估
- 数据集版本管理
## 微调技术
### QLoRA原理
QLoRA(Quantized Low-Rank Adaptation)是一种高效的大模型微调技术,核心特点包括:
1. **4-bit量化**: 将预训练模型量化为4-bit表示大幅减少显存占用
2. **低秩适配**: 通过低秩矩阵分解(LoRA)实现参数高效更新
3. **内存优化**: 使用梯度检查点等技术进一步降低显存需求
### 参数配置
- **学习率**: 建议2e-5到2e-4
- **LoRA秩**: 控制适配器复杂度(建议16-64)
- **LoRA Alpha**: 控制适配器更新幅度(通常设为秩的1-2倍)
### 训练监控
- **TensorBoard集成**: 实时查看损失曲线、学习率等指标
- **日志记录**: 训练过程详细日志保存
- **模型检查点**: 定期保存中间权重
## 快速开始
1. 安装依赖:
```bash
pip install -r requirements.txt
```
2. 启动应用:
```bash
python main.py
```
3. 访问Web界面:
```
http://localhost:7860
```
## 许可证
MIT License

View File

@@ -135,8 +135,6 @@ def dataset_generate_page():
prompt = prompt.partial(**variables_dict)
dataset = Dataset(
name=dataset_name,
model_id=[api_provider.model_id],
@@ -144,10 +142,12 @@ def dataset_generate_page():
dataset_items=[]
)
for document_slice in document_slice_list:
total_slices = len(document_slice_list)
for i, document_slice in enumerate(document_slice_list):
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
request = LLMRequest(api_provider=api_provider,
prompt=prompt.format(document_slice=document_slice),
format=generate_json_example(DatasetItem))
prompt=prompt.format(document_slice=document_slice),
format=generate_json_example(DatasetItem))
call_openai_api(request, rounds)
for resp in request.response:

View File

@@ -30,6 +30,11 @@ def model_manage_page():
with gr.Row():
with gr.Column(scale=3):
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
save_method_dropdown = gr.Dropdown(
choices=["default", "merged_16bit", "merged_4bit", "lora", "gguf", "gguf_q4_k_m", "gguf_f16"],
label="保存模式",
value="default"
)
with gr.Column(scale=1):
save_button = gr.Button("保存模型", variant="secondary")
@@ -73,21 +78,12 @@ def model_manage_page():
unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
def save_model(save_model_name):
try:
global model, tokenizer
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name)
os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
return f"模型已保存到 {save_path}"
except Exception as e:
return f"保存模型时出错: {str(e)}"
from train.save_model import save_model_to_dir
def save_model(save_model_name, save_method):
return save_model_to_dir(save_model_name, models_dir, get_model(), get_tokenizer(), save_method)
save_button.click(fn=save_model, inputs=save_model_name_input, outputs=state_output)
save_button.click(fn=save_model, inputs=[save_model_name_input, save_method_dropdown], outputs=state_output)
def refresh_model_list():
try:

View File

@@ -82,6 +82,8 @@ def train_page():
dataset, new_training_dir,
learning_rate, per_device_train_batch_size, epoch,
save_steps, lora_rank)
except Exception as e:
raise gr.Error(str(e))
finally:
# 确保训练结束后终止 TensorBoard 子进程
tensorboard_process.terminate()

View File

@@ -26,4 +26,4 @@ if __name__ == "__main__":
with gr.TabItem("设置"):
setting_page()
app.launch()
app.launch(server_name="0.0.0.0")

48
train/save_model.py Normal file
View File

@@ -0,0 +1,48 @@
import os
from global_var import get_model, get_tokenizer
def save_model_to_dir(save_model_name, models_dir, model, tokenizer, save_method="default"):
"""
保存模型到指定目录
:param save_model_name: 要保存的模型名称
:param models_dir: 模型保存的基础目录
:param model: 要保存的模型
:param tokenizer: 要保存的tokenizer
:param save_method: 保存模式选项
- "default": 默认保存方式
- "merged_16bit": 合并为16位
- "merged_4bit": 合并为4位
- "lora": 仅LoRA适配器
- "gguf": 保存为GGUF格式
- "gguf_q4_k_m": 保存为q4_k_m GGUF格式
- "gguf_f16": 保存为16位GGUF格式
:return: 保存结果消息或错误信息
"""
try:
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name)
os.makedirs(save_path, exist_ok=True)
if save_method == "default":
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
elif save_method == "merged_16bit":
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_16bit")
elif save_method == "merged_4bit":
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_4bit_forced")
elif save_method == "lora":
model.save_pretrained_merged(save_path, tokenizer, save_method="lora")
elif save_method == "gguf":
model.save_pretrained_gguf(save_path, tokenizer)
elif save_method == "gguf_q4_k_m":
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="q4_k_m")
elif save_method == "gguf_f16":
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="f16")
else:
return f"不支持的保存模式: {save_method}"
return f"模型已保存到 {save_path} (模式: {save_method})"
except Exception as e:
return f"保存模型时出错: {str(e)}"