refactor(tools): 优化 train_model 函数定义

- 添加类型注解,提高代码可读性和维护性
- 使用多行格式定义函数参数,提升代码格式美观
This commit is contained in:
carry 2025-04-14 14:28:36 +08:00
parent 1a2ca3e244
commit 4f09823123

View File

@ -31,8 +31,17 @@ def formatting_prompts(examples,tokenizer):
return {"text": texts} return {"text": texts}
def train_model(model, tokenizer, dataset, output_dir, learning_rate, def train_model(
per_device_train_batch_size, epoch, save_steps, lora_rank): model,
tokenizer,
dataset: list,
output_dir: str,
learning_rate: float,
per_device_train_batch_size: int,
epoch: int,
save_steps: int,
lora_rank: int
) -> None:
# 模型配置参数 # 模型配置参数
dtype = None # 数据类型None表示自动选择 dtype = None # 数据类型None表示自动选择
load_in_4bit = False # 使用4bit量化加载模型以节省显存 load_in_4bit = False # 使用4bit量化加载模型以节省显存
@ -75,8 +84,8 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
chat_template="qwen-2.5", chat_template="qwen-2.5",
) )
dataset = HFDataset.from_list(dataset) train_dataset = HFDataset.from_list(dataset)
dataset = dataset.map(formatting_prompts, train_dataset = train_dataset.map(formatting_prompts,
fn_kwargs={"tokenizer": tokenizer}, fn_kwargs={"tokenizer": tokenizer},
batched=True) batched=True)
@ -84,7 +93,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
trainer = SFTTrainer( trainer = SFTTrainer(
model=model, # 待训练的模型 model=model, # 待训练的模型
tokenizer=tokenizer, # 分词器 tokenizer=tokenizer, # 分词器
train_dataset=dataset, # 训练数据集 train_dataset=train_dataset, # 训练数据集
dataset_text_field="text", # 数据集字段的名称 dataset_text_field="text", # 数据集字段的名称
max_seq_length=model.max_seq_length, # 最大序列长度 max_seq_length=model.max_seq_length, # 最大序列长度
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
@ -96,7 +105,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率 warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
learning_rate=learning_rate, # 学习率 learning_rate=learning_rate, # 学习率
lr_scheduler_type="linear", # 线性学习率调度器 lr_scheduler_type="linear", # 线性学习率调度器
max_steps=int(epoch * len(dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据 max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16 fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
bf16=is_bfloat16_supported(), # 如果支持则使用bf16 bf16=is_bfloat16_supported(), # 如果支持则使用bf16
logging_steps=1, # 每1步记录一次日志 logging_steps=1, # 每1步记录一次日志