From 0722748997aaddde8098e8c3366605e736f0e01d Mon Sep 17 00:00:00 2001 From: carry Date: Sun, 13 Apr 2025 21:12:02 +0800 Subject: [PATCH] =?UTF-8?q?feat(train=5Fpage):=20=E6=B7=BB=E5=8A=A0=20LoRA?= =?UTF-8?q?=20=E7=A7=A9=E5=8A=A8=E6=80=81=E8=BE=93=E5=85=A5=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在训练页面新增 LoRA 秩输入框,使用户可以动态设置 LoRA 秩 - 更新训练模型函数,添加 LoRA 秩参数并将其用于模型配置 - 保留原有功能,仅增加 LoRA 秩相关配置 --- frontend/train_page.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/train_page.py b/frontend/train_page.py index d312dc8..6aa3d78 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -36,18 +36,20 @@ def train_page(): per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0) epoch_input = gr.Number(value=1, label="epoch", precision=0) save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框 + lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框 train_button = gr.Button("开始微调") # 训练状态输出 output = gr.Textbox(label="训练日志", interactive=False) - def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps): + def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank): # 使用动态传入的超参数 learning_rate = float(learning_rate) per_device_train_batch_size = int(per_device_train_batch_size) epoch = int(epoch) save_steps = int(save_steps) # 新增保存步数参数 + lora_rank = int(lora_rank) # 新增LoRA秩参数 # 模型配置参数 dtype = None # 数据类型,None表示自动选择 @@ -62,7 +64,7 @@ def train_page(): model, # LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大 # 建议: 8-32之间 - r=16, + r=lora_rank, # 使用动态传入的LoRA秩 # 需要应用LoRA的目标模块列表 target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", # attention相关层 @@ -161,6 +163,5 @@ if __name__ == "__main__": from model_manage_page import model_manage_page init_global_var("workdir") demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"]) - # demo = gr.TabbedInterface([ train_page()], ["模型管理", "聊天"]) demo.queue() demo.launch() \ No newline at end of file