diff --git a/frontend/train_page.py b/frontend/train_page.py index 06c33e6..3e7fcbb 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -33,9 +33,8 @@ def train_page(): # 新增超参数输入组件 learning_rate_input = gr.Number(value=2e-4, label="学习率") - per_device_train_batch_size_input = gr.Number(value=1, label="每设备训练批次大小", precision=0) - max_steps_input = gr.Number(value=60, label="最大训练步数", precision=0) - epoch_input = gr.Number(value=1, label="训练轮数", precision=0) + per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0) + epoch_input = gr.Number(value=1, label="epoch", precision=0) save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框 train_button = gr.Button("开始微调") @@ -43,7 +42,7 @@ def train_page(): # 训练状态输出 output = gr.Textbox(label="训练日志", interactive=False) - def train_model(dataset_name, max_seq_length, learning_rate, per_device_train_batch_size, epoch, save_steps): + def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps): # 使用动态传入的超参数 learning_rate = float(learning_rate) per_device_train_batch_size = int(per_device_train_batch_size) @@ -135,7 +134,7 @@ def train_page(): ) # 开始训练 - trainer_stats = trainer.train(resume_from_checkpoint=True) + trainer_stats = trainer.train(resume_from_checkpoint=False) train_button.click( fn=train_model, @@ -143,7 +142,6 @@ def train_page(): dataset_dropdown, learning_rate_input, per_device_train_batch_size_input, - max_steps_input, epoch_input, save_steps_input ],