Compare commits
	
		
			2 Commits
		
	
	
		
			7907b96baa
			...
			mvp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | d35475d9e8 | ||
|   | f882b82e57 | 
							
								
								
									
										21
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | MIT License | ||||||
|  |  | ||||||
|  | Copyright (c) 2022 C-a-r-r-y | ||||||
|  |  | ||||||
|  | Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
|  | of this software and associated documentation files (the "Software"), to deal | ||||||
|  | in the Software without restriction, including without limitation the rights | ||||||
|  | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||||
|  | copies of the Software, and to permit persons to whom the Software is | ||||||
|  | furnished to do so, subject to the following conditions: | ||||||
|  |  | ||||||
|  | The above copyright notice and this permission notice shall be included in all | ||||||
|  | copies or substantial portions of the Software. | ||||||
|  |  | ||||||
|  | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||||
|  | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||||
|  | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||
|  | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||||
|  | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||||
|  | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||||
|  | SOFTWARE. | ||||||
							
								
								
									
										15
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | |||||||
|  | # 基于文档驱动的自适应编码大模型微调框架 | ||||||
|  | ## 简介 | ||||||
|  | 本人的毕业设计,这个是mvp分支(MVP 是指最小可行产品Minimum Viable Product),其他功能在master分支中 | ||||||
|  | ### 项目概述 | ||||||
|  |  | ||||||
|  | * 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。 | ||||||
|  |  | ||||||
|  | ### 项目技术(预计) | ||||||
|  |  | ||||||
|  | * 使用unsloth框架在GPU上实现大语言模型的qlora微调 | ||||||
|  | * 使用langchain框架编写工作流实现批量生成微调语料 | ||||||
|  | * 使用tinydb和sqlite实现数据的持久化 | ||||||
|  | * 使用gradio框架实现前端展示 | ||||||
|  |  | ||||||
|  | **施工中......** | ||||||
							
								
								
									
										15
									
								
								config/llm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								config/llm.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | |||||||
|  | import os | ||||||
|  | from dotenv import load_dotenv | ||||||
|  | from typing import Dict, Any | ||||||
|  |  | ||||||
|  | def load_config() -> Dict[str, Any]: | ||||||
|  |     """从.env文件加载配置""" | ||||||
|  |     load_dotenv() | ||||||
|  |      | ||||||
|  |     return { | ||||||
|  |         "openai": { | ||||||
|  |             "api_key": os.getenv("OPENAI_API_KEY"), | ||||||
|  |             "base_url": os.getenv("OPENAI_BASE_URL"), | ||||||
|  |             "model_id": os.getenv("OPENAI_MODEL_ID") | ||||||
|  |         } | ||||||
|  |     } | ||||||
							
								
								
									
										94
									
								
								dataset_generator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								dataset_generator.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | |||||||
