Compare commits

..

No commits in common. "bb1d8fbd3847b0635a6567b56162bf983a6db4ba" and "0722748997aaddde8098e8c3366605e736f0e01d" have entirely different histories.

2 changed files with 5 additions and 57 deletions

View File

@ -1,6 +1,6 @@
from .train_page import *
from .chat_page import * from .chat_page import *
from .setting_page import * from .setting_page import *
from .train_page import *
from .model_manage_page import * from .model_manage_page import *
from .dataset_manage_page import * from .dataset_manage_page import *
from .dataset_generate_page import * from .dataset_generate_page import *

View File

@ -1,12 +1,9 @@
import unsloth
import gradio as gr import gradio as gr
import sys import sys
import torch import torch
import pandas as pd
from tinydb import Query from tinydb import Query
from pathlib import Path from pathlib import Path
from datasets import Dataset as HFDataset from datasets import Dataset as HFDataset
from transformers import TrainerCallback
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
from trl import SFTTrainer # 用于监督微调的训练器 from trl import SFTTrainer # 用于监督微调的训练器
@ -46,16 +43,6 @@ def train_page():
# 训练状态输出 # 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False) output = gr.Textbox(label="训练日志", interactive=False)
# 添加loss曲线展示
loss_plot = gr.LinePlot(
x="step",
y="loss",
title="训练Loss曲线",
interactive=True,
width=600,
height=300
)
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank): def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
# 使用动态传入的超参数 # 使用动态传入的超参数
learning_rate = float(learning_rate) learning_rate = float(learning_rate)
@ -118,33 +105,6 @@ def train_page():
fn_kwargs={"tokenizer": tokenizer}, fn_kwargs={"tokenizer": tokenizer},
batched=True) batched=True)
# 创建回调类
class GradioLossCallback(TrainerCallback):
def __init__(self):
self.loss_data = []
self.log_text = ""
self.last_output = {"text": "", "plot": None}
def on_log(self, args, state, control, logs=None, **kwargs):
print(f"on_log called with logs: {logs}") # 调试输出
if "loss" in logs:
print(f"Recording loss: {logs['loss']} at step {state.global_step}") # 调试输出
self.loss_data.append({
"step": state.global_step,
"loss": float(logs["loss"]) # 确保转换为float
})
self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n"
df = pd.DataFrame(self.loss_data)
print(f"DataFrame created: {df}") # 调试输出
self.last_output = {
"text": self.log_text,
"plot": df
}
return control
# 初始化回调
callback = GradioLossCallback()
# 初始化SFT训练器 # 初始化SFT训练器
trainer = SFTTrainer( trainer = SFTTrainer(
model=model, # 待训练的模型 model=model, # 待训练的模型
@ -174,7 +134,6 @@ def train_page():
# report_to="tensorboard", # 将信息输出到tensorboard # report_to="tensorboard", # 将信息输出到tensorboard
), ),
) )
trainer.add_callback(callback)
trainer = train_on_responses_only( trainer = train_on_responses_only(
trainer, trainer,
@ -184,28 +143,17 @@ def train_page():
# 开始训练 # 开始训练
trainer_stats = trainer.train(resume_from_checkpoint=False) trainer_stats = trainer.train(resume_from_checkpoint=False)
return callback.last_output
def wrapped_train_model(*args):
print("Starting training...") # 调试输出
result = train_model(*args)
print(f"Training completed with result: {result}") # 调试输出
# 确保返回格式正确
if result and "text" in result and "plot" in result:
return result["text"], result["plot"]
return "", pd.DataFrame() # 返回默认值
train_button.click( train_button.click(
fn=wrapped_train_model, fn=train_model,
inputs=[ inputs=[
dataset_dropdown, dataset_dropdown,
learning_rate_input, learning_rate_input,
per_device_train_batch_size_input, per_device_train_batch_size_input,
epoch_input, epoch_input,
save_steps_input, save_steps_input
lora_rank_input
], ],
outputs=[output, loss_plot] outputs=output
) )
return demo return demo