From 9fb31c46c8906516789a57788465828c309c3c49 Mon Sep 17 00:00:00 2001 From: carry Date: Mon, 14 Apr 2025 15:18:14 +0800 Subject: [PATCH] =?UTF-8?q?feat(train):=20=E6=B7=BB=E5=8A=A0=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E8=BF=87=E7=A8=8B=E4=B8=AD=E7=9A=84=E6=97=A5=E5=BF=97?= =?UTF-8?q?=E8=AE=B0=E5=BD=95=E5=92=8C=20loss=20=E5=8F=AF=E8=A7=86?= =?UTF-8?q?=E5=8C=96=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 LossCallback 类,用于在训练过程中记录 loss 数据 - 在训练模型函数中添加回调,实现日志记录和 loss 可视化 - 优化训练过程中的输出信息,增加当前步数和 loss 值的打印 --- frontend/train_page.py | 28 ++++++++++++++++++++++++++-- tools/model.py | 7 +++++-- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/frontend/train_page.py b/frontend/train_page.py index 9ad226f..a8f261c 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -1,8 +1,8 @@ import gradio as gr import sys -import torch from tinydb import Query from pathlib import Path +from transformers import TrainerCallback sys.path.append(str(Path(__file__).resolve().parent.parent)) from global_var import get_model, get_tokenizer, get_datasets, get_workdir @@ -45,7 +45,31 @@ def train_page(): # 加载数据集 dataset = get_datasets().get(Query().name == dataset_name) dataset = [ds["message"][0] for ds in dataset["dataset_items"]] - train_model(get_model(), get_tokenizer(), dataset, get_workdir(), learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank) + + class LossCallback(TrainerCallback): + def __init__(self): + self.loss_data = [] + self.log_text = "" + self.last_output = {"text": "", "plot": None} + def on_log(self, args, state, control, logs=None, **kwargs): + if "loss" in logs: + self.loss_data.append({ + "step": state.global_step, + "loss": float(logs["loss"]) + }) + self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n" + # 添加以下两行print语句 + print(f"Current Step: {state.global_step}") + print(f"Loss Value: {logs['loss']:.4f}") + self.last_output = { + "text": self.log_text, + } + # 不返回 control,避免干预训练过程 + + train_model(get_model(), get_tokenizer(), + dataset, get_workdir()+"/checkpoint", + learning_rate, per_device_train_batch_size, epoch, + save_steps, lora_rank, LossCallback) train_button.click( diff --git a/tools/model.py b/tools/model.py index b06c0bc..062b222 100644 --- a/tools/model.py +++ b/tools/model.py @@ -40,7 +40,8 @@ def train_model( per_device_train_batch_size: int, epoch: int, save_steps: int, - lora_rank: int + lora_rank: int, + trainer_callback ) -> None: # 模型配置参数 dtype = None # 数据类型,None表示自动选择 @@ -115,10 +116,12 @@ def train_model( output_dir=output_dir, # 保存模型检查点和训练日志 save_strategy="steps", # 按步保存中间权重 save_steps=save_steps, # 使用动态传入的保存步数 - # report_to="tensorboard", # 将信息输出到tensorboard + report_to="none", ), ) + trainer.add_callback(trainer_callback) + trainer = train_on_responses_only( trainer, instruction_part = "<|im_start|>user\n",