train: 更新模型训练功能和日志记录方式
- 修改训练目录结构,将检查点和日志分开保存 - 添加 TensorBoard 日志记录支持 - 移除自定义 LossCallback 类,简化训练流程 - 更新训练参数和回调机制,提高代码可读性 - 在 requirements.txt 中添加 tensorboardX 依赖
This commit is contained in:
@@ -35,13 +35,13 @@ def train_model(
|
||||
model,
|
||||
tokenizer,
|
||||
dataset: list,
|
||||
output_dir: str,
|
||||
train_dir: str,
|
||||
learning_rate: float,
|
||||
per_device_train_batch_size: int,
|
||||
epoch: int,
|
||||
save_steps: int,
|
||||
lora_rank: int,
|
||||
trainer_callback
|
||||
trainer_callback=None
|
||||
) -> None:
|
||||
# 模型配置参数
|
||||
dtype = None # 数据类型,None表示自动选择
|
||||
@@ -113,14 +113,16 @@ def train_model(
|
||||
optim="adamw_8bit", # 使用8位AdamW优化器节省显存,几乎不影响训练效果
|
||||
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
|
||||
seed=114514, # 随机数种子
|
||||
output_dir=output_dir, # 保存模型检查点和训练日志
|
||||
output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志
|
||||
save_strategy="steps", # 按步保存中间权重
|
||||
save_steps=save_steps, # 使用动态传入的保存步数
|
||||
report_to="none",
|
||||
logging_dir=train_dir + "/logs", # 日志文件存储路径
|
||||
report_to="tensorboard", # 使用TensorBoard记录日志
|
||||
),
|
||||
)
|
||||
|
||||
trainer.add_callback(trainer_callback)
|
||||
|
||||
if trainer_callback is not None:
|
||||
trainer.add_callback(trainer_callback)
|
||||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
|
Reference in New Issue
Block a user