feat(frontend): 完成了前端微调的代码逻辑

This commit is contained in:
carry 2025-04-12 18:42:22 +08:00
parent 9784f2aed3
commit 539e14d39c

View File

@ -1,8 +1,20 @@
import gradio as gr import gradio as gr
import sys import sys
import torch
from tinydb import Query
from pathlib import Path from pathlib import Path
from datasets import Dataset as HFDataset
from unsloth import FastLanguageModel
from trl import SFTTrainer # 用于监督微调的训练器
from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数
from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练
from unsloth.chat_templates import get_chat_template
from tools import formatting_prompts_func
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, get_datasets from global_var import get_model, get_tokenizer, get_datasets
from tools import formatting_prompts_func
def train_page(): def train_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@ -18,12 +30,119 @@ def train_page():
allow_custom_value=True, allow_custom_value=True,
interactive=True interactive=True
) )
train_button = gr.Button("开始微调")
# 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False)
def train_model(dataset_name):
# 模型配置参数
max_seq_length = 4096 # 最大序列长度
dtype = None # 数据类型None表示自动选择
load_in_4bit = False # 使用4bit量化加载模型以节省显存
# 加载预训练模型和分词器
model = get_model()
tokenizer = get_tokenizer()
model = FastLanguageModel.get_peft_model(
# 原始模型
model,
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
# 建议: 8-32之间
r=16,
# 需要应用LoRA的目标模块列表
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
"gate_proj", "up_proj", "down_proj", # FFN相关层
],
# LoRA缩放因子,用于控制LoRA更新的幅度。值越大LoRA的更新影响越大。
lora_alpha=16,
# LoRA层的dropout率,用于防止过拟合,这里设为0表示不使用dropout。
# 如果数据集较小建议设置0.1左右。
lora_dropout=0,
# 是否对bias参数进行微调,none表示不微调bias
# none: 不微调偏置参数;
# all: 微调所有参数;
# lora_only: 只微调LoRA参数。
bias="none",
# 是否使用梯度检查点技术节省显存,使用unsloth优化版本
# 会略微降低训练速度,但可以显著减少显存使用
use_gradient_checkpointing="unsloth",
# 随机数种子,用于结果复现
random_state=3407,
# 是否使用rank-stabilized LoRA,这里不使用
# 会略微降低训练速度,但可以显著减少显存使用
use_rslora=False,
# LoFTQ配置,这里不使用该量化技术,用于进一步压缩模型大小
loftq_config=None,
)
# 配置分词器
tokenizer = get_chat_template(
tokenizer,
chat_template="qwen-2.5",
)
# 加载数据集
dataset = get_datasets().get(Query().name == dataset_name)
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
dataset = HFDataset.from_list(dataset)
dataset = dataset.map(formatting_prompts_func,
fn_kwargs={"tokenizer": tokenizer},
batched=True)
print(dataset[5])
# 初始化SFT训练器
trainer = SFTTrainer(
model=model, # 待训练的模型
tokenizer=tokenizer, # 分词器
train_dataset=dataset, # 训练数据集
dataset_text_field="text", # 数据集字段的名称
max_seq_length=max_seq_length, # 最大序列长度
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=1, # 数据集处理的并行进程数提高CPU利用率
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2, # 每个GPU的训练批次大小
gradient_accumulation_steps=4, # 梯度累积步数,用于模拟更大的batch size
warmup_steps=5, # 预热步数,逐步增加学习率
learning_rate=2e-4, # 学习率
lr_scheduler_type="linear", # 线性学习率调度器
max_steps=60, # 最大训练步数(一步 = 处理一个batch的数据
# 根据硬件支持选择训练精度
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
logging_steps=1, # 每1步记录一次日志
optim="adamw_8bit", # 使用8位AdamW优化器节省显存几乎不影响训练效果
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
seed=3407, # 随机数种子
output_dir="outputs", # 保存模型检查点和训练日志
save_strategy="steps", # 按步保存中间权重
save_steps=20, # 每20步保存一次中间权重
# report_to="tensorboard", # 将信息输出到tensorboard
),
)
# 开始训练resume_from_checkpoint为True表示从最新的模型开始训练
trainer_stats = trainer.train(resume_from_checkpoint = True)
train_button.click(
fn=train_model,
inputs=dataset_dropdown,
outputs=output
)
return demo return demo
if __name__ == "__main__": if __name__ == "__main__":
from global_var import init_global_var from global_var import init_global_var
from model_manage_page import model_manage_page
init_global_var("workdir") init_global_var("workdir")
demo = train_page() demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
# demo = gr.TabbedInterface([ train_page()], ["模型管理", "聊天"])
demo.queue() demo.queue()
demo.launch() demo.launch()