From 79d3eb153c5f291046fd824bfc90f7a28a6db6e1 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Sun, 13 Apr 2025 01:56:10 +0800 Subject: [PATCH] =?UTF-8?q?refactor(train=5Fpage):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E9=A1=B5=E9=9D=A2=E5=B8=83=E5=B1=80=E5=92=8C?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除了 max_steps_input 组件,减少不必要的输入项 - 将 per_device_train_batch_size_input 和 epoch_input 的标签简化为 "batch size" 和 "epoch" - 新增 save_steps_input 组件,用于设置保存步数 - 修改 train_model 函数,移除了 max_steps 参数 - 更新了 trainer.train() 方法的调用,设置 resume_from_checkpoint=False --- frontend/train_page.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/frontend/train_page.py b/frontend/train_page.py index 06c33e6..3e7fcbb 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -33,9 +33,8 @@ def train_page(): # 新增超参数输入组件 learning_rate_input = gr.Number(value=2e-4, label="学习率") - per_device_train_batch_size_input = gr.Number(value=1, label="每设备训练批次大小", precision=0) - max_steps_input = gr.Number(value=60, label="最大训练步数", precision=0) - epoch_input = gr.Number(value=1, label="训练轮数", precision=0) + per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0) + epoch_input = gr.Number(value=1, label="epoch", precision=0) save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框 train_button = gr.Button("开始微调") @@ -43,7 +42,7 @@ def train_page(): # 训练状态输出 output = gr.Textbox(label="训练日志", interactive=False) - def train_model(dataset_name, max_seq_length, learning_rate, per_device_train_batch_size, epoch, save_steps): + def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps): # 使用动态传入的超参数 learning_rate = float(learning_rate) per_device_train_batch_size = int(per_device_train_batch_size) @@ -135,7 +134,7 @@ def train_page(): ) # 开始训练 - trainer_stats = trainer.train(resume_from_checkpoint=True) + trainer_stats = trainer.train(resume_from_checkpoint=False) train_button.click( fn=train_model, @@ -143,7 +142,6 @@ def train_page(): dataset_dropdown, learning_rate_input, per_device_train_batch_size_input, - max_steps_input, epoch_input, save_steps_input ],