feat(train): 添加训练过程中的日志记录和 loss 可视化功能

- 新增 LossCallback 类,用于在训练过程中记录 loss 数据
- 在训练模型函数中添加回调,实现日志记录和 loss 可视化
- 优化训练过程中的输出信息,增加当前步数和 loss 值的打印
This commit is contained in:
carry
2025-04-14 15:18:14 +08:00
parent 4f09823123
commit 9fb31c46c8
2 changed files with 31 additions and 4 deletions

View File

@@ -40,7 +40,8 @@ def train_model(
per_device_train_batch_size: int,
epoch: int,
save_steps: int,
lora_rank: int
lora_rank: int,
trainer_callback
) -> None:
# 模型配置参数
dtype = None # 数据类型None表示自动选择
@@ -115,10 +116,12 @@ def train_model(
output_dir=output_dir, # 保存模型检查点和训练日志
save_strategy="steps", # 按步保存中间权重
save_steps=save_steps, # 使用动态传入的保存步数
# report_to="tensorboard", # 将信息输出到tensorboard
report_to="none",
),
)
trainer.add_callback(trainer_callback)
trainer = train_on_responses_only(
trainer,
instruction_part = "<|im_start|>user\n",