diff --git a/tools/model.py b/tools/model.py index 37d5a91..b06c0bc 100644 --- a/tools/model.py +++ b/tools/model.py @@ -31,8 +31,17 @@ def formatting_prompts(examples,tokenizer): return {"text": texts} -def train_model(model, tokenizer, dataset, output_dir, learning_rate, - per_device_train_batch_size, epoch, save_steps, lora_rank): +def train_model( + model, + tokenizer, + dataset: list, + output_dir: str, + learning_rate: float, + per_device_train_batch_size: int, + epoch: int, + save_steps: int, + lora_rank: int +) -> None: # 模型配置参数 dtype = None # 数据类型,None表示自动选择 load_in_4bit = False # 使用4bit量化加载模型以节省显存 @@ -75,8 +84,8 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate, chat_template="qwen-2.5", ) - dataset = HFDataset.from_list(dataset) - dataset = dataset.map(formatting_prompts, + train_dataset = HFDataset.from_list(dataset) + train_dataset = train_dataset.map(formatting_prompts, fn_kwargs={"tokenizer": tokenizer}, batched=True) @@ -84,7 +93,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate, trainer = SFTTrainer( model=model, # 待训练的模型 tokenizer=tokenizer, # 分词器 - train_dataset=dataset, # 训练数据集 + train_dataset=train_dataset, # 训练数据集 dataset_text_field="text", # 数据集字段的名称 max_seq_length=model.max_seq_length, # 最大序列长度 data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), @@ -96,7 +105,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate, warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率 learning_rate=learning_rate, # 学习率 lr_scheduler_type="linear", # 线性学习率调度器 - max_steps=int(epoch * len(dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据) + max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据) fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16 bf16=is_bfloat16_supported(), # 如果支持则使用bf16 logging_steps=1, # 每1步记录一次日志