gzhu-biyesheji/train.ipynb

1535 lines
44 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\sbtwc\\.conda\\envs\\unsloth_env\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Unsloth: OpenAI failed to import - ignoring for now.\n",
"🦥 Unsloth Zoo will now patch everything to make training faster!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\sbtwc\\.conda\\envs\\unsloth_env\\Lib\\site-packages\\unsloth_zoo\\gradient_checkpointing.py:330: UserWarning: expandable_segments not supported on this platform (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\pytorch\\c10/cuda/CUDAAllocatorConfig.h:28.)\n",
" GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f\"cuda:{i}\") for i in range(n_gpus)])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"==((====))== Unsloth 2025.3.9: Fast Qwen2 patching. Transformers: 4.48.1.\n",
" \\\\ /| NVIDIA GeForce RTX 3060 Laptop GPU. Num GPUs = 1. Max memory: 6.0 GB. Platform: Windows.\n",
"O^O/ \\_/ \\ Torch: 2.6.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.2.0\n",
"\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]\n",
" \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
"Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Unsloth 2025.3.9 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.\n"
]
}
],
"source": [
"from unsloth import FastLanguageModel\n",
"import torch\n",
"\n",
"# 基础配置参数\n",
"max_seq_length = 4096 # 最大序列长度\n",
"dtype = None # 自动检测数据类型\n",
"load_in_4bit = True # 使用4位量化以减少内存使用\n",
"\n",
"# 加载预训练模型和分词器\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name = \"workdir/model/Qwen2.5-3B-Instruct-bnb-4bit\", # 选择Qwen2.5 32B指令模型\n",
" max_seq_length = max_seq_length,\n",
" dtype = dtype,\n",
" load_in_4bit = load_in_4bit,\n",
")\n",
"\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r = 64, # LoRA秩,控制可训练参数数量\n",
" target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
" \"gate_proj\", \"up_proj\", \"down_proj\",], # 需要训练的目标模块\n",
" lora_alpha = 64, # LoRA缩放因子\n",
" lora_dropout = 0, # LoRA dropout率\n",
" bias = \"none\", # 是否训练偏置项\n",
" use_gradient_checkpointing = \"unsloth\", # 使用梯度检查点节省显存\n",
" random_state = 114514, # 随机数种子\n",
" use_rslora = False, # 是否使用稳定版LoRA\n",
" loftq_config = None, # LoftQ配置\n",
")\n",
"\n",
"from unsloth.chat_templates import get_chat_template\n",
"# 配置分词器使用qwen-2.5对话模板\n",
"tokenizer = get_chat_template(\n",
" tokenizer,\n",
" chat_template=\"qwen-2.5\",\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map: 100%|██████████| 595/595 [00:00<00:00, 7000.21 examples/s]\n"
]
}
],
"source": [
"# 加载数据集\n",
"import json\n",
"\n",
"def formatting_prompts_func(examples):\n",
" \"\"\"格式化对话数据的函数\n",
" Args:\n",
" examples: 包含对话列表的字典\n",
" Returns:\n",
" 包含格式化文本的字典\n",
" \"\"\"\n",
" questions = examples[\"question\"]\n",
" answer = examples[\"answer\"]\n",
" \n",
" # 将Question和Response组合成对话形式\n",
" convos = [\n",
" [{\"role\": \"user\", \"content\": q}, {\"role\": \"assistant\", \"content\": r}]\n",
" for q, r in zip(questions, answer)\n",
" ]\n",
" \n",
" # 使用tokenizer.apply_chat_template格式化对话\n",
" texts = [\n",
" tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)\n",
" for convo in convos\n",
" ]\n",
" \n",
" return {\"text\": texts}\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'question': 'LLaMA-Factory有哪些训练方法', 'answer': 'LLaMA-Factory提供了Pre-training和Post-training两种训练方法。', 'text': '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nLLaMA-Factory有哪些训练方法<|im_end|>\\n<|im_start|>assistant\\nLLaMA-Factory提供了Pre-training和Post-training两种训练方法。<|im_end|>\\n'}\n"
]
}
],
"source": [
"from datasets import load_dataset\n",
"dataset = load_dataset(\"json\", data_files=\"workdir\\dataset\\dataset.json\",split=\"train\")\n",
"dataset = dataset.map(formatting_prompts_func, batched = True)\n",
"\n",
"print(dataset[5])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPU = NVIDIA GeForce RTX 3060 Laptop GPU. Max memory = 6.0 GB.\n",
"2.557 GB of memory reserved.\n"
]
}
],
"source": [
"from trl import SFTTrainer\n",
"from transformers import TrainingArguments, DataCollatorForSeq2Seq\n",
"from unsloth import is_bfloat16_supported\n",
"\n",
"# 配置训练器\n",
"trainer = SFTTrainer(\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" train_dataset=dataset,\n",
" dataset_text_field=\"text\",\n",
" max_seq_length=max_seq_length,\n",
" data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),\n",
" dataset_num_proc=1,\n",
" packing=False,\n",
" args=TrainingArguments(\n",
" per_device_train_batch_size=1, # 每个设备的批次大小\n",
" gradient_accumulation_steps=4, # 梯度累积步数\n",
" warmup_steps=3*50, # 预热步数\n",
" max_steps=3*500, # 最大训练步数\n",
" learning_rate=1e-4, # 学习率\n",
" fp16=not is_bfloat16_supported(), # 是否使用fp16\n",
" bf16=is_bfloat16_supported(), # 是否使用bf16\n",
" logging_steps=5, # 日志记录间隔\n",
" optim=\"paged_adamw_8bit\", # 优化器\n",
" weight_decay=0.01, # 权重衰减\n",
" lr_scheduler_type=\"linear\", # 学习率调度器\n",
" seed=114514, # 随机种子\n",
" output_dir=\"workdir/checkpoint/\", # 输出目录\n",
" save_strategy=\"steps\", # 按步保存中间权重\n",
" save_steps=200, # 每多少步保存一次中间权重\n",
" report_to=\"none\", # 不使用外部日志工具\n",
" ),\n",
")\n",
"\n",
"from unsloth.chat_templates import train_on_responses_only\n",
"# 设置仅对助手回复部分计算损失\n",
"trainer = train_on_responses_only(\n",
" trainer,\n",
" instruction_part = \"<|im_start|>user\\n\",\n",
" response_part = \"<|im_start|>assistant\\n\",\n",
")\n",
"\n",
"# 获取GPU信息\n",
"gpu_stats = torch.cuda.get_device_properties(0)\n",
"start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
"max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
"print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
"print(f\"{start_gpu_memory} GB of memory reserved.\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1\n",
" \\\\ /| Num examples = 595 | Num Epochs = 11 | Total steps = 1,500\n",
"O^O/ \\_/ \\ Batch size per device = 1 | Gradient accumulation steps = 4\n",
"\\ / Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4\n",
" \"-____-\" Trainable parameters = 119,734,272/1,818,406,912 (6.58% trained)\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='1500' max='1500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [1500/1500 1:15:43, Epoch 10/11]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>2.941600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>2.629000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>2.573900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>1.995600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>1.651400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>1.505400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>35</td>\n",
" <td>1.709300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>40</td>\n",
" <td>1.530200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>45</td>\n",
" <td>1.362000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>50</td>\n",
" <td>1.413000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>55</td>\n",
" <td>1.291000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>60</td>\n",
" <td>1.365500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>65</td>\n",
" <td>1.374200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>1.313900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>1.388400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>80</td>\n",
" <td>1.292300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>85</td>\n",
" <td>1.205700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>90</td>\n",
" <td>1.162200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>95</td>\n",
" <td>1.194100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>100</td>\n",
" <td>0.905200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>105</td>\n",
" <td>1.107500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>110</td>\n",
" <td>0.915300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>115</td>\n",
" <td>1.197300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>120</td>\n",
" <td>0.832900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>125</td>\n",
" <td>1.005700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>130</td>\n",
" <td>0.883100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>135</td>\n",
" <td>1.002800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>140</td>\n",
" <td>0.871200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>145</td>\n",
" <td>0.896200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>150</td>\n",
" <td>0.738400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>155</td>\n",
" <td>0.821300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>160</td>\n",
" <td>0.540100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>165</td>\n",
" <td>0.528400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>170</td>\n",
" <td>0.748300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>175</td>\n",
" <td>0.530400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>180</td>\n",
" <td>0.827400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>185</td>\n",
" <td>0.462000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>190</td>\n",
" <td>0.662100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>195</td>\n",
" <td>0.460600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>200</td>\n",
" <td>0.489000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>205</td>\n",
" <td>0.718200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>210</td>\n",
" <td>0.560800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>215</td>\n",
" <td>0.465400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>220</td>\n",
" <td>0.563800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>225</td>\n",
" <td>0.511300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>230</td>\n",
" <td>0.633600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>235</td>\n",
" <td>0.672000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>240</td>\n",
" <td>0.512300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>245</td>\n",
" <td>0.435900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>250</td>\n",
" <td>0.602000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>255</td>\n",
" <td>0.410400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>260</td>\n",
" <td>0.444500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>265</td>\n",
" <td>0.498200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>270</td>\n",
" <td>0.474100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>275</td>\n",
" <td>0.499900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>280</td>\n",
" <td>0.472900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>285</td>\n",
" <td>0.515300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>290</td>\n",
" <td>0.710300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>295</td>\n",
" <td>0.471400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>300</td>\n",
" <td>0.435600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>305</td>\n",
" <td>0.227500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>310</td>\n",
" <td>0.329300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>315</td>\n",
" <td>0.278500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>320</td>\n",
" <td>0.199700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>325</td>\n",
" <td>0.185300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>330</td>\n",
" <td>0.319400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>335</td>\n",
" <td>0.221700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>340</td>\n",
" <td>0.199100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>345</td>\n",
" <td>0.263700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>350</td>\n",
" <td>0.162200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>355</td>\n",
" <td>0.240800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>360</td>\n",
" <td>0.233900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>365</td>\n",
" <td>0.283800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>370</td>\n",
" <td>0.234300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>375</td>\n",
" <td>0.280000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>380</td>\n",
" <td>0.421600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>385</td>\n",
" <td>0.244700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>390</td>\n",
" <td>0.263500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>395</td>\n",
" <td>0.227000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>400</td>\n",
" <td>0.200800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>405</td>\n",
" <td>0.196800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>410</td>\n",
" <td>0.226100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>415</td>\n",
" <td>0.267700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>420</td>\n",
" <td>0.166400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>425</td>\n",
" <td>0.307700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>430</td>\n",
" <td>0.295600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>435</td>\n",
" <td>0.184200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>440</td>\n",
" <td>0.196100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>445</td>\n",
" <td>0.220000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>450</td>\n",
" <td>0.137200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>455</td>\n",
" <td>0.115100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>460</td>\n",
" <td>0.157100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>465</td>\n",
" <td>0.142500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>470</td>\n",
" <td>0.161300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>475</td>\n",
" <td>0.121700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>480</td>\n",
" <td>0.182200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>485</td>\n",
" <td>0.094400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>490</td>\n",
" <td>0.135800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>495</td>\n",
" <td>0.115100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>500</td>\n",
" <td>0.144500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>505</td>\n",
" <td>0.148100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>510</td>\n",
" <td>0.099800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>515</td>\n",
" <td>0.131800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>520</td>\n",
" <td>0.150100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>525</td>\n",
" <td>0.130300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>530</td>\n",
" <td>0.153200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>535</td>\n",
" <td>0.178000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>540</td>\n",
" <td>0.107700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>545</td>\n",
" <td>0.182500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>550</td>\n",
" <td>0.151500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>555</td>\n",
" <td>0.157400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>560</td>\n",
" <td>0.186700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>565</td>\n",
" <td>0.192900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>570</td>\n",
" <td>0.137400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>575</td>\n",
" <td>0.099100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>580</td>\n",
" <td>0.094500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>585</td>\n",
" <td>0.117100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>590</td>\n",
" <td>0.150600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>595</td>\n",
" <td>0.093600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>600</td>\n",
" <td>0.090400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>605</td>\n",
" <td>0.068400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>610</td>\n",
" <td>0.107400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>615</td>\n",
" <td>0.034200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>620</td>\n",
" <td>0.075000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>625</td>\n",
" <td>0.073900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>630</td>\n",
" <td>0.078400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>635</td>\n",
" <td>0.077900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>640</td>\n",
" <td>0.065600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>645</td>\n",
" <td>0.101700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>650</td>\n",
" <td>0.084900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>655</td>\n",
" <td>0.073000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>660</td>\n",
" <td>0.100800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>665</td>\n",
" <td>0.035700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>670</td>\n",
" <td>0.076300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>675</td>\n",
" <td>0.077500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>680</td>\n",
" <td>0.060200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>685</td>\n",
" <td>0.107900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>690</td>\n",
" <td>0.109300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>695</td>\n",
" <td>0.082700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>700</td>\n",
" <td>0.075900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>705</td>\n",
" <td>0.088300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>710</td>\n",
" <td>0.112000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>715</td>\n",
" <td>0.084100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>720</td>\n",
" <td>0.127700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>725</td>\n",
" <td>0.070700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>730</td>\n",
" <td>0.085400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>735</td>\n",
" <td>0.054400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>740</td>\n",
" <td>0.083300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>745</td>\n",
" <td>0.044600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>750</td>\n",
" <td>0.025700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>755</td>\n",
" <td>0.039400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>760</td>\n",
" <td>0.056600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>765</td>\n",
" <td>0.050800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>770</td>\n",
" <td>0.042500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>775</td>\n",
" <td>0.054000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>780</td>\n",
" <td>0.061200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>785</td>\n",
" <td>0.064100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>790</td>\n",
" <td>0.048600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>795</td>\n",
" <td>0.048600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>800</td>\n",
" <td>0.050300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>805</td>\n",
" <td>0.056400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>810</td>\n",
" <td>0.051000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>815</td>\n",
" <td>0.060900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>820</td>\n",
" <td>0.054600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>825</td>\n",
" <td>0.024800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>830</td>\n",
" <td>0.027800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>835</td>\n",
" <td>0.083900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>840</td>\n",
" <td>0.046700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>845</td>\n",
" <td>0.073400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>850</td>\n",
" <td>0.030800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>855</td>\n",
" <td>0.059400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>860</td>\n",
" <td>0.027300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>865</td>\n",
" <td>0.066000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>870</td>\n",
" <td>0.080000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>875</td>\n",
" <td>0.059600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>880</td>\n",
" <td>0.052600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>885</td>\n",
" <td>0.055900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>890</td>\n",
" <td>0.042300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>895</td>\n",
" <td>0.034500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>900</td>\n",
" <td>0.019600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>905</td>\n",
" <td>0.027800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>910</td>\n",
" <td>0.012100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>915</td>\n",
" <td>0.027300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>920</td>\n",
" <td>0.036900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>925</td>\n",
" <td>0.030100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>930</td>\n",
" <td>0.027900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>935</td>\n",
" <td>0.028700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>940</td>\n",
" <td>0.061500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>945</td>\n",
" <td>0.025500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>950</td>\n",
" <td>0.020100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>955</td>\n",
" <td>0.021700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>960</td>\n",
" <td>0.026800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>965</td>\n",
" <td>0.035700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>970</td>\n",
" <td>0.029600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>975</td>\n",
" <td>0.020600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>980</td>\n",
" <td>0.032200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>985</td>\n",
" <td>0.040200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>990</td>\n",
" <td>0.015200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>995</td>\n",
" <td>0.025100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1000</td>\n",
" <td>0.027800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1005</td>\n",
" <td>0.032900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1010</td>\n",
" <td>0.048000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1015</td>\n",
" <td>0.035200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1020</td>\n",
" <td>0.017700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1025</td>\n",
" <td>0.029700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1030</td>\n",
" <td>0.041900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1035</td>\n",
" <td>0.031100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1040</td>\n",
" <td>0.038100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1045</td>\n",
" <td>0.034800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1050</td>\n",
" <td>0.020000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1055</td>\n",
" <td>0.017500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1060</td>\n",
" <td>0.019300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1065</td>\n",
" <td>0.029900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1070</td>\n",
" <td>0.011500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1075</td>\n",
" <td>0.023900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1080</td>\n",
" <td>0.012700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1085</td>\n",
" <td>0.012800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1090</td>\n",
" <td>0.025500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1095</td>\n",
" <td>0.018500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1100</td>\n",
" <td>0.007200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1105</td>\n",
" <td>0.028500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1110</td>\n",
" <td>0.024600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1115</td>\n",
" <td>0.015600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1120</td>\n",
" <td>0.020200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1125</td>\n",
" <td>0.010500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1130</td>\n",
" <td>0.023900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1135</td>\n",
" <td>0.020300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1140</td>\n",
" <td>0.031900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1145</td>\n",
" <td>0.023100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1150</td>\n",
" <td>0.006700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1155</td>\n",
" <td>0.016700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1160</td>\n",
" <td>0.020200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1165</td>\n",
" <td>0.023900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1170</td>\n",
" <td>0.014300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1175</td>\n",
" <td>0.017000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1180</td>\n",
" <td>0.034100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1185</td>\n",
" <td>0.034900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1190</td>\n",
" <td>0.020800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1195</td>\n",
" <td>0.016200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1200</td>\n",
" <td>0.013200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1205</td>\n",
" <td>0.015100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1210</td>\n",
" <td>0.013000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1215</td>\n",
" <td>0.015700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1220</td>\n",
" <td>0.006100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1225</td>\n",
" <td>0.011600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1230</td>\n",
" <td>0.016800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1235</td>\n",
" <td>0.015200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1240</td>\n",
" <td>0.013600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1245</td>\n",
" <td>0.012000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1250</td>\n",
" <td>0.017800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1255</td>\n",
" <td>0.018500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1260</td>\n",
" <td>0.010800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1265</td>\n",
" <td>0.012700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1270</td>\n",
" <td>0.008500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1275</td>\n",
" <td>0.015700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1280</td>\n",
" <td>0.016000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1285</td>\n",
" <td>0.012100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1290</td>\n",
" <td>0.019400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1295</td>\n",
" <td>0.018100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1300</td>\n",
" <td>0.009400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1305</td>\n",
" <td>0.026600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1310</td>\n",
" <td>0.006500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1315</td>\n",
" <td>0.010900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1320</td>\n",
" <td>0.026600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1325</td>\n",
" <td>0.021100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1330</td>\n",
" <td>0.012100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1335</td>\n",
" <td>0.014700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1340</td>\n",
" <td>0.018100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1345</td>\n",
" <td>0.009500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1350</td>\n",
" <td>0.008000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1355</td>\n",
" <td>0.006400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1360</td>\n",
" <td>0.007300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1365</td>\n",
" <td>0.007600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1370</td>\n",
" <td>0.010000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1375</td>\n",
" <td>0.010200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1380</td>\n",
" <td>0.013900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1385</td>\n",
" <td>0.006400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1390</td>\n",
" <td>0.005500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1395</td>\n",
" <td>0.009000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1400</td>\n",
" <td>0.010400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1405</td>\n",
" <td>0.011200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1410</td>\n",
" <td>0.007200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1415</td>\n",
" <td>0.008300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1420</td>\n",
" <td>0.006000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1425</td>\n",
" <td>0.010000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1430</td>\n",
" <td>0.008900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1435</td>\n",
" <td>0.013300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1440</td>\n",
" <td>0.015000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1445</td>\n",
" <td>0.013900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1450</td>\n",
" <td>0.016000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1455</td>\n",
" <td>0.019200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1460</td>\n",
" <td>0.012300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1465</td>\n",
" <td>0.012800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1470</td>\n",
" <td>0.011900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1475</td>\n",
" <td>0.016400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1480</td>\n",
" <td>0.008100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1485</td>\n",
" <td>0.007700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1490</td>\n",
" <td>0.006300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1495</td>\n",
" <td>0.002100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1500</td>\n",
" <td>0.008000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#trainer_stats = trainer.train()\n",
"trainer_stats = trainer.train(resume_from_checkpoint = False)\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"llamafactory支持多种类型的微调具体支持的类型可以通过配置文件的`finetuning_type`字段来指定。例如,在`examples/inference/llamafactory_lora_sft.yaml`配置文件中,`finetuning_type`指定为`lora`表示使用LoRA微调方法。<|im_end|>\n"
]
}
],
"source": [
"FastLanguageModel.for_inference(model)\n",
"\n",
"# 准备相同的测试输入\n",
"messages = [\n",
" {\"role\": \"user\", \"content\": \n",
" \"\"\"llama factory支持什么类型的微调\"\"\"}\n",
"]\n",
"inputs = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize = True,\n",
" add_generation_prompt = True,\n",
" return_tensors = \"pt\",\n",
").to(\"cuda\")\n",
"\n",
"# 使用TextStreamer进行流式生成\n",
"from transformers import TextStreamer\n",
"text_streamer = TextStreamer(tokenizer, skip_prompt = True)\n",
"_ = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 128,\n",
" use_cache = True, temperature = 1.5, min_p = 0.1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}