Compare commits

2 Commits

Author SHA1 Message Date
carry
d35475d9e8 docs(mvp): 添加项目基础文档
- 新增 LICENSE 文件,定义项目开源许可证
- 新增 README.md 文件,介绍项目的基本信息和预期技术栈
2025-04-09 11:16:45 +08:00
carry
f882b82e57 release: 完成了mvp,具有基本的模型训练,语料生成和推理功能 2025-03-26 18:08:15 +08:00
30 changed files with 1847 additions and 735 deletions

6
.gitignore vendored
View File

@@ -11,7 +11,6 @@ env/
# IDE # IDE
.vscode/ .vscode/
.idea/ .idea/
.roo
# Environment files # Environment files
.env .env
@@ -28,7 +27,4 @@ Thumbs.db
workdir/ workdir/
# cache # cache
unsloth_compiled_cache unsloth_compiled_cache
# 测试代码
test.ipynb

View File

@@ -1,11 +1,11 @@
# 基于文档驱动的自适应编码大模型微调框架 # 基于文档驱动的自适应编码大模型微调框架
## 简介 ## 简介
本人的毕业设计 本人的毕业设计,这个是mvp分支MVP 是指最小可行产品Minimum Viable Product其他功能在master分支中
### 项目概述 ### 项目概述
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。 * 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
### 项目技术 ### 项目技术(预计)
* 使用unsloth框架在GPU上实现大语言模型的qlora微调 * 使用unsloth框架在GPU上实现大语言模型的qlora微调
* 使用langchain框架编写工作流实现批量生成微调语料 * 使用langchain框架编写工作流实现批量生成微调语料

15
config/llm.py Normal file
View File

@@ -0,0 +1,15 @@
import os
from dotenv import load_dotenv
from typing import Dict, Any
def load_config() -> Dict[str, Any]:
"""从.env文件加载配置"""
load_dotenv()
return {
"openai": {
"api_key": os.getenv("OPENAI_API_KEY"),
"base_url": os.getenv("OPENAI_BASE_URL"),
"model_id": os.getenv("OPENAI_MODEL_ID")
}
}

94
dataset_generator.py Normal file
View File

@@ -0,0 +1,94 @@
import os
import json
from tools.parse_markdown import parse_markdown, MarkdownNode
from tools.openai_api import generate_json_via_llm
from prompt.base import create_dataset
from config.llm import load_config
from tqdm import tqdm
def process_markdown_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
root = parse_markdown(content)
results = []
def traverse(node, parent_titles):
current_titles = parent_titles.copy()
current_titles.append(node.title)
if not node.children: # 叶子节点
if node.content:
full_text = ' -> '.join(current_titles) + '\n' + node.content
results.append(full_text)
else:
for child in node.children:
traverse(child, current_titles)
traverse(root, [])
return results
def find_markdown_files(directory):
markdown_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.md'):
markdown_files.append(os.path.join(root, file))
return markdown_files
def process_all_markdown(doc_dir):
all_results = []
markdown_files = find_markdown_files(doc_dir)
for file_path in markdown_files:
results = process_markdown_file(file_path)
all_results.extend(results)
return all_results
def save_dataset(dataset, output_dir):
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, 'dataset.json')
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
# 解析markdown文档
results = process_all_markdown('workdir/my_docs')
# 加载LLM配置
config = load_config()
dataset = []
# 使用tqdm包装外部循环以显示进度条
for content in tqdm(results, desc="生成数据集进度", unit="文档"):
for _ in range(3):
prompt = create_dataset.create(
"LLaMA-Factory", # 项目名
content, # 文档内容
"""{
"dataset":[
{
"question":"",
"answer":""
}
]
}"""
)
# 调用LLM生成JSON
try:
result = generate_json_via_llm(
prompt=prompt,
base_url=config["openai"]["base_url"],
api_key=config["openai"]["api_key"],
model_id=config["openai"]["model_id"]
)
print(json.loads(result)["dataset"])
dataset.extend(json.loads(result)["dataset"])
except Exception as e:
print(f"生成数据集时出错: {e}")
# 保存数据集
save_dataset(dataset, 'workdir/dataset2')
print(f"数据集已生成,共{len(dataset)}条数据")

