feat(train_page): 添加训练 Loss 曲线显示功能
- 在训练页面添加了 Loss 曲线图表 - 实现了 GradioLossCallback 类用于记录训练过程中的 Loss 数据 - 修改了训练函数,通过回调函数收集 Loss 信息并更新图表 - 优化了训练函数的返回值结构,支持同时返回文本日志和 Loss 数据
This commit is contained in:
parent
4558929c52
commit
bb1d8fbd38
@ -1,9 +1,12 @@
|
||||
import unsloth
|
||||
import gradio as gr
|
||||
import sys
|
||||
import torch
|
||||
import pandas as pd
|
||||
from tinydb import Query
|
||||
from pathlib import Path
|
||||
from datasets import Dataset as HFDataset
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from unsloth import FastLanguageModel
|
||||
from trl import SFTTrainer # 用于监督微调的训练器
|
||||
@ -43,6 +46,16 @@ def train_page():
|
||||
# 训练状态输出
|
||||
output = gr.Textbox(label="训练日志", interactive=False)
|
||||
|
||||
# 添加loss曲线展示
|
||||
loss_plot = gr.LinePlot(
|
||||
x="step",
|
||||
y="loss",
|
||||
title="训练Loss曲线",
|
||||
interactive=True,
|
||||
width=600,
|
||||
height=300
|
||||
)
|
||||
|
||||
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||||
# 使用动态传入的超参数
|
||||
learning_rate = float(learning_rate)
|
||||
@ -105,6 +118,33 @@ def train_page():
|
||||
fn_kwargs={"tokenizer": tokenizer},
|
||||
batched=True)
|
||||
|
||||
# 创建回调类
|
||||
class GradioLossCallback(TrainerCallback):
|
||||
def __init__(self):
|
||||
self.loss_data = []
|
||||
self.log_text = ""
|
||||
self.last_output = {"text": "", "plot": None}
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
print(f"on_log called with logs: {logs}") # 调试输出
|
||||
if "loss" in logs:
|
||||
print(f"Recording loss: {logs['loss']} at step {state.global_step}") # 调试输出
|
||||
self.loss_data.append({
|
||||
"step": state.global_step,
|
||||
"loss": float(logs["loss"]) # 确保转换为float
|
||||
})
|
||||
self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n"
|
||||
df = pd.DataFrame(self.loss_data)
|
||||
print(f"DataFrame created: {df}") # 调试输出
|
||||
self.last_output = {
|
||||
"text": self.log_text,
|
||||
"plot": df
|
||||
}
|
||||
return control
|
||||
|
||||
# 初始化回调
|
||||
callback = GradioLossCallback()
|
||||
|
||||
# 初始化SFT训练器
|
||||
trainer = SFTTrainer(
|
||||
model=model, # 待训练的模型
|
||||
@ -134,6 +174,7 @@ def train_page():
|
||||
# report_to="tensorboard", # 将信息输出到tensorboard
|
||||
),
|
||||
)
|
||||
trainer.add_callback(callback)
|
||||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
@ -143,17 +184,28 @@ def train_page():
|
||||
|
||||
# 开始训练
|
||||
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
||||
return callback.last_output
|
||||
|
||||
def wrapped_train_model(*args):
|
||||
print("Starting training...") # 调试输出
|
||||
result = train_model(*args)
|
||||
print(f"Training completed with result: {result}") # 调试输出
|
||||
# 确保返回格式正确
|
||||
if result and "text" in result and "plot" in result:
|
||||
return result["text"], result["plot"]
|
||||
return "", pd.DataFrame() # 返回默认值
|
||||
|
||||
train_button.click(
|
||||
fn=train_model,
|
||||
fn=wrapped_train_model,
|
||||
inputs=[
|
||||
dataset_dropdown,
|
||||
learning_rate_input,
|
||||
per_device_train_batch_size_input,
|
||||
epoch_input,
|
||||
save_steps_input
|
||||
save_steps_input,
|
||||
lora_rank_input
|
||||
],
|
||||
outputs=output
|
||||
outputs=[output, loss_plot]
|
||||
)
|
||||
|
||||
return demo
|
||||
|
Loading…
x
Reference in New Issue
Block a user