Compare commits

...

2 Commits

Author SHA1 Message Date
carry
79d3eb153c refactor(train_page): 优化训练页面布局和功能
- 移除了 max_steps_input 组件,减少不必要的输入项
- 将 per_device_train_batch_size_input 和 epoch_input 的标签简化为 "batch size" 和 "epoch"
- 新增 save_steps_input 组件,用于设置保存步数
- 修改 train_model 函数,移除了 max_steps 参数
- 更新了 trainer.train() 方法的调用,设置 resume_from_checkpoint=False
2025-04-13 01:56:10 +08:00
carry
80dae7c6e2 fix(global_var):修复_workdir非全局变量的bug 2025-04-13 01:52:05 +08:00
2 changed files with 5 additions and 7 deletions

View File

@ -33,9 +33,8 @@ def train_page():
# 新增超参数输入组件
learning_rate_input = gr.Number(value=2e-4, label="学习率")
per_device_train_batch_size_input = gr.Number(value=1, label="每设备训练批次大小", precision=0)
max_steps_input = gr.Number(value=60, label="最大训练步数", precision=0)
epoch_input = gr.Number(value=1, label="训练轮数", precision=0)
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
epoch_input = gr.Number(value=1, label="epoch", precision=0)
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
train_button = gr.Button("开始微调")
@ -43,7 +42,7 @@ def train_page():
# 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False)
def train_model(dataset_name, max_seq_length, learning_rate, per_device_train_batch_size, epoch, save_steps):
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps):
# 使用动态传入的超参数
learning_rate = float(learning_rate)
per_device_train_batch_size = int(per_device_train_batch_size)
@ -135,7 +134,7 @@ def train_page():
)
# 开始训练
trainer_stats = trainer.train(resume_from_checkpoint=True)
trainer_stats = trainer.train(resume_from_checkpoint=False)
train_button.click(
fn=train_model,
@ -143,7 +142,6 @@ def train_page():
dataset_dropdown,
learning_rate_input,
per_device_train_batch_size_input,
max_steps_input,
epoch_input,
save_steps_input
],

View File

@ -9,7 +9,7 @@ _model = None
_tokenizer = None
_workdir = None
def init_global_var(workdir="workdir"):
global _prompt_store, _sql_engine, _docs, _datasets
global _prompt_store, _sql_engine, _docs, _datasets, _workdir
_prompt_store = get_prompt_tinydb(workdir)
_sql_engine = get_sqlite_engine(workdir)
_docs = scan_docs_directory(workdir)