feat(train_page): 实现训练目录自动递增功能
- 在 training 文件夹下创建递增的目录结构 - 确保 training 文件夹存在 - 扫描现有目录,生成下一个可用的目录编号 - 更新训练模型函数,使用新的训练目录
This commit is contained in:
parent
148f4afb25
commit
4f7926aec6
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import gradio as gr
|
||||
import sys
|
||||
from tinydb import Query
|
||||
@ -36,21 +37,28 @@ def train_page():
|
||||
output = gr.Textbox(label="训练日志", interactive=False)
|
||||
|
||||
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||||
# 使用动态传入的超参数
|
||||
# 使用动态传入的超参数
|
||||
learning_rate = float(learning_rate)
|
||||
per_device_train_batch_size = int(per_device_train_batch_size)
|
||||
epoch = int(epoch)
|
||||
save_steps = int(save_steps) # 新增保存步数参数
|
||||
lora_rank = int(lora_rank) # 新增LoRA秩参数
|
||||
|
||||
# 加载数据集
|
||||
dataset = get_datasets().get(Query().name == dataset_name)
|
||||
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
||||
|
||||
train_model(get_model(), get_tokenizer(),
|
||||
dataset, get_workdir() + "/training" + "/1",
|
||||
learning_rate, per_device_train_batch_size, epoch,
|
||||
save_steps, lora_rank)
|
||||
# 扫描 training 文件夹并生成递增目录
|
||||
training_dir = get_workdir() + "/training"
|
||||
os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在
|
||||
existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()]
|
||||
next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1
|
||||
new_training_dir = os.path.join(training_dir, str(next_dir_number))
|
||||
|
||||
train_model(get_model(), get_tokenizer(),
|
||||
dataset, new_training_dir,
|
||||
learning_rate, per_device_train_batch_size, epoch,
|
||||
save_steps, lora_rank)
|
||||
|
||||
train_button.click(
|
||||
fn=start_training,
|
||||
|
Loading…
x
Reference in New Issue
Block a user