Compare commits
2 Commits
0722748997
...
bb1d8fbd38
Author | SHA1 | Date | |
---|---|---|---|
![]() |
bb1d8fbd38 | ||
![]() |
4558929c52 |
@ -1,6 +1,6 @@
|
||||
from .train_page import *
|
||||
from .chat_page import *
|
||||
from .setting_page import *
|
||||
from .train_page import *
|
||||
from .model_manage_page import *
|
||||
from .dataset_manage_page import *
|
||||
from .dataset_generate_page import *
|
||||
|
@ -1,9 +1,12 @@
|
||||
import unsloth
|
||||
import gradio as gr
|
||||
import sys
|
||||
import torch
|
||||
import pandas as pd
|
||||
from tinydb import Query
|
||||
from pathlib import Path
|
||||
from datasets import Dataset as HFDataset
|
||||
from datasets import Dataset as HFDataset
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from unsloth import FastLanguageModel
|
||||
from trl import SFTTrainer # 用于监督微调的训练器
|
||||
@ -43,6 +46,16 @@ def train_page():
|
||||
# 训练状态输出
|
||||
output = gr.Textbox(label="训练日志", interactive=False)
|
||||
|
||||
# 添加loss曲线展示
|
||||
loss_plot = gr.LinePlot(
|
||||
x="step",
|
||||
y="loss",
|
||||
title="训练Loss曲线",
|
||||
interactive=True,
|
||||
width=600,
|
||||
height=300
|
||||
)
|
||||
|
||||
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||||
# 使用动态传入的超参数
|
||||
learning_rate = float(learning_rate)
|
||||
@ -105,6 +118,33 @@ def train_page():
|
||||
fn_kwargs={"tokenizer": tokenizer},
|
||||
batched=True)
|
||||
|
||||
# 创建回调类
|
||||
class GradioLossCallback(TrainerCallback):
|
||||
def __init__(self):
|
||||
self.loss_data = []
|
||||
self.log_text = ""
|
||||
self.last_output = {"text": "", "plot": None}
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
print(f"on_log called with logs: {logs}") # 调试输出
|
||||
if "loss" in logs:
|
||||
print(f"Recording loss: {logs['loss']} at step {state.global_step}") # 调试输出
|
||||
self.loss_data.append({
|
||||
"step": state.global_step,
|
||||
"loss": float(logs["loss"]) # 确保转换为float
|
||||
})
|
||||
self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n"
|
||||
df = pd.DataFrame(self.loss_data)
|
||||
print(f"DataFrame created: {df}") # 调试输出
|
||||
self.last_output = {
|
||||
"text": self.log_text,
|
||||
"plot": df
|
||||
}
|
||||
return control
|
||||
|
||||
# 初始化回调
|
||||
callback = GradioLossCallback()
|
||||
|
||||
# 初始化SFT训练器
|
||||
trainer = SFTTrainer(
|
||||
model=model, # 待训练的模型
|
||||
@ -134,6 +174,7 @@ def train_page():
|
||||
# report_to="tensorboard", # 将信息输出到tensorboard
|
||||
),
|
||||
)
|
||||
trainer.add_callback(callback)
|
||||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
@ -143,17 +184,28 @@ def train_page():
|
||||
|
||||
# 开始训练
|
||||
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
||||
return callback.last_output
|
||||
|
||||
def wrapped_train_model(*args):
|
||||
print("Starting training...") # 调试输出
|
||||
result = train_model(*args)
|
||||
print(f"Training completed with result: {result}") # 调试输出
|
||||
# 确保返回格式正确
|
||||
if result and "text" in result and "plot" in result:
|
||||
return result["text"], result["plot"]
|
||||
return "", pd.DataFrame() # 返回默认值
|
||||
|
||||
train_button.click(
|
||||
fn=train_model,
|
||||
fn=wrapped_train_model,
|
||||
inputs=[
|
||||
dataset_dropdown,
|
||||
learning_rate_input,
|
||||
per_device_train_batch_size_input,
|
||||
epoch_input,
|
||||
save_steps_input
|
||||
save_steps_input,
|
||||
lora_rank_input
|
||||
],
|
||||
outputs=output
|
||||
outputs=[output, loss_plot]
|
||||
)
|
||||
|
||||
return demo
|
||||
|
Loading…
x
Reference in New Issue
Block a user