import subprocess import os import gradio as gr import sys from tinydb import Query from pathlib import Path from transformers import TrainerCallback sys.path.append(str(Path(__file__).resolve().parent.parent)) from global_var import get_model, get_tokenizer, get_datasets, get_workdir from tools import find_available_port from train import train_model def train_page(): with gr.Blocks() as demo: gr.Markdown("## 微调") # 获取数据集列表并设置初始值 datasets_list = [str(ds["name"]) for ds in get_datasets().all()] initial_dataset = datasets_list[0] if datasets_list else None with gr.Row(): with gr.Column(scale=1): dataset_dropdown = gr.Dropdown( choices=datasets_list, value=initial_dataset, # 设置初始选中项 label="选择数据集", allow_custom_value=True, interactive=True ) # 新增超参数输入组件 learning_rate_input = gr.Number(value=2e-4, label="学习率") per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0) epoch_input = gr.Number(value=1, label="epoch", precision=0) save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框 lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框 train_button = gr.Button("开始微调") # 训练状态输出 output = gr.Textbox(label="训练状态", interactive=False) with gr.Column(scale=3): # 新增 TensorBoard iframe 显示框 tensorboard_iframe = gr.HTML(label="TensorBoard 可视化") def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank): # 使用动态传入的超参数 learning_rate = float(learning_rate) per_device_train_batch_size = int(per_device_train_batch_size) epoch = int(epoch) save_steps = int(save_steps) # 新增保存步数参数 lora_rank = int(lora_rank) # 新增LoRA秩参数 # 加载数据集 dataset = get_datasets().get(Query().name == dataset_name) dataset = [ds["message"][0] for ds in dataset["dataset_items"]] # 扫描 training 文件夹并生成递增目录 training_dir = get_workdir() + "/training" os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在 existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()] next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1 new_training_dir = os.path.join(training_dir, str(next_dir_number)) tensorboard_port = find_available_port(6006) # 从默认端口 6006 开始检测 print(f"TensorBoard 将使用端口: {tensorboard_port}") tensorboard_logdir = os.path.join(new_training_dir, "logs") os.makedirs(tensorboard_logdir, exist_ok=True) # 确保日志目录存在 tensorboard_process = subprocess.Popen( ["tensorboard", "--logdir", tensorboard_logdir, "--port", str(tensorboard_port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE ) print("TensorBoard 已启动,日志目录:", tensorboard_logdir) # 动态生成 TensorBoard iframe tensorboard_url = f"http://localhost:{tensorboard_port}" tensorboard_iframe_value = f'' yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html try: train_model(get_model(), get_tokenizer(), dataset, new_training_dir, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank) except Exception as e: raise gr.Error(str(e)) finally: # 确保训练结束后终止 TensorBoard 子进程 tensorboard_process.terminate() print("TensorBoard 子进程已终止") train_button.click( fn=start_training, inputs=[ dataset_dropdown, learning_rate_input, per_device_train_batch_size_input, epoch_input, save_steps_input, lora_rank_input ], outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe ) return demo if __name__ == "__main__": from global_var import init_global_var from model_manage_page import model_manage_page init_global_var("workdir") demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"]) demo.queue() demo.launch()