|  | import os | ||||||
|  | import json | ||||||
|  | from tools.parse_markdown import parse_markdown, MarkdownNode | ||||||
|  | from tools.openai_api import generate_json_via_llm | ||||||
|  | from prompt.base import create_dataset | ||||||
|  | from config.llm import load_config | ||||||
|  | from tqdm import tqdm | ||||||
|  |  | ||||||
|  | def process_markdown_file(file_path): | ||||||
|  |     with open(file_path, 'r', encoding='utf-8') as f: | ||||||
|  |         content = f.read() | ||||||
|  |      | ||||||
|  |     root = parse_markdown(content) | ||||||
|  |     results = [] | ||||||
|  |      | ||||||
|  |     def traverse(node, parent_titles): | ||||||
|  |         current_titles = parent_titles.copy() | ||||||
|  |         current_titles.append(node.title) | ||||||
|  |          | ||||||
|  |         if not node.children:  # 叶子节点 | ||||||
|  |             if node.content: | ||||||
|  |                 full_text = ' -> '.join(current_titles) + '\n' + node.content | ||||||
|  |                 results.append(full_text) | ||||||
|  |         else: | ||||||
|  |             for child in node.children: | ||||||
|  |                 traverse(child, current_titles) | ||||||
|  |      | ||||||
|  |     traverse(root, []) | ||||||
|  |     return results | ||||||
|  |  | ||||||
|  | def find_markdown_files(directory): | ||||||
|  |     markdown_files = [] | ||||||
|  |     for root, dirs, files in os.walk(directory): | ||||||
|  |         for file in files: | ||||||
|  |             if file.endswith('.md'): | ||||||
|  |                 markdown_files.append(os.path.join(root, file)) | ||||||
|  |     return markdown_files | ||||||
|  |  | ||||||
|  | def process_all_markdown(doc_dir): | ||||||
|  |     all_results = [] | ||||||
|  |      | ||||||
|  |     markdown_files = find_markdown_files(doc_dir) | ||||||
|  |     for file_path in markdown_files: | ||||||
|  |         results = process_markdown_file(file_path) | ||||||
|  |         all_results.extend(results) | ||||||
|  |      | ||||||
|  |     return all_results | ||||||
|  |  | ||||||
|  | def save_dataset(dataset, output_dir): | ||||||
|  |     os.makedirs(output_dir, exist_ok=True) | ||||||
|  |     output_path = os.path.join(output_dir, 'dataset.json') | ||||||
|  |     with open(output_path, 'w', encoding='utf-8') as f: | ||||||
|  |         json.dump(dataset, f, ensure_ascii=False, indent=2) | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     # 解析markdown文档 | ||||||
|  |     results = process_all_markdown('workdir/my_docs') | ||||||
|  |      | ||||||
|  |     # 加载LLM配置 | ||||||
|  |     config = load_config() | ||||||
|  |      | ||||||
|  |     dataset = [] | ||||||
|  |     # 使用tqdm包装外部循环以显示进度条 | ||||||
|  |     for content in tqdm(results, desc="生成数据集进度", unit="文档"): | ||||||
|  |         for _ in range(3): | ||||||
|  |             prompt = create_dataset.create( | ||||||
|  |                 "LLaMA-Factory",  # 项目名 | ||||||
|  |                 content,  # 文档内容 | ||||||
|  |                 """{ | ||||||
|  |             "dataset":[ | ||||||
|  |                 { | ||||||
|  |                     "question":"", | ||||||
|  |                     "answer":"" | ||||||
|  |                 } | ||||||
|  |             ] | ||||||
|  |         }""" | ||||||
|  |             ) | ||||||
|  |              | ||||||
|  |             # 调用LLM生成JSON | ||||||
|  |             try: | ||||||
|  |                 result = generate_json_via_llm( | ||||||
|  |                     prompt=prompt, | ||||||
|  |                     base_url=config["openai"]["base_url"], | ||||||
|  |                     api_key=config["openai"]["api_key"], | ||||||
|  |                     model_id=config["openai"]["model_id"] | ||||||
|  |                 ) | ||||||
|  |                 print(json.loads(result)["dataset"]) | ||||||
|  |                 dataset.extend(json.loads(result)["dataset"]) | ||||||
|  |             except Exception as e: | ||||||
|  |                 print(f"生成数据集时出错: {e}") | ||||||
|  |  | ||||||
|  |     # 保存数据集 | ||||||
|  |     save_dataset(dataset, 'workdir/dataset2') | ||||||
|  |     print(f"数据集已生成,共{len(dataset)}条数据") | ||||||
							
								
								
									
										25
									
								
								prompt/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								prompt/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | |||||||
|  | class create_dataset: | ||||||
|  |     """用于生成微调数据集模板的类""" | ||||||
|  |      | ||||||
|  |     template = """ | ||||||
|  | 项目名为:{} | ||||||
|  | 请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文 | ||||||
|  | 文档节选:{} | ||||||
|  | 按照如下json格式返回:{} | ||||||
|  | """ | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def create(*args: any) -> str: | ||||||
|  |         """根据提供的任意数量参数生成数据集模板 | ||||||
|  |          | ||||||
|  |         Args: | ||||||
|  |             *args: 任意数量的参数,将按顺序填充到模板中 | ||||||
|  |              | ||||||
|  |         Returns: | ||||||
|  |             格式化后的模板字符串 | ||||||
|  |         """ | ||||||
|  |         return create_dataset.template.format(*args) | ||||||
|  |      | ||||||
|  |  | ||||||
|  | if __name__=="__main__": | ||||||
|  |     print(create_dataset.create("a", "b", "c")) | ||||||
							
								
								
									
										9
									
								
								schema/dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								schema/dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | |||||||
|  | from pydantic import BaseModel, RootModel | ||||||
|  | from typing import List | ||||||
|  |  | ||||||
|  | class QAPair(BaseModel): | ||||||
|  |     question: str | ||||||
|  |     response: str | ||||||
|  |  | ||||||
|  | class QAArray(RootModel): | ||||||
|  |     root: List[QAPair] | ||||||
							
								
								
									
										69
									
								
								tools/openai_api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								tools/openai_api.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | |||||||