View File

@@ -1,9 +0,0 @@
from .init_db import get_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
__all__ = [
"get_sqlite_engine",
"initialize_sqlite_db",
"get_prompt_tinydb",
"initialize_prompt_store"
]

View File

@@ -1,79 +0,0 @@
import os
import sys
from sqlmodel import SQLModel, create_engine, Session
from sqlmodel import select
from typing import Optional
from pathlib import Path
from dotenv import load_dotenv
from sqlalchemy.engine import Engine
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset_generation import APIProvider
# 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None
def get_sqlite_engine(workdir: str) -> Engine:
"""
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
Args:
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
Returns:
Engine: SQLAlchemy 数据库引擎实例。
"""
global _engine
if not _engine:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "db.sqlite")
# 创建数据库URL
db_url = f"sqlite:///{db_path}"
# 创建数据库引擎
_engine = create_engine(db_url)
return _engine
def initialize_sqlite_db(engine: Engine) -> None:
"""
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
Args:
engine (Engine): SQLAlchemy 数据库引擎实例。
"""
# 创建所有表结构
SQLModel.metadata.create_all(engine)
# 加载环境变量
load_dotenv()
# 从环境变量中获取API相关配置
api_key = os.getenv("API_KEY")
base_url = os.getenv("BASE_URL")
model_id = os.getenv("MODEL_ID")
# 如果所有必要的环境变量都存在,则插入初始数据
if api_key and base_url and model_id:
with Session(engine) as session:
# 查询是否已存在APIProvider记录
statement = select(APIProvider).limit(1)
existing_provider = session.exec(statement).first()
# 如果不存在则插入新的APIProvider记录
if not existing_provider:
provider = APIProvider(
base_url=base_url,
model_id=model_id,
api_key=api_key
)
session.add(provider)
session.commit()
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎
engine = get_sqlite_engine(workdir)
# 初始化数据库
initialize_sqlite_db(engine)

View File

@@ -1,62 +0,0 @@
import os
import sys
from typing import Optional
from pathlib import Path
from datetime import datetime, timezone
from tinydb import TinyDB, Query
from tinydb.storages import JSONStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.prompt import promptTempleta
# 全局变量用于存储TinyDB实例
_db_instance: Optional[TinyDB] = None
# 自定义存储类用于格式化JSON数据
def get_prompt_tinydb(workdir: str) -> TinyDB:
"""
获取TinyDB实例。如果实例尚未创建则创建一个新的并返回。
Args:
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
Returns:
TinyDB: TinyDB数据库实例
"""
global _db_instance
if not _db_instance:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "prompts.json")
# 创建TinyDB实例
_db_instance = TinyDB(db_path)
return _db_instance
def initialize_prompt_store(db: TinyDB) -> None:
"""
初始化prompt模板存储
Args:
db (TinyDB): TinyDB数据库实例
"""
# 检查数据库是否为空
if not db.all(): # 如果数据库中没有数据
db.insert(promptTempleta(
id=0,
name="default",
description="默认提示词模板",
content="""项目名为:{ project_name }
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
文档节选:{ content }""").model_dump())
# 如果数据库中已有数据,则跳过插入
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库实例
db = get_prompt_tinydb(workdir)
# 初始化prompt存储
initialize_prompt_store(db)

View File

@@ -1,6 +0,0 @@
from .chat_page import *
from .setting_page import *
from .train_page import *
from .dataset_manage_page import *
from .dataset_generate_page import *
from .prompt_manage_page import *

View File

@@ -1,9 +0,0 @@
import gradio as gr
def chat_page():
with gr.Blocks() as demo:
gr.Markdown("## 聊天")
with gr.Row():
with gr.Column():
pass
return demo

View File

