refactor(tools): 优化 train_model 函数定义
- 添加类型注解,提高代码可读性和维护性 - 使用多行格式定义函数参数,提升代码格式美观
This commit is contained in:
parent
1a2ca3e244
commit
4f09823123
@ -31,8 +31,17 @@ def formatting_prompts(examples,tokenizer):
|
|||||||
return {"text": texts}
|
return {"text": texts}
|
||||||
|
|
||||||
|
|
||||||
def train_model(model, tokenizer, dataset, output_dir, learning_rate,
|
def train_model(
|
||||||
per_device_train_batch_size, epoch, save_steps, lora_rank):
|
model,
|
||||||
|
tokenizer,
|
||||||
|
dataset: list,
|
||||||
|
output_dir: str,
|
||||||
|
learning_rate: float,
|
||||||
|
per_device_train_batch_size: int,
|
||||||
|
epoch: int,
|
||||||
|
save_steps: int,
|
||||||
|
lora_rank: int
|
||||||
|
) -> None:
|
||||||
# 模型配置参数
|
# 模型配置参数
|
||||||
dtype = None # 数据类型,None表示自动选择
|
dtype = None # 数据类型,None表示自动选择
|
||||||
load_in_4bit = False # 使用4bit量化加载模型以节省显存
|
load_in_4bit = False # 使用4bit量化加载模型以节省显存
|
||||||
@ -75,8 +84,8 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
|
|||||||
chat_template="qwen-2.5",
|
chat_template="qwen-2.5",
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = HFDataset.from_list(dataset)
|
train_dataset = HFDataset.from_list(dataset)
|
||||||
dataset = dataset.map(formatting_prompts,
|
train_dataset = train_dataset.map(formatting_prompts,
|
||||||
fn_kwargs={"tokenizer": tokenizer},
|
fn_kwargs={"tokenizer": tokenizer},
|
||||||
batched=True)
|
batched=True)
|
||||||
|
|
||||||
@ -84,7 +93,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
|
|||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=model, # 待训练的模型
|
model=model, # 待训练的模型
|
||||||
tokenizer=tokenizer, # 分词器
|
tokenizer=tokenizer, # 分词器
|
||||||
train_dataset=dataset, # 训练数据集
|
train_dataset=train_dataset, # 训练数据集
|
||||||
dataset_text_field="text", # 数据集字段的名称
|
dataset_text_field="text", # 数据集字段的名称
|
||||||
max_seq_length=model.max_seq_length, # 最大序列长度
|
max_seq_length=model.max_seq_length, # 最大序列长度
|
||||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||||
@ -96,7 +105,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
|
|||||||
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
|
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
|
||||||
learning_rate=learning_rate, # 学习率
|
learning_rate=learning_rate, # 学习率
|
||||||
lr_scheduler_type="linear", # 线性学习率调度器
|
lr_scheduler_type="linear", # 线性学习率调度器
|
||||||
max_steps=int(epoch * len(dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据)
|
max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据)
|
||||||
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
|
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
|
||||||
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
|
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
|
||||||
logging_steps=1, # 每1步记录一次日志
|
logging_steps=1, # 每1步记录一次日志
|
||||||
|
Loading…
x
Reference in New Issue
Block a user