feat(train_page): 添加训练 Loss 曲线显示功能

- 在训练页面添加了 Loss 曲线图表
- 实现了 GradioLossCallback 类用于记录训练过程中的 Loss 数据
- 修改了训练函数,通过回调函数收集 Loss 信息并更新图表
- 优化了训练函数的返回值结构,支持同时返回文本日志和 Loss 数据
This commit is contained in:
carry 2025-04-13 21:49:43 +08:00
parent 4558929c52
commit bb1d8fbd38

View File

@ -1,9 +1,12 @@
import unsloth
import gradio as gr
import sys
import torch
import pandas as pd
from tinydb import Query
from pathlib import Path
from datasets import Dataset as HFDataset
from transformers import TrainerCallback
from unsloth import FastLanguageModel
from trl import SFTTrainer # 用于监督微调的训练器
@ -43,6 +46,16 @@ def train_page():
# 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False)
# 添加loss曲线展示
loss_plot = gr.LinePlot(
x="step",
y="loss",
title="训练Loss曲线",
interactive=True,
width=600,
height=300
)
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
# 使用动态传入的超参数
learning_rate = float(learning_rate)
@ -105,6 +118,33 @@ def train_page():
fn_kwargs={"tokenizer": tokenizer},
batched=True)
# 创建回调类
class GradioLossCallback(TrainerCallback):
def __init__(self):
self.loss_data = []
self.log_text = ""
self.last_output = {"text": "", "plot": None}
def on_log(self, args, state, control, logs=None, **kwargs):
print(f"on_log called with logs: {logs}") # 调试输出
if "loss" in logs:
print(f"Recording loss: {logs['loss']} at step {state.global_step}") # 调试输出
self.loss_data.append({
"step": state.global_step,
"loss": float(logs["loss"]) # 确保转换为float
})
self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n"
df = pd.DataFrame(self.loss_data)
print(f"DataFrame created: {df}") # 调试输出
self.last_output = {
"text": self.log_text,
"plot": df
}
return control
# 初始化回调
callback = GradioLossCallback()
# 初始化SFT训练器
trainer = SFTTrainer(
model=model, # 待训练的模型
@ -134,6 +174,7 @@ def train_page():
# report_to="tensorboard", # 将信息输出到tensorboard
),
)
trainer.add_callback(callback)
trainer = train_on_responses_only(
trainer,
@ -143,17 +184,28 @@ def train_page():
# 开始训练
trainer_stats = trainer.train(resume_from_checkpoint=False)
return callback.last_output
def wrapped_train_model(*args):
print("Starting training...") # 调试输出
result = train_model(*args)
print(f"Training completed with result: {result}") # 调试输出
# 确保返回格式正确
if result and "text" in result and "plot" in result:
return result["text"], result["plot"]
return "", pd.DataFrame() # 返回默认值
train_button.click(
fn=train_model,
fn=wrapped_train_model,
inputs=[
dataset_dropdown,
learning_rate_input,
per_device_train_batch_size_input,
epoch_input,
save_steps_input
save_steps_input,
lora_rank_input
],
outputs=output
outputs=[output, loss_plot]
)
return demo