@@ -1,41 +0,0 @@
import gradio as gr
from global_var import docs, scan_docs_directory, prompt_store
def dataset_generate_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集生成")
with gr.Row():
with gr.Column():
# 获取文档列表并设置初始值
docs_list = [str(doc.name) for doc in scan_docs_directory("workdir")]
initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown(
choices=docs_list,
value=initial_doc, # 设置初始选中项
label="选择文档",
allow_custom_value=True,
interactive=True
)
doc_state = gr.State(value=initial_doc) # 用文档初始值初始化状态
with gr.Column():
# 获取模板列表并设置初始值
prompts = prompt_store.all()
prompt_choices = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_choices[0] if prompt_choices else None
prompt_dropdown = gr.Dropdown(
choices=prompt_choices,
value=initial_prompt, # 设置初始选中项
label="选择模板",
allow_custom_value=True,
interactive=True
)
prompt_state = gr.State(value=initial_prompt) # 用模板初始值初始化状态
# 绑定事件(保留原有逻辑,确保交互时更新)
doc_dropdown.change(lambda x: x, inputs=doc_dropdown, outputs=doc_state)
prompt_dropdown.change(lambda x: x, inputs=prompt_dropdown, outputs=prompt_state)
return demo

View File

@@ -1,9 +0,0 @@
import gradio as gr
def dataset_manage_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集管理")
with gr.Row():
with gr.Column():
pass
return demo

View File

@@ -1,116 +0,0 @@
import gradio as gr
from typing import List
from global_var import prompt_store
from schema.prompt import promptTempleta
def prompt_manage_page():
def get_prompts() -> List[List[str]]:
selected_row = None
try:
db = prompt_store
prompts = db.all()
return [
[p["id"], p["name"], p["description"], p["content"]]
for p in prompts
]
except Exception as e:
raise gr.Error(f"获取提示词失败: {str(e)}")
def add_prompt(name, description, content):
try:
db = prompt_store
new_prompt = promptTempleta(
name=name if name else "",
description=description if description else "",
content=content if content else ""
)
prompt_id = db.insert(new_prompt.model_dump())
# 更新ID
db.update({"id": prompt_id}, doc_ids=[prompt_id])
return get_prompts(), "", "", "" # 返回清空后的输入框值
except Exception as e:
raise gr.Error(f"添加失败: {str(e)}")
def edit_prompt():
global selected_row
if not selected_row:
raise gr.Error("请先选择要编辑的行")
try:
db = prompt_store
db.update({
"name": selected_row[1] if selected_row[1] else "",
"description": selected_row[2] if selected_row[2] else "",
"content": selected_row[3] if selected_row[3] else ""
}, doc_ids=[selected_row[0]])
return get_prompts()
except Exception as e:
raise gr.Error(f"编辑失败: {str(e)}")
def delete_prompt():
global selected_row
if not selected_row:
raise gr.Error("请先选择要删除的行")
try:
db = prompt_store
db.remove(doc_ids=[selected_row[0]])
return get_prompts()
except Exception as e:
raise gr.Error(f"删除失败: {str(e)}")
selected_row = None # 保存当前选中行的全局变量
def select_record(evt: gr.SelectData):
global selected_row
selected_row = evt.row_value
with gr.Blocks() as demo:
gr.Markdown("## 提示词模板管理")
with gr.Row():
with gr.Column(scale=1):
name_input = gr.Textbox(label="模板名称")
description_input = gr.Textbox(label="模板描述")
content_input = gr.Textbox(label="模板内容", lines=10)
add_button = gr.Button("添加新模板", variant="primary")
with gr.Column(scale=3):
prompt_table = gr.DataFrame(
headers=["id", "名称", "描述", "内容"],
datatype=["number", "str", "str", "str"],
interactive=True,
value=get_prompts(),
wrap=True,
col_count=(4, "auto")
)
with gr.Row():
refresh_button = gr.Button("刷新数据", variant="secondary")
edit_button = gr.Button("编辑选中行", variant="primary")
delete_button = gr.Button("删除选中行", variant="stop")
refresh_button.click(
fn=get_prompts,
outputs=[prompt_table],
queue=False
)
add_button.click(
fn=add_prompt,
inputs=[name_input, description_input, content_input],
outputs=[prompt_table, name_input, description_input, content_input]
)
prompt_table.select(select_record, [], [], show_progress="hidden")
edit_button.click(
fn=edit_prompt,
inputs=[],
outputs=[prompt_table]
)
delete_button.click(
fn=delete_prompt,
inputs=[],
outputs=[prompt_table]
)
return demo

View File

