Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
d35475d9e8 | ||
![]() |
f882b82e57 |
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 C-a-r-r-y
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
15
README.md
Normal file
15
README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# 基于文档驱动的自适应编码大模型微调框架
|
||||||
|
## 简介
|
||||||
|
本人的毕业设计,这个是mvp分支(MVP 是指最小可行产品Minimum Viable Product),其他功能在master分支中
|
||||||
|
### 项目概述
|
||||||
|
|
||||||
|
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
|
||||||
|
|
||||||
|
### 项目技术(预计)
|
||||||
|
|
||||||
|
* 使用unsloth框架在GPU上实现大语言模型的qlora微调
|
||||||
|
* 使用langchain框架编写工作流实现批量生成微调语料
|
||||||
|
* 使用tinydb和sqlite实现数据的持久化
|
||||||
|
* 使用gradio框架实现前端展示
|
||||||
|
|
||||||
|
**施工中......**
|
15
config/llm.py
Normal file
15
config/llm.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
def load_config() -> Dict[str, Any]:
|
||||||
|
"""从.env文件加载配置"""
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"openai": {
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
"base_url": os.getenv("OPENAI_BASE_URL"),
|
||||||
|
"model_id": os.getenv("OPENAI_MODEL_ID")
|
||||||
|
}
|
||||||
|
}
|
94
dataset_generator.py
Normal file
94
dataset_generator.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from tools.parse_markdown import parse_markdown, MarkdownNode
|
||||||
|
from tools.openai_api import generate_json_via_llm
|
||||||
|
from prompt.base import create_dataset
|
||||||
|
from config.llm import load_config
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def process_markdown_file(file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
root = parse_markdown(content)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
def traverse(node, parent_titles):
|
||||||
|
current_titles = parent_titles.copy()
|
||||||
|
current_titles.append(node.title)
|
||||||
|
|
||||||
|
if not node.children: # 叶子节点
|
||||||
|
if node.content:
|
||||||
|
full_text = ' -> '.join(current_titles) + '\n' + node.content
|
||||||
|
results.append(full_text)
|
||||||
|
else:
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child, current_titles)
|
||||||
|
|
||||||
|
traverse(root, [])
|
||||||
|
return results
|
||||||
|
|
||||||
|
def find_markdown_files(directory):
|
||||||
|
markdown_files = []
|
||||||
|
for root, dirs, files in os.walk(directory):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith('.md'):
|
||||||
|
markdown_files.append(os.path.join(root, file))
|
||||||
|
return markdown_files
|
||||||
|
|
||||||
|
def process_all_markdown(doc_dir):
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
markdown_files = find_markdown_files(doc_dir)
|
||||||
|
for file_path in markdown_files:
|
||||||
|
results = process_markdown_file(file_path)
|
||||||
|
all_results.extend(results)
|
||||||
|
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
def save_dataset(dataset, output_dir):
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
output_path = os.path.join(output_dir, 'dataset.json')
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 解析markdown文档
|
||||||
|
results = process_all_markdown('workdir/my_docs')
|
||||||
|
|
||||||
|
# 加载LLM配置
|
||||||
|
config = load_config()
|
||||||
|
|
||||||
|
dataset = []
|
||||||
|
# 使用tqdm包装外部循环以显示进度条
|
||||||
|
for content in tqdm(results, desc="生成数据集进度", unit="文档"):
|
||||||
|
for _ in range(3):
|
||||||
|
prompt = create_dataset.create(
|
||||||
|
"LLaMA-Factory", # 项目名
|
||||||
|
content, # 文档内容
|
||||||
|
"""{
|
||||||
|
"dataset":[
|
||||||
|
{
|
||||||
|
"question":"",
|
||||||
|
"answer":""
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用LLM生成JSON
|
||||||
|
try:
|
||||||
|
result = generate_json_via_llm(
|
||||||
|
prompt=prompt,
|
||||||
|
base_url=config["openai"]["base_url"],
|
||||||
|
api_key=config["openai"]["api_key"],
|
||||||
|
model_id=config["openai"]["model_id"]
|
||||||
|
)
|
||||||
|
print(json.loads(result)["dataset"])
|
||||||
|
dataset.extend(json.loads(result)["dataset"])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"生成数据集时出错: {e}")
|
||||||
|
|
||||||
|
# 保存数据集
|
||||||
|
save_dataset(dataset, 'workdir/dataset2')
|
||||||
|
print(f"数据集已生成,共{len(dataset)}条数据")
|
25
prompt/base.py
Normal file
25
prompt/base.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
class create_dataset:
|
||||||
|
"""用于生成微调数据集模板的类"""
|
||||||
|
|
||||||
|
template = """
|
||||||
|
项目名为:{}
|
||||||
|
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
|
||||||
|
文档节选:{}
|
||||||
|
按照如下json格式返回:{}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(*args: any) -> str:
|
||||||
|
"""根据提供的任意数量参数生成数据集模板
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: 任意数量的参数,将按顺序填充到模板中
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的模板字符串
|
||||||
|
"""
|
||||||
|
return create_dataset.template.format(*args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
print(create_dataset.create("a", "b", "c"))
|
9
schema/dataset.py
Normal file
9
schema/dataset.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from pydantic import BaseModel, RootModel
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class QAPair(BaseModel):
|
||||||
|
question: str
|
||||||
|
response: str
|
||||||
|
|
||||||
|
class QAArray(RootModel):
|
||||||
|
root: List[QAPair]
|
69
tools/openai_api.py
Normal file
69
tools/openai_api.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import json
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
def generate_json_via_llm(
|
||||||
|
prompt: str,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
model_id: str
|
||||||
|
) -> str:
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model_id,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
],
|
||||||
|
response_format={
|
||||||
|
'type': 'json_object'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"API请求失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
from config.llm import load_config
|
||||||
|
# 将项目根目录添加到 sys.path 中
|
||||||
|
|
||||||
|
# 示例用法
|
||||||
|
try:
|
||||||
|
config = load_config()
|
||||||
|
print(config)
|
||||||
|
result = generate_json_via_llm(
|
||||||
|
prompt="""测试,随便生成点什么,返回json格式的字符串,格式如下
|
||||||
|
{
|
||||||
|
"dataset":[
|
||||||
|
{
|
||||||
|
"question":"",
|
||||||
|
"answer":""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question":"",
|
||||||
|
"answer":""
|
||||||
|
}
|
||||||
|
......
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
base_url=config["openai"]["base_url"],
|
||||||
|
api_key=config["openai"]["api_key"],
|
||||||
|
model_id=config["openai"]["model_id"],
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误: {e}")
|
1534
train.ipynb
Normal file
1534
train.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
70
trainer.py
Normal file
70
trainer.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from unsloth import FastLanguageModel
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 基础配置参数
|
||||||
|
max_seq_length = 4096 # 最大序列长度
|
||||||
|
dtype = None # 自动检测数据类型
|
||||||
|
load_in_4bit = True # 使用4位量化以减少内存使用
|
||||||
|
|
||||||
|
# 加载预训练模型和分词器
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name = "workdir\model\Qwen2.5-3B-Instruct-bnb-4bit", # 选择Qwen2.5 32B指令模型
|
||||||
|
max_seq_length = max_seq_length,
|
||||||
|
dtype = dtype,
|
||||||
|
load_in_4bit = load_in_4bit,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = FastLanguageModel.get_peft_model(
|
||||||
|
model,
|
||||||
|
r = 64, # LoRA秩,控制可训练参数数量
|
||||||
|
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||||
|
"gate_proj", "up_proj", "down_proj",], # 需要训练的目标模块
|
||||||
|
lora_alpha = 64, # LoRA缩放因子
|
||||||
|
lora_dropout = 0, # LoRA dropout率
|
||||||
|
bias = "none", # 是否训练偏置项
|
||||||
|
use_gradient_checkpointing = "unsloth", # 使用梯度检查点节省显存
|
||||||
|
random_state = 114514, # 随机数种子
|
||||||
|
use_rslora = False, # 是否使用稳定版LoRA
|
||||||
|
loftq_config = None, # LoftQ配置
|
||||||
|
)
|
||||||
|
|
||||||
|
from unsloth.chat_templates import get_chat_template
|
||||||
|
# 配置分词器使用qwen-2.5对话模板
|
||||||
|
tokenizer = get_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
chat_template="qwen-2.5",
|
||||||
|
)
|
||||||
|
|
||||||
|
def formatting_prompts_func(examples):
|
||||||
|
"""格式化对话数据的函数
|
||||||
|
Args:
|
||||||
|
examples: 包含对话列表的字典
|
||||||
|
Returns:
|
||||||
|
包含格式化文本的字典
|
||||||
|
"""
|
||||||
|
questions = examples["question"]
|
||||||
|
answer = examples["answer"]
|
||||||
|
|
||||||
|
# 将Question和Response组合成对话形式
|
||||||
|
convos = [
|
||||||
|
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
|
||||||
|
for q, r in zip(questions, answer)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 使用tokenizer.apply_chat_template格式化对话
|
||||||
|
texts = [
|
||||||
|
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
|
||||||
|
for convo in convos
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"text": texts}
|
||||||
|
|
||||||
|
from unsloth.chat_templates import standardize_sharegpt
|
||||||
|
|
||||||
|
# 加载数据集
|
||||||
|
from datasets import load_dataset
|
||||||
|
dataset = load_dataset("json", data_files="workdir\dataset\dataset.json")
|
||||||
|
dataset = dataset.map(formatting_prompts_func, batched = True)
|
||||||
|
|
||||||
|
print(dataset[5])
|
||||||
|
print(dataset[5]["text"])
|
Loading…
x
Reference in New Issue
Block a user