gzhu-biyesheji/frontend/train_page.py
carry 4f7926aec6 feat(train_page): 实现训练目录自动递增功能
- 在 training 文件夹下创建递增的目录结构
- 确保 training 文件夹存在
- 扫描现有目录,生成下一个可用的目录编号
- 更新训练模型函数,使用新的训练目录
2025-04-14 16:46:29 +08:00

84 lines
3.4 KiB
Python

import os
import gradio as gr
import sys
from tinydb import Query
from pathlib import Path
from transformers import TrainerCallback
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
from tools import train_model
def train_page():
with gr.Blocks() as demo:
gr.Markdown("## 微调")
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
# 新增超参数输入组件
learning_rate_input = gr.Number(value=2e-4, label="学习率")
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
epoch_input = gr.Number(value=1, label="epoch", precision=0)
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
train_button = gr.Button("开始微调")
# 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False)
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
# 使用动态传入的超参数
learning_rate = float(learning_rate)
per_device_train_batch_size = int(per_device_train_batch_size)
epoch = int(epoch)
save_steps = int(save_steps) # 新增保存步数参数
lora_rank = int(lora_rank) # 新增LoRA秩参数
# 加载数据集
dataset = get_datasets().get(Query().name == dataset_name)
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
# 扫描 training 文件夹并生成递增目录
training_dir = get_workdir() + "/training"
os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在
existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()]
next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1
new_training_dir = os.path.join(training_dir, str(next_dir_number))
train_model(get_model(), get_tokenizer(),
dataset, new_training_dir,
learning_rate, per_device_train_batch_size, epoch,
save_steps, lora_rank)
train_button.click(
fn=start_training,
inputs=[
dataset_dropdown,
learning_rate_input,
per_device_train_batch_size_input,
epoch_input,
save_steps_input,
lora_rank_input
],
outputs=output
)
return demo
if __name__ == "__main__":
from global_var import init_global_var
from model_manage_page import model_manage_page
init_global_var("workdir")
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()