Compare commits
2 Commits
2d39b91764
...
79d3eb153c
Author | SHA1 | Date | |
---|---|---|---|
![]() |
79d3eb153c | ||
![]() |
80dae7c6e2 |
@ -33,9 +33,8 @@ def train_page():
|
||||
|
||||
# 新增超参数输入组件
|
||||
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
||||
per_device_train_batch_size_input = gr.Number(value=1, label="每设备训练批次大小", precision=0)
|
||||
max_steps_input = gr.Number(value=60, label="最大训练步数", precision=0)
|
||||
epoch_input = gr.Number(value=1, label="训练轮数", precision=0)
|
||||
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
||||
epoch_input = gr.Number(value=1, label="epoch", precision=0)
|
||||
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
|
||||
|
||||
train_button = gr.Button("开始微调")
|
||||
@ -43,7 +42,7 @@ def train_page():
|
||||
# 训练状态输出
|
||||
output = gr.Textbox(label="训练日志", interactive=False)
|
||||
|
||||
def train_model(dataset_name, max_seq_length, learning_rate, per_device_train_batch_size, epoch, save_steps):
|
||||
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps):
|
||||
# 使用动态传入的超参数
|
||||
learning_rate = float(learning_rate)
|
||||
per_device_train_batch_size = int(per_device_train_batch_size)
|
||||
@ -135,7 +134,7 @@ def train_page():
|
||||
)
|
||||
|
||||
# 开始训练
|
||||
trainer_stats = trainer.train(resume_from_checkpoint=True)
|
||||
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
||||
|
||||
train_button.click(
|
||||
fn=train_model,
|
||||
@ -143,7 +142,6 @@ def train_page():
|
||||
dataset_dropdown,
|
||||
learning_rate_input,
|
||||
per_device_train_batch_size_input,
|
||||
max_steps_input,
|
||||
epoch_input,
|
||||
save_steps_input
|
||||
],
|
||||
|
@ -9,7 +9,7 @@ _model = None
|
||||
_tokenizer = None
|
||||
_workdir = None
|
||||
def init_global_var(workdir="workdir"):
|
||||
global _prompt_store, _sql_engine, _docs, _datasets
|
||||
global _prompt_store, _sql_engine, _docs, _datasets, _workdir
|
||||
_prompt_store = get_prompt_tinydb(workdir)
|
||||
_sql_engine = get_sqlite_engine(workdir)
|
||||
_docs = scan_docs_directory(workdir)
|
||||
|
Loading…
x
Reference in New Issue
Block a user