feat(train_page): 添加 LoRA 秩动态输入功能
- 在训练页面新增 LoRA 秩输入框,使用户可以动态设置 LoRA 秩 - 更新训练模型函数,添加 LoRA 秩参数并将其用于模型配置 - 保留原有功能,仅增加 LoRA 秩相关配置
This commit is contained in:
parent
e08f0059bb
commit
0722748997
@ -36,18 +36,20 @@ def train_page():
|
|||||||
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
||||||
epoch_input = gr.Number(value=1, label="epoch", precision=0)
|
epoch_input = gr.Number(value=1, label="epoch", precision=0)
|
||||||
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
|
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
|
||||||
|
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
|
||||||
|
|
||||||
train_button = gr.Button("开始微调")
|
train_button = gr.Button("开始微调")
|
||||||
|
|
||||||
# 训练状态输出
|
# 训练状态输出
|
||||||
output = gr.Textbox(label="训练日志", interactive=False)
|
output = gr.Textbox(label="训练日志", interactive=False)
|
||||||
|
|
||||||
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps):
|
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||||||
# 使用动态传入的超参数
|
# 使用动态传入的超参数
|
||||||
learning_rate = float(learning_rate)
|
learning_rate = float(learning_rate)
|
||||||
per_device_train_batch_size = int(per_device_train_batch_size)
|
per_device_train_batch_size = int(per_device_train_batch_size)
|
||||||
epoch = int(epoch)
|
epoch = int(epoch)
|
||||||
save_steps = int(save_steps) # 新增保存步数参数
|
save_steps = int(save_steps) # 新增保存步数参数
|
||||||
|
lora_rank = int(lora_rank) # 新增LoRA秩参数
|
||||||
|
|
||||||
# 模型配置参数
|
# 模型配置参数
|
||||||
dtype = None # 数据类型,None表示自动选择
|
dtype = None # 数据类型,None表示自动选择
|
||||||
@ -62,7 +64,7 @@ def train_page():
|
|||||||
model,
|
model,
|
||||||
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
|
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
|
||||||
# 建议: 8-32之间
|
# 建议: 8-32之间
|
||||||
r=16,
|
r=lora_rank, # 使用动态传入的LoRA秩
|
||||||
# 需要应用LoRA的目标模块列表
|
# 需要应用LoRA的目标模块列表
|
||||||
target_modules=[
|
target_modules=[
|
||||||
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
|
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
|
||||||
@ -161,6 +163,5 @@ if __name__ == "__main__":
|
|||||||
from model_manage_page import model_manage_page
|
from model_manage_page import model_manage_page
|
||||||
init_global_var("workdir")
|
init_global_var("workdir")
|
||||||
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
|
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
|
||||||
# demo = gr.TabbedInterface([ train_page()], ["模型管理", "聊天"])
|
|
||||||
demo.queue()
|
demo.queue()
|
||||||
demo.launch()
|
demo.launch()
|
Loading…
x
Reference in New Issue
Block a user