@@ -1,126 +0,0 @@
import gradio as gr
from typing import List
from sqlmodel import Session, select
from schema import APIProvider
from global_var import sql_engine
def setting_page():
def get_providers() -> List[List[str]]:
selected_row = None
try: # 添加异常处理
with Session(sql_engine) as session:
providers = session.exec(select(APIProvider)).all()
return [
[p.id, p.model_id, p.base_url, p.api_key or ""]
for p in providers
]
except Exception as e:
raise gr.Error(f"获取数据失败: {str(e)}")
def add_provider(model_id, base_url, api_key):
try:
with Session(sql_engine) as session:
new_provider = APIProvider(
model_id=model_id if model_id else None,
base_url=base_url if base_url else None,
api_key=api_key if api_key else None
)
session.add(new_provider)
session.commit()
session.refresh(new_provider)
return get_providers(), "", "", "" # 返回清空后的输入框值
except Exception as e:
raise gr.Error(f"添加失败: {str(e)}")
def edit_provider():
global selected_row
if not selected_row:
raise gr.Error("请先选择要编辑的行")
try:
with Session(sql_engine) as session:
provider = session.get(APIProvider, selected_row[0])
if not provider:
raise gr.Error("找不到选中的记录")
provider.model_id = selected_row[1] if selected_row[1] else None
provider.base_url = selected_row[2] if selected_row[2] else None
provider.api_key = selected_row[3] if selected_row[3] else None
session.add(provider)
session.commit()
session.refresh(provider)
return get_providers()
except Exception as e:
raise gr.Error(f"编辑失败: {str(e)}")
def delete_provider():
global selected_row
if not selected_row:
raise gr.Error("请先选择要删除的行")
try:
with Session(sql_engine) as session:
provider = session.get(APIProvider, selected_row[0])
if not provider:
raise gr.Error("找不到选中的记录")
session.delete(provider)
session.commit()
return get_providers()
except Exception as e:
raise gr.Error(f"删除失败: {str(e)}")
selected_row = None # 保存当前选中行的全局变量
def select_record(evt: gr.SelectData):
global selected_row
selected_row = evt.row_value
with gr.Blocks() as demo:
gr.Markdown("## API Provider 管理")
with gr.Row():
with gr.Column(scale=1):
model_id_input = gr.Textbox(label="Model ID")
base_url_input = gr.Textbox(label="Base URL")
api_key_input = gr.Textbox(label="API Key")
add_button = gr.Button("添加新API", variant="primary")
with gr.Column(scale=3):
provider_table = gr.DataFrame(
headers=["id", "model id", "base URL", "API Key"],
datatype=["number", "str", "str", "str"],
interactive=True,
value=get_providers(),
wrap=True,
col_count=(4, "auto")
)
with gr.Row():
refresh_button = gr.Button("刷新数据", variant="secondary")
edit_button = gr.Button("编辑选中行", variant="primary")
delete_button = gr.Button("删除选中行", variant="stop")
refresh_button.click(
fn=get_providers,
outputs=[provider_table],
queue=False # 立即刷新不需要排队
)
add_button.click(
fn=add_provider,
inputs=[model_id_input, base_url_input, api_key_input],
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
)
provider_table.select(select_record, [], [], show_progress="hidden")
edit_button.click(
fn=edit_provider,
inputs=[],
outputs=[provider_table]
)
delete_button.click(
fn=delete_provider,
inputs=[],
outputs=[provider_table]
)
return demo

View File

@@ -1,9 +0,0 @@
import gradio as gr
def train_page():
with gr.Blocks() as demo:
gr.Markdown("## 微调")
with gr.Row():
with gr.Column():
pass
return demo

View File

@@ -1,6 +0,0 @@
from db import get_sqlite_engine,get_prompt_tinydb
from tools import scan_docs_directory
prompt_store = get_prompt_tinydb("workdir")
sql_engine = get_sqlite_engine("workdir")
docs = scan_docs_directory("workdir")

26
main.py
View File

