refactor(train_page): 优化训练页面布局和功能
- 移除了 max_steps_input 组件,减少不必要的输入项 - 将 per_device_train_batch_size_input 和 epoch_input 的标签简化为 "batch size" 和 "epoch" - 新增 save_steps_input 组件,用于设置保存步数 - 修改 train_model 函数,移除了 max_steps 参数 - 更新了 trainer.train() 方法的调用,设置 resume_from_checkpoint=False
This commit is contained in:
parent
80dae7c6e2
commit
79d3eb153c
@ -33,9 +33,8 @@ def train_page():
|
||||
|
||||
# 新增超参数输入组件
|
||||
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
||||
per_device_train_batch_size_input = gr.Number(value=1, label="每设备训练批次大小", precision=0)
|
||||
max_steps_input = gr.Number(value=60, label="最大训练步数", precision=0)
|
||||
epoch_input = gr.Number(value=1, label="训练轮数", precision=0)
|
||||
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
||||
epoch_input = gr.Number(value=1, label="epoch", precision=0)
|
||||
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
|
||||
|
||||
train_button = gr.Button("开始微调")
|
||||
@ -43,7 +42,7 @@ def train_page():
|
||||
# 训练状态输出
|
||||
output = gr.Textbox(label="训练日志", interactive=False)
|
||||
|
||||
def train_model(dataset_name, max_seq_length, learning_rate, per_device_train_batch_size, epoch, save_steps):
|
||||
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps):
|
||||
# 使用动态传入的超参数
|
||||
learning_rate = float(learning_rate)
|
||||
per_device_train_batch_size = int(per_device_train_batch_size)
|
||||
@ -135,7 +134,7 @@ def train_page():
|
||||
)
|
||||
|
||||
# 开始训练
|
||||
trainer_stats = trainer.train(resume_from_checkpoint=True)
|
||||
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
||||
|
||||
train_button.click(
|
||||
fn=train_model,
|
||||
@ -143,7 +142,6 @@ def train_page():
|
||||
dataset_dropdown,
|
||||
learning_rate_input,
|
||||
per_device_train_batch_size_input,
|
||||
max_steps_input,
|
||||
epoch_input,
|
||||
save_steps_input
|
||||
],
|
||||
|
Loading…
x
Reference in New Issue
Block a user