feat(train_page): 添加 LoRA 秩动态输入功能

- 在训练页面新增 LoRA 秩输入框,使用户可以动态设置 LoRA 秩
- 更新训练模型函数,添加 LoRA 秩参数并将其用于模型配置
- 保留原有功能,仅增加 LoRA 秩相关配置
This commit is contained in:
carry 2025-04-13 21:12:02 +08:00
parent e08f0059bb
commit 0722748997

View File

@ -36,18 +36,20 @@ def train_page():
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0) per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
epoch_input = gr.Number(value=1, label="epoch", precision=0) epoch_input = gr.Number(value=1, label="epoch", precision=0)
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框 save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
train_button = gr.Button("开始微调") train_button = gr.Button("开始微调")
# 训练状态输出 # 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False) output = gr.Textbox(label="训练日志", interactive=False)
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps): def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
# 使用动态传入的超参数 # 使用动态传入的超参数
learning_rate = float(learning_rate) learning_rate = float(learning_rate)
per_device_train_batch_size = int(per_device_train_batch_size) per_device_train_batch_size = int(per_device_train_batch_size)
epoch = int(epoch) epoch = int(epoch)
save_steps = int(save_steps) # 新增保存步数参数 save_steps = int(save_steps) # 新增保存步数参数
lora_rank = int(lora_rank) # 新增LoRA秩参数
# 模型配置参数 # 模型配置参数
dtype = None # 数据类型None表示自动选择 dtype = None # 数据类型None表示自动选择
@ -62,7 +64,7 @@ def train_page():
model, model,
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大 # LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
# 建议: 8-32之间 # 建议: 8-32之间
r=16, r=lora_rank, # 使用动态传入的LoRA秩
# 需要应用LoRA的目标模块列表 # 需要应用LoRA的目标模块列表
target_modules=[ target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层 "q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
@ -161,6 +163,5 @@ if __name__ == "__main__":
from model_manage_page import model_manage_page from model_manage_page import model_manage_page
init_global_var("workdir") init_global_var("workdir")
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"]) demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
# demo = gr.TabbedInterface([ train_page()], ["模型管理", "聊天"])
demo.queue() demo.queue()
demo.launch() demo.launch()