
- 新增 LossCallback 类,用于在训练过程中记录 loss 数据 - 在训练模型函数中添加回调,实现日志记录和 loss 可视化 - 优化训练过程中的输出信息,增加当前步数和 loss 值的打印
96 lines
3.9 KiB
Python
96 lines
3.9 KiB
Python
import gradio as gr
|
||
import sys
|
||
from tinydb import Query
|
||
from pathlib import Path
|
||
from transformers import TrainerCallback
|
||
|
||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
||
from tools import train_model
|
||
|
||
def train_page():
|
||
with gr.Blocks() as demo:
|
||
gr.Markdown("## 微调")
|
||
# 获取数据集列表并设置初始值
|
||
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
||
initial_dataset = datasets_list[0] if datasets_list else None
|
||
|
||
dataset_dropdown = gr.Dropdown(
|
||
choices=datasets_list,
|
||
value=initial_dataset, # 设置初始选中项
|
||
label="选择数据集",
|
||
allow_custom_value=True,
|
||
interactive=True
|
||
)
|
||
|
||
# 新增超参数输入组件
|
||
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
||
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
||
epoch_input = gr.Number(value=1, label="epoch", precision=0)
|
||
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
|
||
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
|
||
|
||
train_button = gr.Button("开始微调")
|
||
|
||
# 训练状态输出
|
||
output = gr.Textbox(label="训练日志", interactive=False)
|
||
|
||
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||
# 使用动态传入的超参数
|
||
learning_rate = float(learning_rate)
|
||
per_device_train_batch_size = int(per_device_train_batch_size)
|
||
epoch = int(epoch)
|
||
save_steps = int(save_steps) # 新增保存步数参数
|
||
lora_rank = int(lora_rank) # 新增LoRA秩参数
|
||
# 加载数据集
|
||
dataset = get_datasets().get(Query().name == dataset_name)
|
||
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
||
|
||
class LossCallback(TrainerCallback):
|
||
def __init__(self):
|
||
self.loss_data = []
|
||
self.log_text = ""
|
||
self.last_output = {"text": "", "plot": None}
|
||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||
if "loss" in logs:
|
||
self.loss_data.append({
|
||
"step": state.global_step,
|
||
"loss": float(logs["loss"])
|
||
})
|
||
self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n"
|
||
# 添加以下两行print语句
|
||
print(f"Current Step: {state.global_step}")
|
||
print(f"Loss Value: {logs['loss']:.4f}")
|
||
self.last_output = {
|
||
"text": self.log_text,
|
||
}
|
||
# 不返回 control,避免干预训练过程
|
||
|
||
train_model(get_model(), get_tokenizer(),
|
||
dataset, get_workdir()+"/checkpoint",
|
||
learning_rate, per_device_train_batch_size, epoch,
|
||
save_steps, lora_rank, LossCallback)
|
||
|
||
|
||
train_button.click(
|
||
fn=start_training,
|
||
inputs=[
|
||
dataset_dropdown,
|
||
learning_rate_input,
|
||
per_device_train_batch_size_input,
|
||
epoch_input,
|
||
save_steps_input,
|
||
lora_rank_input # 新增lora_rank_input
|
||
],
|
||
outputs=output
|
||
)
|
||
|
||
return demo
|
||
|
||
if __name__ == "__main__":
|
||
from global_var import init_global_var
|
||
from model_manage_page import model_manage_page
|
||
init_global_var("workdir")
|
||
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
|
||
demo.queue()
|
||
demo.launch() |