From 4f09823123b13ff427173916fc4d854c74e4c04b Mon Sep 17 00:00:00 2001 From: carry Date: Mon, 14 Apr 2025 14:28:36 +0800 Subject: [PATCH] =?UTF-8?q?refactor(tools):=20=E4=BC=98=E5=8C=96=20train?= =?UTF-8?q?=5Fmodel=20=E5=87=BD=E6=95=B0=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加类型注解,提高代码可读性和维护性 - 使用多行格式定义函数参数,提升代码格式美观 --- tools/model.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tools/model.py b/tools/model.py index 37d5a91..b06c0bc 100644 --- a/tools/model.py +++ b/tools/model.py @@ -31,8 +31,17 @@ def formatting_prompts(examples,tokenizer): return {"text": texts} -def train_model(model, tokenizer, dataset, output_dir, learning_rate, - per_device_train_batch_size, epoch, save_steps, lora_rank): +def train_model( + model, + tokenizer, + dataset: list, + output_dir: str, + learning_rate: float, + per_device_train_batch_size: int, + epoch: int, + save_steps: int, + lora_rank: int +) -> None: # 模型配置参数 dtype = None # 数据类型,None表示自动选择 load_in_4bit = False # 使用4bit量化加载模型以节省显存 @@ -75,8 +84,8 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate, chat_template="qwen-2.5", ) - dataset = HFDataset.from_list(dataset) - dataset = dataset.map(formatting_prompts, + train_dataset = HFDataset.from_list(dataset) + train_dataset = train_dataset.map(formatting_prompts, fn_kwargs={"tokenizer": tokenizer}, batched=True) @@ -84,7 +93,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate, trainer = SFTTrainer( model=model, # 待训练的模型 tokenizer=tokenizer, # 分词器 - train_dataset=dataset, # 训练数据集 + train_dataset=train_dataset, # 训练数据集 dataset_text_field="text", # 数据集字段的名称 max_seq_length=model.max_seq_length, # 最大序列长度 data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), @@ -96,7 +105,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate, warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率 learning_rate=learning_rate, # 学习率 lr_scheduler_type="linear", # 线性学习率调度器 - max_steps=int(epoch * len(dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据) + max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据) fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16 bf16=is_bfloat16_supported(), # 如果支持则使用bf16 logging_steps=1, # 每1步记录一次日志