diff --git a/frontend/train_page.py b/frontend/train_page.py index 6aa3d78..17e70b8 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -1,9 +1,12 @@ +import unsloth import gradio as gr import sys import torch +import pandas as pd from tinydb import Query from pathlib import Path -from datasets import Dataset as HFDataset +from datasets import Dataset as HFDataset +from transformers import TrainerCallback from unsloth import FastLanguageModel from trl import SFTTrainer # 用于监督微调的训练器 @@ -43,6 +46,16 @@ def train_page(): # 训练状态输出 output = gr.Textbox(label="训练日志", interactive=False) + # 添加loss曲线展示 + loss_plot = gr.LinePlot( + x="step", + y="loss", + title="训练Loss曲线", + interactive=True, + width=600, + height=300 + ) + def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank): # 使用动态传入的超参数 learning_rate = float(learning_rate) @@ -105,6 +118,33 @@ def train_page(): fn_kwargs={"tokenizer": tokenizer}, batched=True) + # 创建回调类 + class GradioLossCallback(TrainerCallback): + def __init__(self): + self.loss_data = [] + self.log_text = "" + self.last_output = {"text": "", "plot": None} + + def on_log(self, args, state, control, logs=None, **kwargs): + print(f"on_log called with logs: {logs}") # 调试输出 + if "loss" in logs: + print(f"Recording loss: {logs['loss']} at step {state.global_step}") # 调试输出 + self.loss_data.append({ + "step": state.global_step, + "loss": float(logs["loss"]) # 确保转换为float + }) + self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n" + df = pd.DataFrame(self.loss_data) + print(f"DataFrame created: {df}") # 调试输出 + self.last_output = { + "text": self.log_text, + "plot": df + } + return control + + # 初始化回调 + callback = GradioLossCallback() + # 初始化SFT训练器 trainer = SFTTrainer( model=model, # 待训练的模型 @@ -134,6 +174,7 @@ def train_page(): # report_to="tensorboard", # 将信息输出到tensorboard ), ) + trainer.add_callback(callback) trainer = train_on_responses_only( trainer, @@ -143,17 +184,28 @@ def train_page(): # 开始训练 trainer_stats = trainer.train(resume_from_checkpoint=False) + return callback.last_output + def wrapped_train_model(*args): + print("Starting training...") # 调试输出 + result = train_model(*args) + print(f"Training completed with result: {result}") # 调试输出 + # 确保返回格式正确 + if result and "text" in result and "plot" in result: + return result["text"], result["plot"] + return "", pd.DataFrame() # 返回默认值 + train_button.click( - fn=train_model, + fn=wrapped_train_model, inputs=[ dataset_dropdown, learning_rate_input, per_device_train_batch_size_input, epoch_input, - save_steps_input + save_steps_input, + lora_rank_input ], - outputs=output + outputs=[output, loss_plot] ) return demo