feat(train_page): 实现训练目录自动递增功能

- 在 training 文件夹下创建递增的目录结构
- 确保 training 文件夹存在
- 扫描现有目录,生成下一个可用的目录编号
- 更新训练模型函数,使用新的训练目录
This commit is contained in:
carry 2025-04-14 16:46:29 +08:00
parent 148f4afb25
commit 4f7926aec6

View File

@ -1,3 +1,4 @@
import os
import gradio as gr import gradio as gr
import sys import sys
from tinydb import Query from tinydb import Query
@ -42,16 +43,23 @@ def train_page():
epoch = int(epoch) epoch = int(epoch)
save_steps = int(save_steps) # 新增保存步数参数 save_steps = int(save_steps) # 新增保存步数参数
lora_rank = int(lora_rank) # 新增LoRA秩参数 lora_rank = int(lora_rank) # 新增LoRA秩参数
# 加载数据集 # 加载数据集
dataset = get_datasets().get(Query().name == dataset_name) dataset = get_datasets().get(Query().name == dataset_name)
dataset = [ds["message"][0] for ds in dataset["dataset_items"]] dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
# 扫描 training 文件夹并生成递增目录
training_dir = get_workdir() + "/training"
os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在
existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()]
next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1
new_training_dir = os.path.join(training_dir, str(next_dir_number))
train_model(get_model(), get_tokenizer(), train_model(get_model(), get_tokenizer(),
dataset, get_workdir() + "/training" + "/1", dataset, new_training_dir,
learning_rate, per_device_train_batch_size, epoch, learning_rate, per_device_train_batch_size, epoch,
save_steps, lora_rank) save_steps, lora_rank)
train_button.click( train_button.click(
fn=start_training, fn=start_training,
inputs=[ inputs=[