@@ -1,26 +0,0 @@
import gradio as gr
from frontend.setting_page import setting_page
from frontend import *
from db import initialize_sqlite_db,initialize_prompt_store
from global_var import sql_engine,prompt_store
if __name__ == "__main__":
initialize_sqlite_db(sql_engine)
initialize_prompt_store(prompt_store)
with gr.Blocks() as app:
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
with gr.Tabs():
with gr.TabItem("模型推理"):
chat_page()
with gr.TabItem("模型微调"):
train_page()
with gr.TabItem("数据集生成"):
dataset_generate_page()
with gr.TabItem("数据集管理"):
dataset_manage_page()
with gr.TabItem("提示词模板管理"):
prompt_manage_page()
with gr.TabItem("设置"):
setting_page()
app.launch()

25
prompt/base.py Normal file
View File

@@ -0,0 +1,25 @@
class create_dataset:
"""用于生成微调数据集模板的类"""
template = """
项目名为:{}
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
文档节选:{}
按照如下json格式返回{}
"""
@staticmethod
def create(*args: any) -> str:
"""根据提供的任意数量参数生成数据集模板
Args:
*args: 任意数量的参数,将按顺序填充到模板中
Returns:
格式化后的模板字符串
"""
return create_dataset.template.format(*args)
if __name__=="__main__":
print(create_dataset.create("a", "b", "c"))

View File

@@ -1,6 +1,2 @@
openai>=1.0.0 openai>=1.0.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0
pydantic>=2.0.0
gradio>=5.0.0
langchain>=0.3
tinydb>=4.0.0

View File

@@ -1,4 +0,0 @@
from .dataset import *
from .dataset_generation import APIProvider, LLMResponse, LLMRequest
from .md_doc import MarkdownNode
from .prompt import promptTempleta

View File

@@ -1,28 +1,9 @@
from typing import Optional from pydantic import BaseModel, RootModel
from pydantic import BaseModel, Field from typing import List
from datetime import datetime, timezone
class doc(BaseModel): class QAPair(BaseModel):
id: Optional[int] = Field(default=None, description="文档ID") question: str
name: str = Field(default="", description="文档名称") response: str
path: str = Field(default="", description="文档路径")
markdown_files: list[str] = Field(default_factory=list, description="文档路径列表")
class Q_A(BaseModel): class QAArray(RootModel):
question: str = Field(default="", min_length=1,description="问题") root: List[QAPair]
answer: str = Field(default="", min_length=1, description="答案")
class dataset_item(BaseModel):
id: Optional[int] = Field(default=None, description="数据集项ID")
message: list[Q_A] = Field(description="数据集项内容")
class dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID")
name: Optional[str] = Field(default=None, description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
description: Optional[str] = Field(default="", description="数据集描述")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
dataset_items: list[dataset_item] = Field(default_factory=list, description="数据集项列表")

View File

@@ -1,51 +0,0 @@
from datetime import datetime, timezone
from typing import Optional
from sqlmodel import SQLModel, Relationship, Field
class APIProvider(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True,allow_mutation=False)
base_url: str = Field(...,min_length=1,description="API的基础URL不能为空")
model_id: str = Field(...,min_length=1,description="API使用的模型ID不能为空")
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
class LLMResponse(SQLModel):
timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="响应的时间戳"
)
response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: dict = Field(default_factory=lambda: {
"temperature": None,
"max_tokens": None,
"top_p": None,
"frequency_penalty": None,
"presence_penalty": None,
"seed": None
}, description="API的生成参数")
class LLMRequest(SQLModel):
prompt: str = Field(..., description="发送给API的提示词")
provider_id: int = Field(foreign_key="apiprovider.id")
provider: APIProvider = Relationship()
format: Optional[str] = Field(default=None, description="API响应的格式")
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
total_tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")

View File

@@ -1,13 +0,0 @@
from pydantic import BaseModel, Field
from typing import List, Optional
class MarkdownNode(BaseModel):
level: int = Field(default=0, description="节点层级")
title: str = Field(default="Root", description="节点标题")
content: Optional[str] = Field(default=None, description="节点内容")
children: List['MarkdownNode'] = Field(default_factory=list, description="子节点列表")
class Config:
arbitrary_types_allowed = True
MarkdownNode.model_rebuild()

View File

@@ -1,13 +0,0 @@
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime, timezone
class promptTempleta(BaseModel):
id: Optional[int] = Field(default=None, description="模板ID")
name: Optional[str] = Field(default="", description="模板名称")
description: Optional[str] = Field(default="", description="模板描述")
content: str = Field(default="", min_length=1, description="模板内容")
created_at: str = Field(
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
description="记录创建时间"
)

