From bb1d8fbd3847b0635a6567b56162bf983a6db4ba Mon Sep 17 00:00:00 2001 From: carry Date: Sun, 13 Apr 2025 21:49:43 +0800 Subject: [PATCH] =?UTF-8?q?feat(train=5Fpage):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=20Loss=20=E6=9B=B2=E7=BA=BF=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在训练页面添加了 Loss 曲线图表 - 实现了 GradioLossCallback 类用于记录训练过程中的 Loss 数据 - 修改了训练函数,通过回调函数收集 Loss 信息并更新图表 - 优化了训练函数的返回值结构,支持同时返回文本日志和 Loss 数据 --- frontend/train_page.py | 60 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/frontend/train_page.py b/frontend/train_page.py index 6aa3d78..17e70b8 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -1,9 +1,12 @@ +import unsloth import gradio as gr import sys import torch +import pandas as pd from tinydb import Query from pathlib import Path -from datasets import Dataset as HFDataset +from datasets import Dataset as HFDataset +from transformers import TrainerCallback from unsloth import FastLanguageModel from trl import SFTTrainer # 用于监督微调的训练器 @@ -43,6 +46,16 @@ def train_page(): # 训练状态输出 output = gr.Textbox(label="训练日志", interactive=False) + # 添加loss曲线展示 + loss_plot = gr.LinePlot( + x="step", + y="loss", + title="训练Loss曲线", + interactive=True, + width=600, + height=300 + ) + def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank): # 使用动态传入的超参数 learning_rate = float(learning_rate) @@ -105,6 +118,33 @@ def train_page(): fn_kwargs={"tokenizer": tokenizer}, batched=True) + # 创建回调类 + class GradioLossCallback(TrainerCallback): + def __init__(self): + self.loss_data = [] + self.log_text = "" + self.last_output = {"text": "", "plot": None} + + def on_log(self, args, state, control, logs=None, **kwargs): + print(f"on_log called with logs: {logs}") # 调试输出 + if "loss" in logs: + print(f"Recording loss: {logs['loss']} at step {state.global_step}") # 调试输出 + self.loss_data.append({ + "step": state.global_step, + "loss": float(logs["loss"]) # 确保转换为float + }) + self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n" + df = pd.DataFrame(self.loss_data) + print(f"DataFrame created: {df}") # 调试输出 + self.last_output = { + "text": self.log_text, + "plot": df + } + return control + + # 初始化回调 + callback = GradioLossCallback() + # 初始化SFT训练器 trainer = SFTTrainer( model=model, # 待训练的模型 @@ -134,6 +174,7 @@ def train_page(): # report_to="tensorboard", # 将信息输出到tensorboard ), ) + trainer.add_callback(callback) trainer = train_on_responses_only( trainer, @@ -143,17 +184,28 @@ def train_page(): # 开始训练 trainer_stats = trainer.train(resume_from_checkpoint=False) + return callback.last_output + def wrapped_train_model(*args): + print("Starting training...") # 调试输出 + result = train_model(*args) + print(f"Training completed with result: {result}") # 调试输出 + # 确保返回格式正确 + if result and "text" in result and "plot" in result: + return result["text"], result["plot"] + return "", pd.DataFrame() # 返回默认值 + train_button.click( - fn=train_model, + fn=wrapped_train_model, inputs=[ dataset_dropdown, learning_rate_input, per_device_train_batch_size_input, epoch_input, - save_steps_input + save_steps_input, + lora_rank_input ], - outputs=output + outputs=[output, loss_plot] ) return demo