Compare commits
5 Commits
0a4efa5641
...
improve
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7a4388c928 | ||
![]() |
6338706967 | ||
![]() |
3718c75cee | ||
![]() |
905658073a | ||
![]() |
9806334517 |
92
README.md
92
README.md
@@ -1,15 +1,91 @@
|
||||
# 基于文档驱动的自适应编码大模型微调框架
|
||||
|
||||
## 简介
|
||||
本人的毕业设计
|
||||
|
||||
### 项目概述
|
||||
本项目是一个基于文档驱动的自适应编码大模型微调框架,通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
|
||||
|
||||
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
|
||||
### 核心功能
|
||||
- 文档解析与语料生成
|
||||
- 大语言模型高效微调
|
||||
- 交互式训练与推理界面
|
||||
- 训练过程可视化监控
|
||||
|
||||
### 项目技术
|
||||
## 技术架构
|
||||
|
||||
* 使用unsloth框架在GPU上实现大语言模型的qlora微调
|
||||
* 使用langchain框架编写工作流实现批量生成微调语料
|
||||
* 使用tinydb和sqlite实现数据的持久化
|
||||
* 使用gradio框架实现前端展示
|
||||
### 系统架构
|
||||
```
|
||||
[前端界面] -> [模型微调] -> [数据存储]
|
||||
↑ ↑ ↑
|
||||
│ │ │
|
||||
[Gradio] [unsloth/QLoRA] [SQLite/TinyDB]
|
||||
```
|
||||
|
||||
**施工中......**
|
||||
### 技术栈
|
||||
- **前端界面**: Gradio构建的交互式Web界面
|
||||
- **模型微调**: 基于unsloth框架的QLoRA高效微调
|
||||
- **数据存储**: SQLite(结构化数据) + TinyDB(非结构化数据)
|
||||
- **工作流引擎**: LangChain实现文档解析与语料生成
|
||||
- **训练监控**: TensorBoard集成
|
||||
|
||||
## 功能模块
|
||||
|
||||
### 1. 模型管理
|
||||
- 支持多种格式的大语言模型加载
|
||||
- 模型信息查看与状态管理
|
||||
- 模型卸载与内存释放
|
||||
|
||||
### 2. 模型推理
|
||||
- 对话式交互界面
|
||||
- 流式响应输出
|
||||
- 历史对话管理
|
||||
|
||||
### 3. 模型微调
|
||||
- 训练参数配置(学习率、batch size等)
|
||||
- LoRA参数配置(秩、alpha等)
|
||||
- 训练过程实时监控
|
||||
- 训练中断与恢复
|
||||
|
||||
### 4. 数据集生成
|
||||
- 文档解析与清洗
|
||||
- 指令-响应对生成
|
||||
- 数据集质量评估
|
||||
- 数据集版本管理
|
||||
|
||||
## 微调技术
|
||||
|
||||
### QLoRA原理
|
||||
QLoRA(Quantized Low-Rank Adaptation)是一种高效的大模型微调技术,核心特点包括:
|
||||
1. **4-bit量化**: 将预训练模型量化为4-bit表示,大幅减少显存占用
|
||||
2. **低秩适配**: 通过低秩矩阵分解(LoRA)实现参数高效更新
|
||||
3. **内存优化**: 使用梯度检查点等技术进一步降低显存需求
|
||||
|
||||
### 参数配置
|
||||
- **学习率**: 建议2e-5到2e-4
|
||||
- **LoRA秩**: 控制适配器复杂度(建议16-64)
|
||||
- **LoRA Alpha**: 控制适配器更新幅度(通常设为秩的1-2倍)
|
||||
|
||||
### 训练监控
|
||||
- **TensorBoard集成**: 实时查看损失曲线、学习率等指标
|
||||
- **日志记录**: 训练过程详细日志保存
|
||||
- **模型检查点**: 定期保存中间权重
|
||||
|
||||
## 快速开始
|
||||
|
||||
1. 安装依赖:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. 启动应用:
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
3. 访问Web界面:
|
||||
```
|
||||
http://localhost:7860
|
||||
```
|
||||
|
||||
## 许可证
|
||||
MIT License
|
||||
|
@@ -135,8 +135,6 @@ def dataset_generate_page():
|
||||
|
||||
prompt = prompt.partial(**variables_dict)
|
||||
|
||||
|
||||
|
||||
dataset = Dataset(
|
||||
name=dataset_name,
|
||||
model_id=[api_provider.model_id],
|
||||
@@ -144,10 +142,12 @@ def dataset_generate_page():
|
||||
dataset_items=[]
|
||||
)
|
||||
|
||||
for document_slice in document_slice_list:
|
||||
total_slices = len(document_slice_list)
|
||||
for i, document_slice in enumerate(document_slice_list):
|
||||
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
|
||||
request = LLMRequest(api_provider=api_provider,
|
||||
prompt=prompt.format(document_slice=document_slice),
|
||||
format=generate_json_example(DatasetItem))
|
||||
prompt=prompt.format(document_slice=document_slice),
|
||||
format=generate_json_example(DatasetItem))
|
||||
call_openai_api(request, rounds)
|
||||
|
||||
for resp in request.response:
|
||||
|
@@ -30,6 +30,11 @@ def model_manage_page():
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
|
||||
save_method_dropdown = gr.Dropdown(
|
||||
choices=["default", "merged_16bit", "merged_4bit", "lora", "gguf", "gguf_q4_k_m", "gguf_f16"],
|
||||
label="保存模式",
|
||||
value="default"
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
save_button = gr.Button("保存模型", variant="secondary")
|
||||
|
||||
@@ -73,21 +78,12 @@ def model_manage_page():
|
||||
|
||||
unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
|
||||
|
||||
def save_model(save_model_name):
|
||||
try:
|
||||
global model, tokenizer
|
||||
if model is None:
|
||||
return "没有加载的模型可保存"
|
||||
|
||||
save_path = os.path.join(models_dir, save_model_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model.save_pretrained(save_path)
|
||||
tokenizer.save_pretrained(save_path)
|
||||
return f"模型已保存到 {save_path}"
|
||||
except Exception as e:
|
||||
return f"保存模型时出错: {str(e)}"
|
||||
from train.save_model import save_model_to_dir
|
||||
|
||||
def save_model(save_model_name, save_method):
|
||||
return save_model_to_dir(save_model_name, models_dir, get_model(), get_tokenizer(), save_method)
|
||||
|
||||
save_button.click(fn=save_model, inputs=save_model_name_input, outputs=state_output)
|
||||
save_button.click(fn=save_model, inputs=[save_model_name_input, save_method_dropdown], outputs=state_output)
|
||||
|
||||
def refresh_model_list():
|
||||
try:
|
||||
|
@@ -82,6 +82,8 @@ def train_page():
|
||||
dataset, new_training_dir,
|
||||
learning_rate, per_device_train_batch_size, epoch,
|
||||
save_steps, lora_rank)
|
||||
except Exception as e:
|
||||
raise gr.Error(str(e))
|
||||
finally:
|
||||
# 确保训练结束后终止 TensorBoard 子进程
|
||||
tensorboard_process.terminate()
|
||||
|
2
main.py
2
main.py
@@ -26,4 +26,4 @@ if __name__ == "__main__":
|
||||
with gr.TabItem("设置"):
|
||||
setting_page()
|
||||
|
||||
app.launch()
|
||||
app.launch(server_name="0.0.0.0")
|
48
train/save_model.py
Normal file
48
train/save_model.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
from global_var import get_model, get_tokenizer
|
||||
|
||||
def save_model_to_dir(save_model_name, models_dir, model, tokenizer, save_method="default"):
|
||||
"""
|
||||
保存模型到指定目录
|
||||
:param save_model_name: 要保存的模型名称
|
||||
:param models_dir: 模型保存的基础目录
|
||||
:param model: 要保存的模型
|
||||
:param tokenizer: 要保存的tokenizer
|
||||
:param save_method: 保存模式选项
|
||||
- "default": 默认保存方式
|
||||
- "merged_16bit": 合并为16位
|
||||
- "merged_4bit": 合并为4位
|
||||
- "lora": 仅LoRA适配器
|
||||
- "gguf": 保存为GGUF格式
|
||||
- "gguf_q4_k_m": 保存为q4_k_m GGUF格式
|
||||
- "gguf_f16": 保存为16位GGUF格式
|
||||
:return: 保存结果消息或错误信息
|
||||
"""
|
||||
try:
|
||||
if model is None:
|
||||
return "没有加载的模型可保存"
|
||||
|
||||
save_path = os.path.join(models_dir, save_model_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
if save_method == "default":
|
||||
model.save_pretrained(save_path)
|
||||
tokenizer.save_pretrained(save_path)
|
||||
elif save_method == "merged_16bit":
|
||||
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_16bit")
|
||||
elif save_method == "merged_4bit":
|
||||
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_4bit_forced")
|
||||
elif save_method == "lora":
|
||||
model.save_pretrained_merged(save_path, tokenizer, save_method="lora")
|
||||
elif save_method == "gguf":
|
||||
model.save_pretrained_gguf(save_path, tokenizer)
|
||||
elif save_method == "gguf_q4_k_m":
|
||||
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="q4_k_m")
|
||||
elif save_method == "gguf_f16":
|
||||
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="f16")
|
||||
else:
|
||||
return f"不支持的保存模式: {save_method}"
|
||||
|
||||
return f"模型已保存到 {save_path} (模式: {save_method})"
|
||||
except Exception as e:
|
||||
return f"保存模型时出错: {str(e)}"
|
Reference in New Issue
Block a user