View File

@@ -1,2 +0,0 @@
from .parse_markdown import parse_markdown
from .scan_doc_dir import *

View File

@@ -1,35 +0,0 @@
from typing import List
from schema.dataset import dataset, dataset_item, Q_A
import json
def convert_json_to_dataset(json_data: List[dict]) -> dataset:
# 将JSON数据转换为dataset格式
dataset_items = []
item_id = 1 # 自增ID计数器
for item in json_data:
qa = Q_A(question=item["question"], answer=item["answer"])
dataset_item_obj = dataset_item(id=item_id, message=[qa])
dataset_items.append(dataset_item_obj)
item_id += 1 # ID自增
# 创建dataset对象
result_dataset = dataset(
name="Converted Dataset",
model_id=None,
description="Dataset converted from JSON",
dataset_items=dataset_items
)
return result_dataset
# 示例从文件读取JSON并转换
if __name__ == "__main__":
# 假设JSON数据存储在文件中
with open(r"workdir\dataset_old\llamafactory.json", "r", encoding="utf-8") as file:
json_data = json.load(file)
# 转换为dataset格式
converted_dataset = convert_json_to_dataset(json_data)
# 输出结果到文件
with open("output.json", "w", encoding="utf-8") as file:
file.write(converted_dataset.model_dump_json(indent=4))

69
tools/openai_api.py Normal file
View File

@@ -0,0 +1,69 @@
import json
from openai import OpenAI
def generate_json_via_llm(
prompt: str,
base_url: str,
api_key: str,
model_id: str
) -> str:
client = OpenAI(
api_key=api_key,
base_url=base_url
)
try:
response = client.chat.completions.create(
model=model_id,
messages=[
{
"role": "user",
"content": prompt
}
],
response_format={
'type': 'json_object'
}
)
return response.choices[0].message.content
except Exception as e:
raise RuntimeError(f"API请求失败: {e}")
if __name__ == "__main__":
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from config.llm import load_config
# 将项目根目录添加到 sys.path 中
# 示例用法
try:
config = load_config()
print(config)
result = generate_json_via_llm(
prompt="""测试随便生成点什么返回json格式的字符串,格式如下
{
"dataset":[
{
"question":"",
"answer":""
},
{
"question":"",
"answer":""
}
......
]
}
""",
base_url=config["openai"]["base_url"],
api_key=config["openai"]["api_key"],
model_id=config["openai"]["model_id"],
)
print(result)
except Exception as e:
print(f"错误: {e}")

View File

@@ -1,45 +1,28 @@
import re import re
import sys
from pathlib import Path
# 添加项目根目录到sys.path class MarkdownNode:
sys.path.append(str(Path(__file__).resolve().parent.parent)) def __init__(self, level=0, title="Root"):
from schema import MarkdownNode self.level = level
self.title = title
self.content = "" # 使用字符串存储合并后的内容
self.children = []
def process_markdown_file(file_path): def __repr__(self):
with open(file_path, 'r', encoding='utf-8') as f: return f"({self.level}) {self.title}"
content = f.read()
root = parse_markdown(content)
results = []
def traverse(node, parent_titles):
current_titles = parent_titles.copy()
current_titles.append(node.title)
if not node.children: # 叶子节点
if node.content:
full_text = ' -> '.join(current_titles) + '\n' + node.content
results.append(full_text)
else:
for child in node.children:
traverse(child, current_titles)
traverse(root, [])
return results
def add_child(parent, child):
parent.children.append(child)
def print_tree(node, indent=0): def add_child(self, child):
prefix = "" * (indent - 1) + "└─ " if indent > 0 else "" self.children.append(child)
print(f"{prefix}{node.title}")
if node.content: def print_tree(self, indent=0):
content_prefix = "" * indent + "[内容]" prefix = "" * (indent - 1) + "" if indent > 0 else ""
print(content_prefix) print(f"{prefix}{self.title}")
for line in node.content.split('\n'): if self.content:
print("" * indent + "" + line) content_prefix = "" * indent + "├─ [内容]"
for child in node.children: print(content_prefix)
print_tree(child, indent + 1) for line in self.content.split('\n'):
print("" * indent + "" + line)
for child in self.children:
child.print_tree(indent + 1)
def parse_markdown(markdown): def parse_markdown(markdown):
lines = markdown.split('\n') lines = markdown.split('\n')
@@ -68,10 +51,10 @@ def parse_markdown(markdown):
if match: if match:
level = len(match.group(1)) level = len(match.group(1))
title = match.group(2) title = match.group(2)
node = MarkdownNode(level=level, title=title, content="", children=[]) node = MarkdownNode(level, title)
while stack[-1].level >= level: while stack[-1].level >= level:
stack.pop() stack.pop()
add_child(stack[-1], node) stack[-1].add_child(node)
stack.append(node) stack.append(node)
else: else:
if stack[-1].content: if stack[-1].content:
@@ -81,13 +64,10 @@ def parse_markdown(markdown):
return root return root
if __name__=="__main__": if __name__=="__main__":
# # 从文件读取 Markdown 内容 # 从文件读取 Markdown 内容
# with open("workdir/example.md", "r", encoding="utf-8") as f: with open("example.md", "r", encoding="utf-8") as f:
# markdown = f.read() markdown = f.read()
# # 解析 Markdown 并打印树结构 # 解析 Markdown 并打印树结构
# root = parse_markdown(markdown) root = parse_markdown(markdown)
# print_tree(root) root.print_tree()
for i in process_markdown_file("workdir/example.md"):
print("~"*20)
print(i)

