gzhu-biyesheji/frontend/train_page.py
carry 79d3eb153c refactor(train_page): 优化训练页面布局和功能
- 移除了 max_steps_input 组件,减少不必要的输入项
- 将 per_device_train_batch_size_input 和 epoch_input 的标签简化为 "batch size" 和 "epoch"
- 新增 save_steps_input 组件,用于设置保存步数
- 修改 train_model 函数,移除了 max_steps 参数
- 更新了 trainer.train() 方法的调用,设置 resume_from_checkpoint=False
2025-04-13 01:56:10 +08:00

160 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gradio as gr
import sys
import torch
from tinydb import Query
from pathlib import Path
from datasets import Dataset as HFDataset
from unsloth import FastLanguageModel
from trl import SFTTrainer # 用于监督微调的训练器
from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数
from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练
from unsloth.chat_templates import get_chat_template
from tools import formatting_prompts_func
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
from tools import formatting_prompts_func
def train_page():
with gr.Blocks() as demo:
gr.Markdown("## 微调")
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
# 新增超参数输入组件
learning_rate_input = gr.Number(value=2e-4, label="学习率")
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
epoch_input = gr.Number(value=1, label="epoch", precision=0)
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
train_button = gr.Button("开始微调")
# 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False)
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps):
# 使用动态传入的超参数
learning_rate = float(learning_rate)
per_device_train_batch_size = int(per_device_train_batch_size)
epoch = int(epoch)
save_steps = int(save_steps) # 新增保存步数参数
# 模型配置参数
dtype = None # 数据类型None表示自动选择
load_in_4bit = False # 使用4bit量化加载模型以节省显存
# 加载预训练模型和分词器
model = get_model()
tokenizer = get_tokenizer()
model = FastLanguageModel.get_peft_model(
# 原始模型
model,
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
# 建议: 8-32之间
r=16,
# 需要应用LoRA的目标模块列表
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
"gate_proj", "up_proj", "down_proj", # FFN相关层
],
# LoRA缩放因子,用于控制LoRA更新的幅度。值越大LoRA的更新影响越大。
lora_alpha=16,
# LoRA层的dropout率,用于防止过拟合,这里设为0表示不使用dropout。
# 如果数据集较小建议设置0.1左右。
lora_dropout=0,
# 是否对bias参数进行微调,none表示不微调bias
# none: 不微调偏置参数;
# all: 微调所有参数;
# lora_only: 只微调LoRA参数。
bias="none",
# 是否使用梯度检查点技术节省显存,使用unsloth优化版本
# 会略微降低训练速度,但可以显著减少显存使用
use_gradient_checkpointing="unsloth",
# 随机数种子,用于结果复现
random_state=3407,
# 是否使用rank-stabilized LoRA,这里不使用
# 会略微降低训练速度,但可以显著减少显存使用
use_rslora=False,
# LoFTQ配置,这里不使用该量化技术,用于进一步压缩模型大小
loftq_config=None,
)
tokenizer = get_chat_template(
tokenizer,
chat_template="qwen-2.5",
)
# 加载数据集
dataset = get_datasets().get(Query().name == dataset_name)
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
dataset = HFDataset.from_list(dataset)
dataset = dataset.map(formatting_prompts_func,
fn_kwargs={"tokenizer": tokenizer},
batched=True)
# 初始化SFT训练器
trainer = SFTTrainer(
model=model, # 待训练的模型
tokenizer=tokenizer, # 分词器
train_dataset=dataset, # 训练数据集
dataset_text_field="text", # 数据集字段的名称
max_seq_length=model.max_seq_length, # 最大序列长度
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=1, # 数据集处理的并行进程数
packing=False,
args=TrainingArguments(
per_device_train_batch_size=per_device_train_batch_size, # 每个GPU的训练批次大小
gradient_accumulation_steps=4, # 梯度累积步数,用于模拟更大的batch size
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
learning_rate=learning_rate, # 学习率
lr_scheduler_type="linear", # 线性学习率调度器
max_steps=int(epoch * len(dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
logging_steps=1, # 每1步记录一次日志
optim="adamw_8bit", # 使用8位AdamW优化器节省显存几乎不影响训练效果
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
seed=114514, # 随机数种子
output_dir=get_workdir() + "/checkpoint/", # 保存模型检查点和训练日志
save_strategy="steps", # 按步保存中间权重
save_steps=save_steps, # 使用动态传入的保存步数
# report_to="tensorboard", # 将信息输出到tensorboard
),
)
# 开始训练
trainer_stats = trainer.train(resume_from_checkpoint=False)
train_button.click(
fn=train_model,
inputs=[
dataset_dropdown,
learning_rate_input,
per_device_train_batch_size_input,
epoch_input,
save_steps_input
],
outputs=output
)
return demo
if __name__ == "__main__":
from global_var import init_global_var
from model_manage_page import model_manage_page
init_global_var("workdir")
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
# demo = gr.TabbedInterface([ train_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()