refactor(tools): 优化 train_model 函数定义

- 添加类型注解,提高代码可读性和维护性
- 使用多行格式定义函数参数,提升代码格式美观
This commit is contained in:
carry 2025-04-14 14:28:36 +08:00
parent 1a2ca3e244
commit 4f09823123

View File

@ -31,8 +31,17 @@ def formatting_prompts(examples,tokenizer):
return {"text": texts}
def train_model(model, tokenizer, dataset, output_dir, learning_rate,
per_device_train_batch_size, epoch, save_steps, lora_rank):
def train_model(
model,
tokenizer,
dataset: list,
output_dir: str,
learning_rate: float,
per_device_train_batch_size: int,
epoch: int,
save_steps: int,
lora_rank: int
) -> None:
# 模型配置参数
dtype = None # 数据类型None表示自动选择
load_in_4bit = False # 使用4bit量化加载模型以节省显存
@ -75,8 +84,8 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
chat_template="qwen-2.5",
)
dataset = HFDataset.from_list(dataset)
dataset = dataset.map(formatting_prompts,
train_dataset = HFDataset.from_list(dataset)
train_dataset = train_dataset.map(formatting_prompts,
fn_kwargs={"tokenizer": tokenizer},
batched=True)
@ -84,7 +93,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
trainer = SFTTrainer(
model=model, # 待训练的模型
tokenizer=tokenizer, # 分词器
train_dataset=dataset, # 训练数据集
train_dataset=train_dataset, # 训练数据集
dataset_text_field="text", # 数据集字段的名称
max_seq_length=model.max_seq_length, # 最大序列长度
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
@ -96,7 +105,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
learning_rate=learning_rate, # 学习率
lr_scheduler_type="linear", # 线性学习率调度器
max_steps=int(epoch * len(dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据
max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
logging_steps=1, # 每1步记录一次日志