Compare commits

..

No commits in common. "79d3eb153c5f291046fd824bfc90f7a28a6db6e1" and "2d39b9176488a62ef24ce9bc5d49fba22d62606b" have entirely different histories.

2 changed files with 7 additions and 5 deletions

View File

@ -33,8 +33,9 @@ def train_page():
# 新增超参数输入组件
learning_rate_input = gr.Number(value=2e-4, label="学习率")
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
epoch_input = gr.Number(value=1, label="epoch", precision=0)
per_device_train_batch_size_input = gr.Number(value=1, label="每设备训练批次大小", precision=0)
max_steps_input = gr.Number(value=60, label="最大训练步数", precision=0)
epoch_input = gr.Number(value=1, label="训练轮数", precision=0)
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
train_button = gr.Button("开始微调")
@ -42,7 +43,7 @@ def train_page():
# 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False)
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps):
def train_model(dataset_name, max_seq_length, learning_rate, per_device_train_batch_size, epoch, save_steps):
# 使用动态传入的超参数
learning_rate = float(learning_rate)
per_device_train_batch_size = int(per_device_train_batch_size)
@ -134,7 +135,7 @@ def train_page():
)
# 开始训练
trainer_stats = trainer.train(resume_from_checkpoint=False)
trainer_stats = trainer.train(resume_from_checkpoint=True)
train_button.click(
fn=train_model,
@ -142,6 +143,7 @@ def train_page():
dataset_dropdown,
learning_rate_input,
per_device_train_batch_size_input,
max_steps_input,
epoch_input,
save_steps_input
],

View File

@ -9,7 +9,7 @@ _model = None
_tokenizer = None
_workdir = None
def init_global_var(workdir="workdir"):
global _prompt_store, _sql_engine, _docs, _datasets, _workdir
global _prompt_store, _sql_engine, _docs, _datasets
_prompt_store = get_prompt_tinydb(workdir)
_sql_engine = get_sqlite_engine(workdir)
_docs = scan_docs_directory(workdir)