131 lines
5.4 KiB
Python
131 lines
5.4 KiB
Python
import os
|
||
from datasets import Dataset as HFDataset
|
||
from unsloth import FastLanguageModel
|
||
from trl import SFTTrainer # 用于监督微调的训练器
|
||
from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数
|
||
from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练
|
||
from unsloth.chat_templates import get_chat_template, train_on_responses_only
|
||
def get_model_name(model):
|
||
return os.path.basename(model.name_or_path)
|
||
def formatting_prompts(examples,tokenizer):
|
||
"""格式化对话数据的函数
|
||
Args:
|
||
examples: 包含对话列表的字典
|
||
Returns:
|
||
包含格式化文本的字典
|
||
"""
|
||
questions = examples["question"]
|
||
answer = examples["answer"]
|
||
|
||
convos = [
|
||
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
|
||
for q, r in zip(questions, answer)
|
||
]
|
||
|
||
# 使用tokenizer.apply_chat_template格式化对话
|
||
texts = [
|
||
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
|
||
for convo in convos
|
||
]
|
||
|
||
return {"text": texts}
|
||
|
||
|
||
def train_model(
|
||
model,
|
||
tokenizer,
|
||
dataset: list,
|
||
train_dir: str,
|
||
learning_rate: float,
|
||
per_device_train_batch_size: int,
|
||
epoch: int,
|
||
save_steps: int,
|
||
lora_rank: int,
|
||
trainer_callback=None
|
||
) -> None:
|
||
|
||
model = FastLanguageModel.get_peft_model(
|
||
# 原始模型
|
||
model,
|
||
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
|
||
# 建议: 8-32之间
|
||
r=lora_rank, # 使用动态传入的LoRA秩
|
||
# 需要应用LoRA的目标模块列表
|
||
target_modules=[
|
||
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
|
||
"gate_proj", "up_proj", "down_proj", # FFN相关层
|
||
],
|
||
# LoRA缩放因子,用于控制LoRA更新的幅度。值越大,LoRA的更新影响越大。
|
||
lora_alpha=16,
|
||
# LoRA层的dropout率,用于防止过拟合,这里设为0表示不使用dropout。
|
||
# 如果数据集较小,建议设置0.1左右。
|
||
lora_dropout=0,
|
||
# 是否对bias参数进行微调,none表示不微调bias
|
||
# none: 不微调偏置参数;
|
||
# all: 微调所有参数;
|
||
# lora_only: 只微调LoRA参数。
|
||
bias="none",
|
||
# 是否使用梯度检查点技术节省显存,使用unsloth优化版本
|
||
# 会略微降低训练速度,但可以显著减少显存使用
|
||
use_gradient_checkpointing="unsloth",
|
||
# 随机数种子,用于结果复现
|
||
random_state=3407,
|
||
# 是否使用rank-stabilized LoRA,这里不使用
|
||
# 会略微降低训练速度,但可以显著减少显存使用
|
||
use_rslora=False,
|
||
# LoFTQ配置,这里不使用该量化技术,用于进一步压缩模型大小
|
||
loftq_config=None,
|
||
)
|
||
|
||
tokenizer = get_chat_template(
|
||
tokenizer,
|
||
chat_template="qwen-2.5",
|
||
)
|
||
|
||
train_dataset = HFDataset.from_list(dataset)
|
||
train_dataset = train_dataset.map(formatting_prompts,
|
||
fn_kwargs={"tokenizer": tokenizer},
|
||
batched=True)
|
||
|
||
# 初始化SFT训练器
|
||
trainer = SFTTrainer(
|
||
model=model, # 待训练的模型
|
||
tokenizer=tokenizer, # 分词器
|
||
train_dataset=train_dataset, # 训练数据集
|
||
dataset_text_field="text", # 数据集字段的名称
|
||
max_seq_length=model.max_seq_length, # 最大序列长度
|
||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||
dataset_num_proc=1, # 数据集处理的并行进程数
|
||
packing=False,
|
||
args=TrainingArguments(
|
||
per_device_train_batch_size=per_device_train_batch_size, # 每个GPU的训练批次大小
|
||
gradient_accumulation_steps=4, # 梯度累积步数,用于模拟更大的batch size
|
||
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
|
||
learning_rate=learning_rate, # 学习率
|
||
lr_scheduler_type="linear", # 线性学习率调度器
|
||
max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据)
|
||
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
|
||
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
|
||
logging_steps=1, # 每1步记录一次日志
|
||
optim="adamw_8bit", # 使用8位AdamW优化器节省显存,几乎不影响训练效果
|
||
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
|
||
seed=114514, # 随机数种子
|
||
output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志
|
||
save_strategy="steps", # 按步保存中间权重
|
||
save_steps=save_steps, # 使用动态传入的保存步数
|
||
logging_dir=train_dir + "/logs", # 日志文件存储路径
|
||
report_to="tensorboard", # 使用TensorBoard记录日志
|
||
),
|
||
)
|
||
|
||
if trainer_callback is not None:
|
||
trainer.add_callback(trainer_callback)
|
||
|
||
trainer = train_on_responses_only(
|
||
trainer,
|
||
instruction_part = "<|im_start|>user\n",
|
||
response_part = "<|im_start|>assistant\n",
|
||
)
|
||
|
||
# 开始训练
|
||
trainer_stats = trainer.train(resume_from_checkpoint=False) |