feat(train_page): 添加训练 Loss 曲线显示功能
- 在训练页面添加了 Loss 曲线图表 - 实现了 GradioLossCallback 类用于记录训练过程中的 Loss 数据 - 修改了训练函数,通过回调函数收集 Loss 信息并更新图表 - 优化了训练函数的返回值结构,支持同时返回文本日志和 Loss 数据
This commit is contained in:
parent
4558929c52
commit
bb1d8fbd38
@ -1,9 +1,12 @@
|
|||||||
|
import unsloth
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
|
import pandas as pd
|
||||||
from tinydb import Query
|
from tinydb import Query
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datasets import Dataset as HFDataset
|
from datasets import Dataset as HFDataset
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
from unsloth import FastLanguageModel
|
from unsloth import FastLanguageModel
|
||||||
from trl import SFTTrainer # 用于监督微调的训练器
|
from trl import SFTTrainer # 用于监督微调的训练器
|
||||||
@ -43,6 +46,16 @@ def train_page():
|
|||||||
# 训练状态输出
|
# 训练状态输出
|
||||||
output = gr.Textbox(label="训练日志", interactive=False)
|
output = gr.Textbox(label="训练日志", interactive=False)
|
||||||
|
|
||||||
|
# 添加loss曲线展示
|
||||||
|
loss_plot = gr.LinePlot(
|
||||||
|
x="step",
|
||||||
|
y="loss",
|
||||||
|
title="训练Loss曲线",
|
||||||
|
interactive=True,
|
||||||
|
width=600,
|
||||||
|
height=300
|
||||||
|
)
|
||||||
|
|
||||||
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||||||
# 使用动态传入的超参数
|
# 使用动态传入的超参数
|
||||||
learning_rate = float(learning_rate)
|
learning_rate = float(learning_rate)
|
||||||
@ -105,6 +118,33 @@ def train_page():
|
|||||||
fn_kwargs={"tokenizer": tokenizer},
|
fn_kwargs={"tokenizer": tokenizer},
|
||||||
batched=True)
|
batched=True)
|
||||||
|
|
||||||
|
# 创建回调类
|
||||||
|
class GradioLossCallback(TrainerCallback):
|
||||||
|
def __init__(self):
|
||||||
|
self.loss_data = []
|
||||||
|
self.log_text = ""
|
||||||
|
self.last_output = {"text": "", "plot": None}
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
print(f"on_log called with logs: {logs}") # 调试输出
|
||||||
|
if "loss" in logs:
|
||||||
|
print(f"Recording loss: {logs['loss']} at step {state.global_step}") # 调试输出
|
||||||
|
self.loss_data.append({
|
||||||
|
"step": state.global_step,
|
||||||
|
"loss": float(logs["loss"]) # 确保转换为float
|
||||||
|
})
|
||||||
|
self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n"
|
||||||
|
df = pd.DataFrame(self.loss_data)
|
||||||
|
print(f"DataFrame created: {df}") # 调试输出
|
||||||
|
self.last_output = {
|
||||||
|
"text": self.log_text,
|
||||||
|
"plot": df
|
||||||
|
}
|
||||||
|
return control
|
||||||
|
|
||||||
|
# 初始化回调
|
||||||
|
callback = GradioLossCallback()
|
||||||
|
|
||||||
# 初始化SFT训练器
|
# 初始化SFT训练器
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=model, # 待训练的模型
|
model=model, # 待训练的模型
|
||||||
@ -134,6 +174,7 @@ def train_page():
|
|||||||
# report_to="tensorboard", # 将信息输出到tensorboard
|
# report_to="tensorboard", # 将信息输出到tensorboard
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
trainer.add_callback(callback)
|
||||||
|
|
||||||
trainer = train_on_responses_only(
|
trainer = train_on_responses_only(
|
||||||
trainer,
|
trainer,
|
||||||
@ -143,17 +184,28 @@ def train_page():
|
|||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
||||||
|
return callback.last_output
|
||||||
|
|
||||||
|
def wrapped_train_model(*args):
|
||||||
|
print("Starting training...") # 调试输出
|
||||||
|
result = train_model(*args)
|
||||||
|
print(f"Training completed with result: {result}") # 调试输出
|
||||||
|
# 确保返回格式正确
|
||||||
|
if result and "text" in result and "plot" in result:
|
||||||
|
return result["text"], result["plot"]
|
||||||
|
return "", pd.DataFrame() # 返回默认值
|
||||||
|
|
||||||
train_button.click(
|
train_button.click(
|
||||||
fn=train_model,
|
fn=wrapped_train_model,
|
||||||
inputs=[
|
inputs=[
|
||||||
dataset_dropdown,
|
dataset_dropdown,
|
||||||
learning_rate_input,
|
learning_rate_input,
|
||||||
per_device_train_batch_size_input,
|
per_device_train_batch_size_input,
|
||||||
epoch_input,
|
epoch_input,
|
||||||
save_steps_input
|
save_steps_input,
|
||||||
|
lora_rank_input
|
||||||
],
|
],
|
||||||
outputs=output
|
outputs=[output, loss_plot]
|
||||||
)
|
)
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
Loading…
x
Reference in New Issue
Block a user