feat(train): 添加训练过程中的日志记录和 loss 可视化功能
- 新增 LossCallback 类,用于在训练过程中记录 loss 数据 - 在训练模型函数中添加回调,实现日志记录和 loss 可视化 - 优化训练过程中的输出信息,增加当前步数和 loss 值的打印
This commit is contained in:
parent
4f09823123
commit
9fb31c46c8
@ -1,8 +1,8 @@
|
||||
import gradio as gr
|
||||
import sys
|
||||
import torch
|
||||
from tinydb import Query
|
||||
from pathlib import Path
|
||||
from transformers import TrainerCallback
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
||||
@ -45,7 +45,31 @@ def train_page():
|
||||
# 加载数据集
|
||||
dataset = get_datasets().get(Query().name == dataset_name)
|
||||
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
||||
train_model(get_model(), get_tokenizer(), dataset, get_workdir(), learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank)
|
||||
|
||||
class LossCallback(TrainerCallback):
|
||||
def __init__(self):
|
||||
self.loss_data = []
|
||||
self.log_text = ""
|
||||
self.last_output = {"text": "", "plot": None}
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if "loss" in logs:
|
||||
self.loss_data.append({
|
||||
"step": state.global_step,
|
||||
"loss": float(logs["loss"])
|
||||
})
|
||||
self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n"
|
||||
# 添加以下两行print语句
|
||||
print(f"Current Step: {state.global_step}")
|
||||
print(f"Loss Value: {logs['loss']:.4f}")
|
||||
self.last_output = {
|
||||
"text": self.log_text,
|
||||
}
|
||||
# 不返回 control,避免干预训练过程
|
||||
|
||||
train_model(get_model(), get_tokenizer(),
|
||||
dataset, get_workdir()+"/checkpoint",
|
||||
learning_rate, per_device_train_batch_size, epoch,
|
||||
save_steps, lora_rank, LossCallback)
|
||||
|
||||
|
||||
train_button.click(
|
||||
|
@ -40,7 +40,8 @@ def train_model(
|
||||
per_device_train_batch_size: int,
|
||||
epoch: int,
|
||||
save_steps: int,
|
||||
lora_rank: int
|
||||
lora_rank: int,
|
||||
trainer_callback
|
||||
) -> None:
|
||||
# 模型配置参数
|
||||
dtype = None # 数据类型,None表示自动选择
|
||||
@ -115,10 +116,12 @@ def train_model(
|
||||
output_dir=output_dir, # 保存模型检查点和训练日志
|
||||
save_strategy="steps", # 按步保存中间权重
|
||||
save_steps=save_steps, # 使用动态传入的保存步数
|
||||
# report_to="tensorboard", # 将信息输出到tensorboard
|
||||
report_to="none",
|
||||
),
|
||||
)
|
||||
|
||||
trainer.add_callback(trainer_callback)
|
||||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part = "<|im_start|>user\n",
|
||||
|
Loading…
x
Reference in New Issue
Block a user