From 4f7926aec667af4d5a2aaa1b8768c1ef0b7c5751 Mon Sep 17 00:00:00 2001 From: carry Date: Mon, 14 Apr 2025 16:46:29 +0800 Subject: [PATCH] =?UTF-8?q?feat(train=5Fpage):=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E7=9B=AE=E5=BD=95=E8=87=AA=E5=8A=A8=E9=80=92?= =?UTF-8?q?=E5=A2=9E=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 training 文件夹下创建递增的目录结构 - 确保 training 文件夹存在 - 扫描现有目录,生成下一个可用的目录编号 - 更新训练模型函数,使用新的训练目录 --- frontend/train_page.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/frontend/train_page.py b/frontend/train_page.py index 27d307f..56d0ea7 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -1,3 +1,4 @@ +import os import gradio as gr import sys from tinydb import Query @@ -36,21 +37,28 @@ def train_page(): output = gr.Textbox(label="训练日志", interactive=False) def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank): - # 使用动态传入的超参数 + # 使用动态传入的超参数 learning_rate = float(learning_rate) per_device_train_batch_size = int(per_device_train_batch_size) epoch = int(epoch) save_steps = int(save_steps) # 新增保存步数参数 lora_rank = int(lora_rank) # 新增LoRA秩参数 + # 加载数据集 dataset = get_datasets().get(Query().name == dataset_name) dataset = [ds["message"][0] for ds in dataset["dataset_items"]] - train_model(get_model(), get_tokenizer(), - dataset, get_workdir() + "/training" + "/1", - learning_rate, per_device_train_batch_size, epoch, - save_steps, lora_rank) + # 扫描 training 文件夹并生成递增目录 + training_dir = get_workdir() + "/training" + os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在 + existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()] + next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1 + new_training_dir = os.path.join(training_dir, str(next_dir_number)) + train_model(get_model(), get_tokenizer(), + dataset, new_training_dir, + learning_rate, per_device_train_batch_size, epoch, + save_steps, lora_rank) train_button.click( fn=start_training,