Compare commits
No commits in common. "79d3eb153c5f291046fd824bfc90f7a28a6db6e1" and "2d39b9176488a62ef24ce9bc5d49fba22d62606b" have entirely different histories.
79d3eb153c
...
2d39b91764
@ -33,8 +33,9 @@ def train_page():
|
||||
|
||||
# 新增超参数输入组件
|
||||
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
||||
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
||||
epoch_input = gr.Number(value=1, label="epoch", precision=0)
|
||||
per_device_train_batch_size_input = gr.Number(value=1, label="每设备训练批次大小", precision=0)
|
||||
max_steps_input = gr.Number(value=60, label="最大训练步数", precision=0)
|
||||
epoch_input = gr.Number(value=1, label="训练轮数", precision=0)
|
||||
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
|
||||
|
||||
train_button = gr.Button("开始微调")
|
||||
@ -42,7 +43,7 @@ def train_page():
|
||||
# 训练状态输出
|
||||
output = gr.Textbox(label="训练日志", interactive=False)
|
||||
|
||||
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps):
|
||||
def train_model(dataset_name, max_seq_length, learning_rate, per_device_train_batch_size, epoch, save_steps):
|
||||
# 使用动态传入的超参数
|
||||
learning_rate = float(learning_rate)
|
||||
per_device_train_batch_size = int(per_device_train_batch_size)
|
||||
@ -134,7 +135,7 @@ def train_page():
|
||||
)
|
||||
|
||||
# 开始训练
|
||||
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
||||
trainer_stats = trainer.train(resume_from_checkpoint=True)
|
||||
|
||||
train_button.click(
|
||||
fn=train_model,
|
||||
@ -142,6 +143,7 @@ def train_page():
|
||||
dataset_dropdown,
|
||||
learning_rate_input,
|
||||
per_device_train_batch_size_input,
|
||||
max_steps_input,
|
||||
epoch_input,
|
||||
save_steps_input
|
||||
],
|
||||
|
@ -9,7 +9,7 @@ _model = None
|
||||
_tokenizer = None
|
||||
_workdir = None
|
||||
def init_global_var(workdir="workdir"):
|
||||
global _prompt_store, _sql_engine, _docs, _datasets, _workdir
|
||||
global _prompt_store, _sql_engine, _docs, _datasets
|
||||
_prompt_store = get_prompt_tinydb(workdir)
|
||||
_sql_engine = get_sqlite_engine(workdir)
|
||||
_docs = scan_docs_directory(workdir)
|
||||
|
Loading…
x
Reference in New Issue
Block a user