From 2d39b9176488a62ef24ce9bc5d49fba22d62606b Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Sun, 13 Apr 2025 01:04:27 +0800 Subject: [PATCH] =?UTF-8?q?feat(train=5Fpage):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83=E8=B6=85=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增学习率、批次大小、最大训练步数等超参数输入组件 - 实现超参数在训练过程中的动态应用 - 调整训练参数以适应不同硬件环境 - 优化训练过程,支持按步数保存模型 --- frontend/train_page.py | 64 +++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/frontend/train_page.py b/frontend/train_page.py index 2613556..06c33e6 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -13,16 +13,16 @@ from unsloth.chat_templates import get_chat_template from tools import formatting_prompts_func sys.path.append(str(Path(__file__).resolve().parent.parent)) -from global_var import get_model, get_tokenizer, get_datasets +from global_var import get_model, get_tokenizer, get_datasets, get_workdir from tools import formatting_prompts_func def train_page(): with gr.Blocks() as demo: - gr.Markdown("## 微调") # 获取数据集列表并设置初始值 datasets_list = [str(ds["name"]) for ds in get_datasets().all()] initial_dataset = datasets_list[0] if datasets_list else None + dataset_dropdown = gr.Dropdown( choices=datasets_list, value=initial_dataset, # 设置初始选中项 @@ -30,14 +30,27 @@ def train_page(): allow_custom_value=True, interactive=True ) + + # 新增超参数输入组件 + learning_rate_input = gr.Number(value=2e-4, label="学习率") + per_device_train_batch_size_input = gr.Number(value=1, label="每设备训练批次大小", precision=0) + max_steps_input = gr.Number(value=60, label="最大训练步数", precision=0) + epoch_input = gr.Number(value=1, label="训练轮数", precision=0) + save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框 + train_button = gr.Button("开始微调") # 训练状态输出 output = gr.Textbox(label="训练日志", interactive=False) - def train_model(dataset_name): + def train_model(dataset_name, max_seq_length, learning_rate, per_device_train_batch_size, epoch, save_steps): + # 使用动态传入的超参数 + learning_rate = float(learning_rate) + per_device_train_batch_size = int(per_device_train_batch_size) + epoch = int(epoch) + save_steps = int(save_steps) # 新增保存步数参数 + # 模型配置参数 - max_seq_length = 4096 # 最大序列长度 dtype = None # 数据类型,None表示自动选择 load_in_4bit = False # 使用4bit量化加载模型以节省显存 @@ -78,7 +91,6 @@ def train_page(): loftq_config=None, ) - # 配置分词器 tokenizer = get_chat_template( tokenizer, chat_template="qwen-2.5", @@ -91,48 +103,50 @@ def train_page(): dataset = dataset.map(formatting_prompts_func, fn_kwargs={"tokenizer": tokenizer}, batched=True) - print(dataset[5]) - # 初始化SFT训练器 trainer = SFTTrainer( - model=model, # 待训练的模型 + model=model, # 待训练的模型 tokenizer=tokenizer, # 分词器 train_dataset=dataset, # 训练数据集 dataset_text_field="text", # 数据集字段的名称 - max_seq_length=max_seq_length, # 最大序列长度 + max_seq_length=model.max_seq_length, # 最大序列长度 data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), - dataset_num_proc=1, # 数据集处理的并行进程数,提高CPU利用率 + dataset_num_proc=1, # 数据集处理的并行进程数 packing=False, args=TrainingArguments( - per_device_train_batch_size=2, # 每个GPU的训练批次大小 + per_device_train_batch_size=per_device_train_batch_size, # 每个GPU的训练批次大小 gradient_accumulation_steps=4, # 梯度累积步数,用于模拟更大的batch size - warmup_steps=5, # 预热步数,逐步增加学习率 - learning_rate=2e-4, # 学习率 - lr_scheduler_type="linear", # 线性学习率调度器 - max_steps=60, # 最大训练步数(一步 = 处理一个batch的数据) - # 根据硬件支持选择训练精度 + warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率 + learning_rate=learning_rate, # 学习率 + lr_scheduler_type="linear", # 线性学习率调度器 + max_steps=int(epoch * len(dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据) fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16 bf16=is_bfloat16_supported(), # 如果支持则使用bf16 logging_steps=1, # 每1步记录一次日志 optim="adamw_8bit", # 使用8位AdamW优化器节省显存,几乎不影响训练效果 weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合 - seed=3407, # 随机数种子 - output_dir="outputs", # 保存模型检查点和训练日志 + seed=114514, # 随机数种子 + output_dir=get_workdir() + "/checkpoint/", # 保存模型检查点和训练日志 save_strategy="steps", # 按步保存中间权重 - save_steps=20, # 每20步保存一次中间权重 + save_steps=save_steps, # 使用动态传入的保存步数 # report_to="tensorboard", # 将信息输出到tensorboard ), ) - - # 开始训练,resume_from_checkpoint为True表示从最新的模型开始训练 - trainer_stats = trainer.train(resume_from_checkpoint = True) - - + # 开始训练 + trainer_stats = trainer.train(resume_from_checkpoint=True) + train_button.click( fn=train_model, - inputs=dataset_dropdown, + inputs=[ + dataset_dropdown, + learning_rate_input, + per_device_train_batch_size_input, + max_steps_input, + epoch_input, + save_steps_input + ], outputs=output )