Compare commits

...

2 Commits

Author SHA1 Message Date
carry
2d39b91764 feat(train_page): 添加模型训练超参数配置功能
- 新增学习率、批次大小、最大训练步数等超参数输入组件
- 实现超参数在训练过程中的动态应用
- 调整训练参数以适应不同硬件环境
- 优化训练过程,支持按步数保存模型
2025-04-13 01:04:27 +08:00
carry
5094febcb4 refactor(global_var): 重构全局变量管理并添加工作目录功能
- 添加 _workdir 全局变量以存储工作目录路径
- 在 init_global_var 函数中初始化 _workdir
- 新增 get_workdir 函数以获取工作目录路径
- 调整全局变量的定义和初始化顺序
2025-04-13 00:54:55 +08:00
2 changed files with 46 additions and 30 deletions

View File

@ -13,16 +13,16 @@ from unsloth.chat_templates import get_chat_template
from tools import formatting_prompts_func
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, get_datasets
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
from tools import formatting_prompts_func
def train_page():
with gr.Blocks() as demo:
gr.Markdown("## 微调")
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
@ -30,14 +30,27 @@ def train_page():
allow_custom_value=True,
interactive=True
)
# 新增超参数输入组件
learning_rate_input = gr.Number(value=2e-4, label="学习率")
per_device_train_batch_size_input = gr.Number(value=1, label="每设备训练批次大小", precision=0)
max_steps_input = gr.Number(value=60, label="最大训练步数", precision=0)
epoch_input = gr.Number(value=1, label="训练轮数", precision=0)
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
train_button = gr.Button("开始微调")
# 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False)
def train_model(dataset_name):
def train_model(dataset_name, max_seq_length, learning_rate, per_device_train_batch_size, epoch, save_steps):
# 使用动态传入的超参数
learning_rate = float(learning_rate)
per_device_train_batch_size = int(per_device_train_batch_size)
epoch = int(epoch)
save_steps = int(save_steps) # 新增保存步数参数
# 模型配置参数
max_seq_length = 4096 # 最大序列长度
dtype = None # 数据类型None表示自动选择
load_in_4bit = False # 使用4bit量化加载模型以节省显存
@ -78,7 +91,6 @@ def train_page():
loftq_config=None,
)
# 配置分词器
tokenizer = get_chat_template(
tokenizer,
chat_template="qwen-2.5",
@ -91,48 +103,50 @@ def train_page():
dataset = dataset.map(formatting_prompts_func,
fn_kwargs={"tokenizer": tokenizer},
batched=True)
print(dataset[5])
# 初始化SFT训练器
trainer = SFTTrainer(
model=model, # 待训练的模型
model=model, # 待训练的模型
tokenizer=tokenizer, # 分词器
train_dataset=dataset, # 训练数据集
dataset_text_field="text", # 数据集字段的名称
max_seq_length=max_seq_length, # 最大序列长度
max_seq_length=model.max_seq_length, # 最大序列长度
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=1, # 数据集处理的并行进程数提高CPU利用率
dataset_num_proc=1, # 数据集处理的并行进程数
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2, # 每个GPU的训练批次大小
per_device_train_batch_size=per_device_train_batch_size, # 每个GPU的训练批次大小
gradient_accumulation_steps=4, # 梯度累积步数,用于模拟更大的batch size
warmup_steps=5, # 预热步数,逐步增加学习率
learning_rate=2e-4, # 学习率
lr_scheduler_type="linear", # 线性学习率调度器
max_steps=60, # 最大训练步数(一步 = 处理一个batch的数据
# 根据硬件支持选择训练精度
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
learning_rate=learning_rate, # 学习率
lr_scheduler_type="linear", # 线性学习率调度器
max_steps=int(epoch * len(dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
logging_steps=1, # 每1步记录一次日志
optim="adamw_8bit", # 使用8位AdamW优化器节省显存几乎不影响训练效果
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
seed=3407, # 随机数种子
output_dir="outputs", # 保存模型检查点和训练日志
seed=114514, # 随机数种子
output_dir=get_workdir() + "/checkpoint/", # 保存模型检查点和训练日志
save_strategy="steps", # 按步保存中间权重
save_steps=20, # 每20步保存一次中间权重
save_steps=save_steps, # 使用动态传入的保存步数
# report_to="tensorboard", # 将信息输出到tensorboard
),
)
# 开始训练resume_from_checkpoint为True表示从最新的模型开始训练
trainer_stats = trainer.train(resume_from_checkpoint = True)
# 开始训练
trainer_stats = trainer.train(resume_from_checkpoint=True)
train_button.click(
fn=train_model,
inputs=dataset_dropdown,
inputs=[
dataset_dropdown,
learning_rate_input,
per_device_train_batch_size_input,
max_steps_input,
epoch_input,
save_steps_input
],
outputs=output
)

View File

@ -5,14 +5,19 @@ _prompt_store = None
_sql_engine = None
_docs = None
_datasets = None
_model = None
_tokenizer = None
_workdir = None
def init_global_var(workdir="workdir"):
"""Initialize all global variables"""
global _prompt_store, _sql_engine, _docs, _datasets
_prompt_store = get_prompt_tinydb(workdir)
_sql_engine = get_sqlite_engine(workdir)
_docs = scan_docs_directory(workdir)
_datasets = get_all_dataset(workdir)
_workdir = workdir
def get_workdir():
return _workdir
def get_prompt_store():
return _prompt_store
@ -42,9 +47,6 @@ def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
_model = None
_tokenizer = None
def get_model():
return _model