Compare commits
8 Commits
mvp
...
f93f213a31
Author | SHA1 | Date | |
---|---|---|---|
![]() |
f93f213a31 | ||
![]() |
10b4c29bda | ||
![]() |
b1e98ca913 | ||
![]() |
2d5a5277ae | ||
![]() |
519a5f3773 | ||
![]() |
1f4d491694 | ||
![]() |
8ce4f1e373 | ||
![]() |
3395b860e4 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -11,6 +11,7 @@ env/
|
|||||||
# IDE
|
# IDE
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
|
.roo
|
||||||
|
|
||||||
# Environment files
|
# Environment files
|
||||||
.env
|
.env
|
||||||
|
21
LICENSE
21
LICENSE
@@ -1,21 +0,0 @@
|
|||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2022 C-a-r-r-y
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
15
README.md
15
README.md
@@ -1,15 +0,0 @@
|
|||||||
# 基于文档驱动的自适应编码大模型微调框架
|
|
||||||
## 简介
|
|
||||||
本人的毕业设计,这个是mvp分支(MVP 是指最小可行产品Minimum Viable Product),其他功能在master分支中
|
|
||||||
### 项目概述
|
|
||||||
|
|
||||||
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
|
|
||||||
|
|
||||||
### 项目技术(预计)
|
|
||||||
|
|
||||||
* 使用unsloth框架在GPU上实现大语言模型的qlora微调
|
|
||||||
* 使用langchain框架编写工作流实现批量生成微调语料
|
|
||||||
* 使用tinydb和sqlite实现数据的持久化
|
|
||||||
* 使用gradio框架实现前端展示
|
|
||||||
|
|
||||||
**施工中......**
|
|
@@ -1,15 +0,0 @@
|
|||||||
import os
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
def load_config() -> Dict[str, Any]:
|
|
||||||
"""从.env文件加载配置"""
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"openai": {
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
"base_url": os.getenv("OPENAI_BASE_URL"),
|
|
||||||
"model_id": os.getenv("OPENAI_MODEL_ID")
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,94 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
from tools.parse_markdown import parse_markdown, MarkdownNode
|
|
||||||
from tools.openai_api import generate_json_via_llm
|
|
||||||
from prompt.base import create_dataset
|
|
||||||
from config.llm import load_config
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
def process_markdown_file(file_path):
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
root = parse_markdown(content)
|
|
||||||
results = []
|
|
||||||
|
|
||||||
def traverse(node, parent_titles):
|
|
||||||
current_titles = parent_titles.copy()
|
|
||||||
current_titles.append(node.title)
|
|
||||||
|
|
||||||
if not node.children: # 叶子节点
|
|
||||||
if node.content:
|
|
||||||
full_text = ' -> '.join(current_titles) + '\n' + node.content
|
|
||||||
results.append(full_text)
|
|
||||||
else:
|
|
||||||
for child in node.children:
|
|
||||||
traverse(child, current_titles)
|
|
||||||
|
|
||||||
traverse(root, [])
|
|
||||||
return results
|
|
||||||
|
|
||||||
def find_markdown_files(directory):
|
|
||||||
markdown_files = []
|
|
||||||
for root, dirs, files in os.walk(directory):
|
|
||||||
for file in files:
|
|
||||||
if file.endswith('.md'):
|
|
||||||
markdown_files.append(os.path.join(root, file))
|
|
||||||
return markdown_files
|
|
||||||
|
|
||||||
def process_all_markdown(doc_dir):
|
|
||||||
all_results = []
|
|
||||||
|
|
||||||
markdown_files = find_markdown_files(doc_dir)
|
|
||||||
for file_path in markdown_files:
|
|
||||||
results = process_markdown_file(file_path)
|
|
||||||
all_results.extend(results)
|
|
||||||
|
|
||||||
return all_results
|
|
||||||
|
|
||||||
def save_dataset(dataset, output_dir):
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
output_path = os.path.join(output_dir, 'dataset.json')
|
|
||||||
with open(output_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 解析markdown文档
|
|
||||||
results = process_all_markdown('workdir/my_docs')
|
|
||||||
|
|
||||||
# 加载LLM配置
|
|
||||||
config = load_config()
|
|
||||||
|
|
||||||
dataset = []
|
|
||||||
# 使用tqdm包装外部循环以显示进度条
|
|
||||||
for content in tqdm(results, desc="生成数据集进度", unit="文档"):
|
|
||||||
for _ in range(3):
|
|
||||||
prompt = create_dataset.create(
|
|
||||||
"LLaMA-Factory", # 项目名
|
|
||||||
content, # 文档内容
|
|
||||||
"""{
|
|
||||||
"dataset":[
|
|
||||||
{
|
|
||||||
"question":"",
|
|
||||||
"answer":""
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# 调用LLM生成JSON
|
|
||||||
try:
|
|
||||||
result = generate_json_via_llm(
|
|
||||||
prompt=prompt,
|
|
||||||
base_url=config["openai"]["base_url"],
|
|
||||||
api_key=config["openai"]["api_key"],
|
|
||||||
model_id=config["openai"]["model_id"]
|
|
||||||
)
|
|
||||||
print(json.loads(result)["dataset"])
|
|
||||||
dataset.extend(json.loads(result)["dataset"])
|
|
||||||
except Exception as e:
|
|
||||||
print(f"生成数据集时出错: {e}")
|
|
||||||
|
|
||||||
# 保存数据集
|
|
||||||
save_dataset(dataset, 'workdir/dataset2')
|
|
||||||
print(f"数据集已生成,共{len(dataset)}条数据")
|
|
3
db/__init__.py
Normal file
3
db/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .init_db import get_engine, initialize_db
|
||||||
|
|
||||||
|
__all__ = ['get_engine', 'initialize_db']
|
79
db/init_db.py
Normal file
79
db/init_db.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from sqlmodel import SQLModel, create_engine, Session
|
||||||
|
from sqlmodel import select
|
||||||
|
from typing import Optional
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||||
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
|
from schema.dataset_generation import APIProvider
|
||||||
|
|
||||||
|
# 全局变量,用于存储数据库引擎实例
|
||||||
|
_engine: Optional[Engine] = None
|
||||||
|
|
||||||
|
def get_engine(workdir: str) -> Engine:
|
||||||
|
"""
|
||||||
|
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Engine: SQLAlchemy 数据库引擎实例。
|
||||||
|
"""
|
||||||
|
global _engine
|
||||||
|
if not _engine:
|
||||||
|
# 创建数据库目录(如果不存在)
|
||||||
|
db_dir = os.path.join(workdir, "db")
|
||||||
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
|
# 定义数据库文件路径
|
||||||
|
db_path = os.path.join(db_dir, "db.sqlite")
|
||||||
|
# 创建数据库URL
|
||||||
|
db_url = f"sqlite:///{db_path}"
|
||||||
|
# 创建数据库引擎
|
||||||
|
_engine = create_engine(db_url)
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
def initialize_db(engine: Engine) -> None:
|
||||||
|
"""
|
||||||
|
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine (Engine): SQLAlchemy 数据库引擎实例。
|
||||||
|
"""
|
||||||
|
# 创建所有表结构
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
# 从环境变量中获取API相关配置
|
||||||
|
api_key = os.getenv("API_KEY")
|
||||||
|
base_url = os.getenv("BASE_URL")
|
||||||
|
model_id = os.getenv("MODEL_ID")
|
||||||
|
|
||||||
|
# 如果所有必要的环境变量都存在,则插入初始数据
|
||||||
|
if api_key and base_url and model_id:
|
||||||
|
with Session(engine) as session:
|
||||||
|
# 查询是否已存在APIProvider记录
|
||||||
|
statement = select(APIProvider).limit(1)
|
||||||
|
existing_provider = session.exec(statement).first()
|
||||||
|
|
||||||
|
# 如果不存在,则插入新的APIProvider记录
|
||||||
|
if not existing_provider:
|
||||||
|
provider = APIProvider(
|
||||||
|
base_url=base_url,
|
||||||
|
model_id=model_id,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
session.add(provider)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 定义工作目录路径
|
||||||
|
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||||
|
# 获取数据库引擎
|
||||||
|
engine = get_engine(workdir)
|
||||||
|
# 初始化数据库
|
||||||
|
initialize_db(engine)
|
9
frontend/chat_page.py
Normal file
9
frontend/chat_page.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
def chat_page():
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown("## 聊天")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
pass
|
||||||
|
return demo
|
9
frontend/setting_page.py
Normal file
9
frontend/setting_page.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
def setting_page():
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown("## 设置")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
pass
|
||||||
|
return demo
|
9
frontend/train_page.py
Normal file
9
frontend/train_page.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
def train_page():
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown("## 微调")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
pass
|
||||||
|
return demo
|
23
main.py
Normal file
23
main.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import gradio as gr
|
||||||
|
from frontend.setting_page import setting_page
|
||||||
|
from frontend.chat_page import chat_page
|
||||||
|
from frontend.train_page import train_page
|
||||||
|
def main():
|
||||||
|
setting_demo = setting_page()
|
||||||
|
chat_demo = chat_page()
|
||||||
|
train_demo = train_page()
|
||||||
|
|
||||||
|
with gr.Blocks() as app:
|
||||||
|
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
|
||||||
|
with gr.Tabs():
|
||||||
|
with gr.TabItem("微调"):
|
||||||
|
train_demo.render()
|
||||||
|
with gr.TabItem("聊天"):
|
||||||
|
chat_demo.render()
|
||||||
|
with gr.TabItem("设置"):
|
||||||
|
setting_demo.render()
|
||||||
|
|
||||||
|
app.launch()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@@ -1,25 +0,0 @@
|
|||||||
class create_dataset:
|
|
||||||
"""用于生成微调数据集模板的类"""
|
|
||||||
|
|
||||||
template = """
|
|
||||||
项目名为:{}
|
|
||||||
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
|
|
||||||
文档节选:{}
|
|
||||||
按照如下json格式返回:{}
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create(*args: any) -> str:
|
|
||||||
"""根据提供的任意数量参数生成数据集模板
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args: 任意数量的参数,将按顺序填充到模板中
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
格式化后的模板字符串
|
|
||||||
"""
|
|
||||||
return create_dataset.template.format(*args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__=="__main__":
|
|
||||||
print(create_dataset.create("a", "b", "c"))
|
|
@@ -1,2 +1,4 @@
|
|||||||
openai>=1.0.0
|
openai>=1.0.0
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
|
pydantic>=2.0.0
|
||||||
|
gradio>=3.0.0
|
4
schema/__init__.py
Normal file
4
schema/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .dataset import *
|
||||||
|
from .dataset_generation import APIProvider, LLMResponse, LLMRequest
|
||||||
|
from .md_doc import MarkdownNode
|
||||||
|
from .prompt import promptTempleta
|
@@ -1,9 +0,0 @@
|
|||||||
from pydantic import BaseModel, RootModel
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
class QAPair(BaseModel):
|
|
||||||
question: str
|
|
||||||
response: str
|
|
||||||
|
|
||||||
class QAArray(RootModel):
|
|
||||||
root: List[QAPair]
|
|
51
schema/dataset_generation.py
Normal file
51
schema/dataset_generation.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
from sqlmodel import SQLModel, Relationship, Field
|
||||||
|
|
||||||
|
class APIProvider(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
base_url: str = Field(..., description="API的基础URL")
|
||||||
|
model_id: str = Field(..., description="API使用的模型ID")
|
||||||
|
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
|
||||||
|
created_at: datetime = Field(
|
||||||
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
|
description="记录创建时间"
|
||||||
|
)
|
||||||
|
|
||||||
|
class LLMResponse(SQLModel):
|
||||||
|
timestamp: datetime = Field(
|
||||||
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
|
description="响应的时间戳"
|
||||||
|
)
|
||||||
|
response_id: str = Field(..., description="响应的唯一ID")
|
||||||
|
tokens_usage: dict = Field(default_factory=lambda: {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"prompt_cache_hit_tokens": None,
|
||||||
|
"prompt_cache_miss_tokens": None
|
||||||
|
}, description="token使用信息")
|
||||||
|
response_content: dict = Field(default_factory=dict, description="API响应的内容")
|
||||||
|
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||||
|
llm_parameters: dict = Field(default_factory=lambda: {
|
||||||
|
"temperature": None,
|
||||||
|
"max_tokens": None,
|
||||||
|
"top_p": None,
|
||||||
|
"frequency_penalty": None,
|
||||||
|
"presence_penalty": None,
|
||||||
|
"seed": None
|
||||||
|
}, description="API的生成参数")
|
||||||
|
|
||||||
|
class LLMRequest(SQLModel):
|
||||||
|
prompt: str = Field(..., description="发送给API的提示词")
|
||||||
|
provider_id: int = Field(foreign_key="apiprovider.id")
|
||||||
|
provider: APIProvider = Relationship()
|
||||||
|
format: Optional[str] = Field(default=None, description="API响应的格式")
|
||||||
|
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
||||||
|
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
||||||
|
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||||
|
total_tokens_usage: dict = Field(default_factory=lambda: {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"prompt_cache_hit_tokens": None,
|
||||||
|
"prompt_cache_miss_tokens": None
|
||||||
|
}, description="token使用信息")
|
13
schema/md_doc.py
Normal file
13
schema/md_doc.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
class MarkdownNode(BaseModel):
|
||||||
|
level: int = Field(default=0, description="节点层级")
|
||||||
|
title: str = Field(default="Root", description="节点标题")
|
||||||
|
content: Optional[str] = Field(default=None, description="节点内容")
|
||||||
|
children: List['MarkdownNode'] = Field(default_factory=list, description="子节点列表")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
MarkdownNode.model_rebuild()
|
@@ -1,69 +0,0 @@
|
|||||||
import json
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
def generate_json_via_llm(
|
|
||||||
prompt: str,
|
|
||||||
base_url: str,
|
|
||||||
api_key: str,
|
|
||||||
model_id: str
|
|
||||||
) -> str:
|
|
||||||
client = OpenAI(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
|
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model=model_id,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt
|
|
||||||
}
|
|
||||||
],
|
|
||||||
response_format={
|
|
||||||
'type': 'json_object'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.choices[0].message.content
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"API请求失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
||||||
from config.llm import load_config
|
|
||||||
# 将项目根目录添加到 sys.path 中
|
|
||||||
|
|
||||||
# 示例用法
|
|
||||||
try:
|
|
||||||
config = load_config()
|
|
||||||
print(config)
|
|
||||||
result = generate_json_via_llm(
|
|
||||||
prompt="""测试,随便生成点什么,返回json格式的字符串,格式如下
|
|
||||||
{
|
|
||||||
"dataset":[
|
|
||||||
{
|
|
||||||
"question":"",
|
|
||||||
"answer":""
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"question":"",
|
|
||||||
"answer":""
|
|
||||||
}
|
|
||||||
......
|
|
||||||
]
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
base_url=config["openai"]["base_url"],
|
|
||||||
api_key=config["openai"]["api_key"],
|
|
||||||
model_id=config["openai"]["model_id"],
|
|
||||||
)
|
|
||||||
print(result)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"错误: {e}")
|
|
@@ -1,28 +1,24 @@
|
|||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
class MarkdownNode:
|
# 添加项目根目录到sys.path
|
||||||
def __init__(self, level=0, title="Root"):
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
self.level = level
|
from schema import MarkdownNode
|
||||||
self.title = title
|
|
||||||
self.content = "" # 使用字符串存储合并后的内容
|
|
||||||
self.children = []
|
|
||||||
|
|
||||||
def __repr__(self):
|
def add_child(parent, child):
|
||||||
return f"({self.level}) {self.title}"
|
parent.children.append(child)
|
||||||
|
|
||||||
def add_child(self, child):
|
def print_tree(node, indent=0):
|
||||||
self.children.append(child)
|
prefix = "│ " * (indent - 1) + "└─ " if indent > 0 else ""
|
||||||
|
print(f"{prefix}{node.title}")
|
||||||
def print_tree(self, indent=0):
|
if node.content:
|
||||||
prefix = "│ " * (indent - 1) + "└─ " if indent > 0 else ""
|
content_prefix = "│ " * indent + "├─ [内容]"
|
||||||
print(f"{prefix}{self.title}")
|
print(content_prefix)
|
||||||
if self.content:
|
for line in node.content.split('\n'):
|
||||||
content_prefix = "│ " * indent + "├─ [内容]"
|
print("│ " * indent + "│ " + line)
|
||||||
print(content_prefix)
|
for child in node.children:
|
||||||
for line in self.content.split('\n'):
|
print_tree(child, indent + 1)
|
||||||
print("│ " * indent + "│ " + line)
|
|
||||||
for child in self.children:
|
|
||||||
child.print_tree(indent + 1)
|
|
||||||
|
|
||||||
def parse_markdown(markdown):
|
def parse_markdown(markdown):
|
||||||
lines = markdown.split('\n')
|
lines = markdown.split('\n')
|
||||||
@@ -51,10 +47,10 @@ def parse_markdown(markdown):
|
|||||||
if match:
|
if match:
|
||||||
level = len(match.group(1))
|
level = len(match.group(1))
|
||||||
title = match.group(2)
|
title = match.group(2)
|
||||||
node = MarkdownNode(level, title)
|
node = MarkdownNode(level=level, title=title, content="", children=[])
|
||||||
while stack[-1].level >= level:
|
while stack[-1].level >= level:
|
||||||
stack.pop()
|
stack.pop()
|
||||||
stack[-1].add_child(node)
|
add_child(stack[-1], node)
|
||||||
stack.append(node)
|
stack.append(node)
|
||||||
else:
|
else:
|
||||||
if stack[-1].content:
|
if stack[-1].content:
|
||||||
@@ -65,9 +61,9 @@ def parse_markdown(markdown):
|
|||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
# 从文件读取 Markdown 内容
|
# 从文件读取 Markdown 内容
|
||||||
with open("example.md", "r", encoding="utf-8") as f:
|
with open("workdir/example.md", "r", encoding="utf-8") as f:
|
||||||
markdown = f.read()
|
markdown = f.read()
|
||||||
|
|
||||||
# 解析 Markdown 并打印树结构
|
# 解析 Markdown 并打印树结构
|
||||||
root = parse_markdown(markdown)
|
root = parse_markdown(markdown)
|
||||||
root.print_tree()
|
print_tree(root)
|
||||||
|
1534
train.ipynb
1534
train.ipynb
File diff suppressed because it is too large
Load Diff
70
trainer.py
70
trainer.py
@@ -1,70 +0,0 @@
|
|||||||
from unsloth import FastLanguageModel
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# 基础配置参数
|
|
||||||
max_seq_length = 4096 # 最大序列长度
|
|
||||||
dtype = None # 自动检测数据类型
|
|
||||||
load_in_4bit = True # 使用4位量化以减少内存使用
|
|
||||||
|
|
||||||
# 加载预训练模型和分词器
|
|
||||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
||||||
model_name = "workdir\model\Qwen2.5-3B-Instruct-bnb-4bit", # 选择Qwen2.5 32B指令模型
|
|
||||||
max_seq_length = max_seq_length,
|
|
||||||
dtype = dtype,
|
|
||||||
load_in_4bit = load_in_4bit,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = FastLanguageModel.get_peft_model(
|
|
||||||
model,
|
|
||||||
r = 64, # LoRA秩,控制可训练参数数量
|
|
||||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
|
||||||
"gate_proj", "up_proj", "down_proj",], # 需要训练的目标模块
|
|
||||||
lora_alpha = 64, # LoRA缩放因子
|
|
||||||
lora_dropout = 0, # LoRA dropout率
|
|
||||||
bias = "none", # 是否训练偏置项
|
|
||||||
use_gradient_checkpointing = "unsloth", # 使用梯度检查点节省显存
|
|
||||||
random_state = 114514, # 随机数种子
|
|
||||||
use_rslora = False, # 是否使用稳定版LoRA
|
|
||||||
loftq_config = None, # LoftQ配置
|
|
||||||
)
|
|
||||||
|
|
||||||
from unsloth.chat_templates import get_chat_template
|
|
||||||
# 配置分词器使用qwen-2.5对话模板
|
|
||||||
tokenizer = get_chat_template(
|
|
||||||
tokenizer,
|
|
||||||
chat_template="qwen-2.5",
|
|
||||||
)
|
|
||||||
|
|
||||||
def formatting_prompts_func(examples):
|
|
||||||
"""格式化对话数据的函数
|
|
||||||
Args:
|
|
||||||
examples: 包含对话列表的字典
|
|
||||||
Returns:
|
|
||||||
包含格式化文本的字典
|
|
||||||
"""
|
|
||||||
questions = examples["question"]
|
|
||||||
answer = examples["answer"]
|
|
||||||
|
|
||||||
# 将Question和Response组合成对话形式
|
|
||||||
convos = [
|
|
||||||
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
|
|
||||||
for q, r in zip(questions, answer)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 使用tokenizer.apply_chat_template格式化对话
|
|
||||||
texts = [
|
|
||||||
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
|
|
||||||
for convo in convos
|
|
||||||
]
|
|
||||||
|
|
||||||
return {"text": texts}
|
|
||||||
|
|
||||||
from unsloth.chat_templates import standardize_sharegpt
|
|
||||||
|
|
||||||
# 加载数据集
|
|
||||||
from datasets import load_dataset
|
|
||||||
dataset = load_dataset("json", data_files="workdir\dataset\dataset.json")
|
|
||||||
dataset = dataset.map(formatting_prompts_func, batched = True)
|
|
||||||
|
|
||||||
print(dataset[5])
|
|
||||||
print(dataset[5]["text"])
|
|
Reference in New Issue
Block a user