|  | import json | ||||||
|  | from openai import OpenAI | ||||||
|  |  | ||||||
|  | def generate_json_via_llm( | ||||||
|  |     prompt: str, | ||||||
|  |     base_url: str, | ||||||
|  |     api_key: str, | ||||||
|  |     model_id: str | ||||||
|  | ) -> str: | ||||||
|  |     client = OpenAI( | ||||||
|  |         api_key=api_key, | ||||||
|  |         base_url=base_url | ||||||
|  |     ) | ||||||
|  |      | ||||||
|  |     try: | ||||||
|  |  | ||||||
|  |          | ||||||
|  |         response = client.chat.completions.create( | ||||||
|  |             model=model_id, | ||||||
|  |             messages=[ | ||||||
|  |                 { | ||||||
|  |                     "role": "user", | ||||||
|  |                     "content": prompt | ||||||
|  |                 } | ||||||
|  |             ], | ||||||
|  |             response_format={ | ||||||
|  |                 'type': 'json_object' | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |          | ||||||
|  |         return response.choices[0].message.content | ||||||
|  |     except Exception as e: | ||||||
|  |         raise RuntimeError(f"API请求失败: {e}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     import sys | ||||||
|  |     import os | ||||||
|  |     sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | ||||||
|  |     from config.llm import load_config | ||||||
|  |     # 将项目根目录添加到 sys.path 中 | ||||||
|  |  | ||||||
|  |     # 示例用法 | ||||||
|  |     try: | ||||||
|  |         config = load_config() | ||||||
|  |         print(config) | ||||||
|  |         result = generate_json_via_llm( | ||||||
|  |             prompt="""测试,随便生成点什么,返回json格式的字符串,格式如下 | ||||||
|  | { | ||||||
|  |     "dataset":[ | ||||||
|  |         { | ||||||
|  |             "question":"", | ||||||
|  |             "answer":"" | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |             "question":"", | ||||||
|  |             "answer":"" | ||||||
|  |         } | ||||||
|  |         ...... | ||||||
|  |     ] | ||||||
|  | } | ||||||
|  |             """, | ||||||
|  |             base_url=config["openai"]["base_url"], | ||||||
|  |             api_key=config["openai"]["api_key"], | ||||||
|  |             model_id=config["openai"]["model_id"], | ||||||
|  |         ) | ||||||
|  |         print(result) | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"错误: {e}") | ||||||
							
								
								
									
										1534
									
								
								train.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1534
									
								
								train.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										70
									
								
								trainer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								trainer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,70 @@ | |||||||
|  | from unsloth import FastLanguageModel | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | # 基础配置参数 | ||||||
|  | max_seq_length = 4096 # 最大序列长度 | ||||||
|  | dtype = None # 自动检测数据类型 | ||||||
|  | load_in_4bit = True # 使用4位量化以减少内存使用 | ||||||
|  |  | ||||||
|  | # 加载预训练模型和分词器 | ||||||
|  | model, tokenizer = FastLanguageModel.from_pretrained( | ||||||
|  |     model_name = "workdir\model\Qwen2.5-3B-Instruct-bnb-4bit", # 选择Qwen2.5 32B指令模型 | ||||||
|  |     max_seq_length = max_seq_length, | ||||||
|  |     dtype = dtype, | ||||||
|  |     load_in_4bit = load_in_4bit, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | model = FastLanguageModel.get_peft_model( | ||||||
|  |     model, | ||||||
|  |     r = 64, # LoRA秩,控制可训练参数数量 | ||||||
|  |     target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", | ||||||
|  |                       "gate_proj", "up_proj", "down_proj",], # 需要训练的目标模块 | ||||||
|  |     lora_alpha = 64, # LoRA缩放因子 | ||||||
|  |     lora_dropout = 0, # LoRA dropout率 | ||||||
|  |     bias = "none", # 是否训练偏置项 | ||||||
|  |     use_gradient_checkpointing = "unsloth", # 使用梯度检查点节省显存 | ||||||
|  |     random_state = 114514, # 随机数种子 | ||||||
|  |     use_rslora = False, # 是否使用稳定版LoRA | ||||||
|  |     loftq_config = None, # LoftQ配置 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | from unsloth.chat_templates import get_chat_template | ||||||
|  | # 配置分词器使用qwen-2.5对话模板 | ||||||
|  | tokenizer = get_chat_template( | ||||||
|  |     tokenizer, | ||||||
|  |     chat_template="qwen-2.5", | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | def formatting_prompts_func(examples): | ||||||
|  |     """格式化对话数据的函数 | ||||||
|  |     Args: | ||||||
|  |         examples: 包含对话列表的字典 | ||||||
|  |     Returns: | ||||||
|  |         包含格式化文本的字典 | ||||||
|  |     """ | ||||||
|  |     questions = examples["question"] | ||||||
|  |     answer = examples["answer"] | ||||||
|  |      | ||||||
|  |     # 将Question和Response组合成对话形式 | ||||||
|  |     convos = [ | ||||||
|  |         [{"role": "user", "content": q}, {"role": "assistant", "content": r}] | ||||||
|  |         for q, r in zip(questions, answer) | ||||||
|  |     ] | ||||||
|  |      | ||||||
|  |     # 使用tokenizer.apply_chat_template格式化对话 | ||||||
|  |     texts = [ | ||||||
|  |         tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) | ||||||
|  |         for convo in convos | ||||||
|  |     ] | ||||||
|  |      | ||||||
|  |     return {"text": texts} | ||||||
|  |  | ||||||
|  | from unsloth.chat_templates import standardize_sharegpt | ||||||
|  |  | ||||||
|  | # 加载数据集 | ||||||
|  | from datasets import load_dataset | ||||||
|  | dataset = load_dataset("json", data_files="workdir\dataset\dataset.json") | ||||||
|  | dataset = dataset.map(formatting_prompts_func, batched = True) | ||||||
|  |  | ||||||
|  | print(dataset[5]) | ||||||
|  | print(dataset[5]["text"]) | ||||||
		Reference in New Issue
	
	Block a user