View File

@@ -1,32 +0,0 @@
import sys
import os
from pathlib import Path
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import doc
def scan_docs_directory(workdir: str):
docs_dir = os.path.join(workdir, "docs")
doc_list = os.listdir(docs_dir)
to_return = []
for doc_name in doc_list:
doc_path = os.path.join(docs_dir, doc_name)
if os.path.isdir(doc_path):
markdown_files = []
for root, dirs, files in os.walk(doc_path):
for file in files:
if file.endswith(".md"):
markdown_files.append(os.path.join(root, file))
to_return.append(doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
return to_return
# 添加测试代码
if __name__ == "__main__":
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
docs = scan_docs_directory(workdir)
print(docs)

1534
train.ipynb Normal file

File diff suppressed because it is too large Load Diff

70
trainer.py Normal file
View File

@@ -0,0 +1,70 @@
from unsloth import FastLanguageModel
import torch
# 基础配置参数
max_seq_length = 4096 # 最大序列长度
dtype = None # 自动检测数据类型
load_in_4bit = True # 使用4位量化以减少内存使用
# 加载预训练模型和分词器
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "workdir\model\Qwen2.5-3B-Instruct-bnb-4bit", # 选择Qwen2.5 32B指令模型
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
model = FastLanguageModel.get_peft_model(
model,
r = 64, # LoRA秩,控制可训练参数数量
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",], # 需要训练的目标模块
lora_alpha = 64, # LoRA缩放因子
lora_dropout = 0, # LoRA dropout率
bias = "none", # 是否训练偏置项
use_gradient_checkpointing = "unsloth", # 使用梯度检查点节省显存
random_state = 114514, # 随机数种子
use_rslora = False, # 是否使用稳定版LoRA
loftq_config = None, # LoftQ配置
)
from unsloth.chat_templates import get_chat_template
# 配置分词器使用qwen-2.5对话模板
tokenizer = get_chat_template(
tokenizer,
chat_template="qwen-2.5",
)
def formatting_prompts_func(examples):
"""格式化对话数据的函数
Args:
examples: 包含对话列表的字典
Returns:
包含格式化文本的字典
"""
questions = examples["question"]
answer = examples["answer"]
# 将Question和Response组合成对话形式
convos = [
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
for q, r in zip(questions, answer)
]
# 使用tokenizer.apply_chat_template格式化对话
texts = [
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
for convo in convos
]
return {"text": texts}
from unsloth.chat_templates import standardize_sharegpt
# 加载数据集
from datasets import load_dataset
dataset = load_dataset("json", data_files="workdir\dataset\dataset.json")
dataset = dataset.map(formatting_prompts_func, batched = True)
print(dataset[5])
print(dataset[5]["text"])