Compare commits
No commits in common. "bb1d8fbd3847b0635a6567b56162bf983a6db4ba" and "0722748997aaddde8098e8c3366605e736f0e01d" have entirely different histories.
bb1d8fbd38
...
0722748997
@ -1,6 +1,6 @@
|
|||||||
from .train_page import *
|
|
||||||
from .chat_page import *
|
from .chat_page import *
|
||||||
from .setting_page import *
|
from .setting_page import *
|
||||||
|
from .train_page import *
|
||||||
from .model_manage_page import *
|
from .model_manage_page import *
|
||||||
from .dataset_manage_page import *
|
from .dataset_manage_page import *
|
||||||
from .dataset_generate_page import *
|
from .dataset_generate_page import *
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
import unsloth
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
import pandas as pd
|
|
||||||
from tinydb import Query
|
from tinydb import Query
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datasets import Dataset as HFDataset
|
from datasets import Dataset as HFDataset
|
||||||
from transformers import TrainerCallback
|
|
||||||
|
|
||||||
from unsloth import FastLanguageModel
|
from unsloth import FastLanguageModel
|
||||||
from trl import SFTTrainer # 用于监督微调的训练器
|
from trl import SFTTrainer # 用于监督微调的训练器
|
||||||
@ -46,16 +43,6 @@ def train_page():
|
|||||||
# 训练状态输出
|
# 训练状态输出
|
||||||
output = gr.Textbox(label="训练日志", interactive=False)
|
output = gr.Textbox(label="训练日志", interactive=False)
|
||||||
|
|
||||||
# 添加loss曲线展示
|
|
||||||
loss_plot = gr.LinePlot(
|
|
||||||
x="step",
|
|
||||||
y="loss",
|
|
||||||
title="训练Loss曲线",
|
|
||||||
interactive=True,
|
|
||||||
width=600,
|
|
||||||
height=300
|
|
||||||
)
|
|
||||||
|
|
||||||
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||||||
# 使用动态传入的超参数
|
# 使用动态传入的超参数
|
||||||
learning_rate = float(learning_rate)
|
learning_rate = float(learning_rate)
|
||||||
@ -118,33 +105,6 @@ def train_page():
|
|||||||
fn_kwargs={"tokenizer": tokenizer},
|
fn_kwargs={"tokenizer": tokenizer},
|
||||||
batched=True)
|
batched=True)
|
||||||
|
|
||||||
# 创建回调类
|
|
||||||
class GradioLossCallback(TrainerCallback):
|
|
||||||
def __init__(self):
|
|
||||||
self.loss_data = []
|
|
||||||
self.log_text = ""
|
|
||||||
self.last_output = {"text": "", "plot": None}
|
|
||||||
|
|
||||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
||||||
print(f"on_log called with logs: {logs}") # 调试输出
|
|
||||||
if "loss" in logs:
|
|
||||||
print(f"Recording loss: {logs['loss']} at step {state.global_step}") # 调试输出
|
|
||||||
self.loss_data.append({
|
|
||||||
"step": state.global_step,
|
|
||||||
"loss": float(logs["loss"]) # 确保转换为float
|
|
||||||
})
|
|
||||||
self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n"
|
|
||||||
df = pd.DataFrame(self.loss_data)
|
|
||||||
print(f"DataFrame created: {df}") # 调试输出
|
|
||||||
self.last_output = {
|
|
||||||
"text": self.log_text,
|
|
||||||
"plot": df
|
|
||||||
}
|
|
||||||
return control
|
|
||||||
|
|
||||||
# 初始化回调
|
|
||||||
callback = GradioLossCallback()
|
|
||||||
|
|
||||||
# 初始化SFT训练器
|
# 初始化SFT训练器
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=model, # 待训练的模型
|
model=model, # 待训练的模型
|
||||||
@ -174,7 +134,6 @@ def train_page():
|
|||||||
# report_to="tensorboard", # 将信息输出到tensorboard
|
# report_to="tensorboard", # 将信息输出到tensorboard
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
trainer.add_callback(callback)
|
|
||||||
|
|
||||||
trainer = train_on_responses_only(
|
trainer = train_on_responses_only(
|
||||||
trainer,
|
trainer,
|
||||||
@ -184,28 +143,17 @@ def train_page():
|
|||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
||||||
return callback.last_output
|
|
||||||
|
|
||||||
def wrapped_train_model(*args):
|
|
||||||
print("Starting training...") # 调试输出
|
|
||||||
result = train_model(*args)
|
|
||||||
print(f"Training completed with result: {result}") # 调试输出
|
|
||||||
# 确保返回格式正确
|
|
||||||
if result and "text" in result and "plot" in result:
|
|
||||||
return result["text"], result["plot"]
|
|
||||||
return "", pd.DataFrame() # 返回默认值
|
|
||||||
|
|
||||||
train_button.click(
|
train_button.click(
|
||||||
fn=wrapped_train_model,
|
fn=train_model,
|
||||||
inputs=[
|
inputs=[
|
||||||
dataset_dropdown,
|
dataset_dropdown,
|
||||||
learning_rate_input,
|
learning_rate_input,
|
||||||
per_device_train_batch_size_input,
|
per_device_train_batch_size_input,
|
||||||
epoch_input,
|
epoch_input,
|
||||||
save_steps_input,
|
save_steps_input
|
||||||
lora_rank_input
|
|
||||||
],
|
],
|
||||||
outputs=[output, loss_plot]
|
outputs=output
|
||||||
)
|
)
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
Loading…
x
Reference in New Issue
Block a user