Compare commits

11 Commits

Author SHA1 Message Date
carry
841e14a093 feat(frontend): 添加数据集页面并重构主页面布局
- 新增 dataset_page 模块,实现数据集页面的基本布局
- 重构 main.py 中的页面加载方式,使用列表收集所有页面
- 更新主页面布局,将聊天页面作为第一个选项卡
- 调整设置页面的加载方式,直接使用函数调用
2025-04-06 22:49:37 +08:00
carry
2ff077bb1c refactor(frontend): 重构前端页面导入方式
- 在 main.py 中使用更简洁的导入方式
- 新增 __init__.py 文件以简化前端页面的导入
2025-04-06 22:46:31 +08:00
carry
513b639bce feat(frontend): 添加了设置页面的api provider展示 2025-04-06 22:05:56 +08:00
carry
f93f213a31 feat(db): 添加数据库连接和初始化功能
- 新增 db/__init__.py 文件,提供数据库连接和初始化的接口
- 导入 get_engine 和 initialize_db 函数,方便外部使用
2025-04-06 21:27:25 +08:00
carry
10b4c29bda docs(db): 修改了代码注释 2025-04-06 21:26:53 +08:00
carry
b1e98ca913 feat(db): 初始化数据库并创建 APIProvider 表
- 新增 init_db.py 文件,实现数据库初始化和 APIProvider 表的创建
- 新增 dataset_generation.py 文件,定义 LLMRequest、LLMResponse 和 APIProvider 模型
- 在初始化数据库时,如果环境变量中存在 API_KEY、BASE_URL 和 MODEL_ID,会自动添加一条 APIProvider 记录
2025-04-06 19:59:23 +08:00
carry
2d5a5277ae refactor(schema): 更新 prompt 导入
- 将 prompt_templeta 重命名为 promptTempleta,以符合驼峰命名规范
- 优化导入语句格式
2025-04-06 19:39:43 +08:00
carry
519a5f3773 feat(frontend): 添加前端页面模块并实现基本布局
- 新增 chat_page.py、setting_page.py 和 train_page.py 文件,分别实现聊天、设置和微调页面的基本布局
- 添加 main.py 文件,集成所有页面并创建主应用
- 在 requirements.txt 中添加 gradio 依赖
2025-04-06 14:49:01 +08:00
carry
1f4d491694 build: 添加 pydantic 依赖 2025-04-05 01:00:33 +08:00
carry
8ce4f1e373 chore: 添加 .roo 到 .gitignore 文件
- 在 .gitignore 文件中添加 .roo 目录,以忽略相关文件
2025-04-05 00:59:42 +08:00
carry
3395b860e4 refactor(parse_markdown): 重构 Markdown 解析逻辑并使用 Pydantic 模型
将 MarkdownNode 类重构为使用 Pydantic 模型,提高代码的可维护性和类型安全性。同时,将解析逻辑与节点操作分离,简化代码结构。
2025-04-04 20:50:39 +08:00
23 changed files with 268 additions and 1878 deletions

1
.gitignore vendored
View File

@@ -11,6 +11,7 @@ env/
# IDE # IDE
.vscode/ .vscode/
.idea/ .idea/
.roo
# Environment files # Environment files
.env .env

21
LICENSE
View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2022 C-a-r-r-y
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,15 +0,0 @@
# 基于文档驱动的自适应编码大模型微调框架
## 简介
本人的毕业设计,这个是mvp分支MVP 是指最小可行产品Minimum Viable Product其他功能在master分支中
### 项目概述
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
### 项目技术(预计)
* 使用unsloth框架在GPU上实现大语言模型的qlora微调
* 使用langchain框架编写工作流实现批量生成微调语料
* 使用tinydb和sqlite实现数据的持久化
* 使用gradio框架实现前端展示
**施工中......**

View File

@@ -1,15 +0,0 @@
import os
from dotenv import load_dotenv
from typing import Dict, Any
def load_config() -> Dict[str, Any]:
"""从.env文件加载配置"""
load_dotenv()
return {
"openai": {
"api_key": os.getenv("OPENAI_API_KEY"),
"base_url": os.getenv("OPENAI_BASE_URL"),
"model_id": os.getenv("OPENAI_MODEL_ID")
}
}

View File

@@ -1,94 +0,0 @@
import os
import json
from tools.parse_markdown import parse_markdown, MarkdownNode
from tools.openai_api import generate_json_via_llm
from prompt.base import create_dataset
from config.llm import load_config
from tqdm import tqdm
def process_markdown_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
root = parse_markdown(content)
results = []
def traverse(node, parent_titles):
current_titles = parent_titles.copy()
current_titles.append(node.title)
if not node.children: # 叶子节点
if node.content:
full_text = ' -> '.join(current_titles) + '\n' + node.content
results.append(full_text)
else:
for child in node.children:
traverse(child, current_titles)
traverse(root, [])
return results
def find_markdown_files(directory):
markdown_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.md'):
markdown_files.append(os.path.join(root, file))
return markdown_files
def process_all_markdown(doc_dir):
all_results = []
markdown_files = find_markdown_files(doc_dir)
for file_path in markdown_files:
results = process_markdown_file(file_path)
all_results.extend(results)
return all_results
def save_dataset(dataset, output_dir):
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, 'dataset.json')
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
# 解析markdown文档
results = process_all_markdown('workdir/my_docs')
# 加载LLM配置
config = load_config()
dataset = []
# 使用tqdm包装外部循环以显示进度条
for content in tqdm(results, desc="生成数据集进度", unit="文档"):
for _ in range(3):
prompt = create_dataset.create(
"LLaMA-Factory", # 项目名
content, # 文档内容
"""{
"dataset":[
{
"question":"",
"answer":""
}
]
}"""
)
# 调用LLM生成JSON
try:
result = generate_json_via_llm(
prompt=prompt,
base_url=config["openai"]["base_url"],
api_key=config["openai"]["api_key"],
model_id=config["openai"]["model_id"]
)
print(json.loads(result)["dataset"])
dataset.extend(json.loads(result)["dataset"])
except Exception as e:
print(f"生成数据集时出错: {e}")
# 保存数据集
save_dataset(dataset, 'workdir/dataset2')
print(f"数据集已生成,共{len(dataset)}条数据")

3
db/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .init_db import get_engine, initialize_db
__all__ = ['get_engine', 'initialize_db']

79
db/init_db.py Normal file
View File

@@ -0,0 +1,79 @@
from sqlmodel import SQLModel, create_engine, Session
from sqlmodel import select
from typing import Optional
import os
from pathlib import Path
import sys
from dotenv import load_dotenv
from sqlalchemy.engine import Engine
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset_generation import APIProvider
# 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None
def get_engine(workdir: str) -> Engine:
"""
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
Args:
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
Returns:
Engine: SQLAlchemy 数据库引擎实例。
"""
global _engine
if not _engine:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "db.sqlite")
# 创建数据库URL
db_url = f"sqlite:///{db_path}"
# 创建数据库引擎
_engine = create_engine(db_url)
return _engine
def initialize_db(engine: Engine) -> None:
"""
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
Args:
engine (Engine): SQLAlchemy 数据库引擎实例。
"""
# 创建所有表结构
SQLModel.metadata.create_all(engine)
# 加载环境变量
load_dotenv()
# 从环境变量中获取API相关配置
api_key = os.getenv("API_KEY")
base_url = os.getenv("BASE_URL")
model_id = os.getenv("MODEL_ID")
# 如果所有必要的环境变量都存在,则插入初始数据
if api_key and base_url and model_id:
with Session(engine) as session:
# 查询是否已存在APIProvider记录
statement = select(APIProvider).limit(1)
existing_provider = session.exec(statement).first()
# 如果不存在则插入新的APIProvider记录
if not existing_provider:
provider = APIProvider(
base_url=base_url,
model_id=model_id,
api_key=api_key
)
session.add(provider)
session.commit()
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎
engine = get_engine(workdir)
# 初始化数据库
initialize_db(engine)

4
frontend/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .chat_page import chat_page
from .setting_page import setting_page
from .train_page import train_page
from .dataset_page import dataset_page

9
frontend/chat_page.py Normal file
View File

@@ -0,0 +1,9 @@
import gradio as gr
def chat_page():
with gr.Blocks() as demo:
gr.Markdown("## 聊天")
with gr.Row():
with gr.Column():
pass
return demo

9
frontend/dataset_page.py Normal file
View File

@@ -0,0 +1,9 @@
import gradio as gr
def dataset_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集")
with gr.Row():
with gr.Column():
pass
return demo

35
frontend/setting_page.py Normal file
View File

@@ -0,0 +1,35 @@
import gradio as gr
from typing import List, Dict
from sqlmodel import Session, select
from db import get_engine
from schema import APIProvider
import os
# 获取数据库引擎
engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir"))
def setting_page():
def get_providers() -> List[List[str]]:
with Session(engine) as session:
providers = session.exec(select(APIProvider)).all()
return [
[p.id, p.model_id, p.base_url, p.api_key or ""]
for p in providers
]
with gr.Blocks() as demo:
gr.Markdown("## API Provider 管理")
with gr.Row():
# API Provider列表
with gr.Column(scale=2):
provider_table = gr.DataFrame(
headers=["id" , "model id", "URL", "API Key"],
datatype=["number","str", "str", "str"],
interactive=True,
value=get_providers(),
wrap=True,
col_count=(4, "fixed")
)
return demo

9
frontend/train_page.py Normal file
View File

@@ -0,0 +1,9 @@
import gradio as gr
def train_page():
with gr.Blocks() as demo:
gr.Markdown("## 微调")
with gr.Row():
with gr.Column():
pass
return demo

27
main.py Normal file
View File

@@ -0,0 +1,27 @@
import gradio as gr
from frontend.setting_page import setting_page
from frontend import chat_page,setting_page,train_page,dataset_page
from db import initialize_db as init_db,get_engine
if __name__ == "__main__":
init_db(get_engine("workdir"))
pages = []
pages.append(setting_page())
pages.append(chat_page())
pages.append(train_page())
pages.append(dataset_page())
with gr.Blocks() as app:
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
with gr.Tabs():
with gr.TabItem("聊天"):
chat_page()
with gr.TabItem("微调"):
train_page()
with gr.TabItem("数据集"):
dataset_page()
with gr.TabItem("设置"):
setting_page()
app.launch()

View File

@@ -1,25 +0,0 @@
class create_dataset:
"""用于生成微调数据集模板的类"""
template = """
项目名为:{}
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
文档节选:{}
按照如下json格式返回{}
"""
@staticmethod
def create(*args: any) -> str:
"""根据提供的任意数量参数生成数据集模板
Args:
*args: 任意数量的参数,将按顺序填充到模板中
Returns:
格式化后的模板字符串
"""
return create_dataset.template.format(*args)
if __name__=="__main__":
print(create_dataset.create("a", "b", "c"))

View File

@@ -1,2 +1,4 @@
openai>=1.0.0 openai>=1.0.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0
pydantic>=2.0.0
gradio>=3.0.0

4
schema/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .dataset import *
from .dataset_generation import APIProvider, LLMResponse, LLMRequest
from .md_doc import MarkdownNode
from .prompt import promptTempleta

View File

@@ -1,9 +0,0 @@
from pydantic import BaseModel, RootModel
from typing import List
class QAPair(BaseModel):
question: str
response: str
class QAArray(RootModel):
root: List[QAPair]

View File

@@ -0,0 +1,51 @@
from datetime import datetime, timezone
from typing import Optional
from sqlmodel import SQLModel, Relationship, Field
class APIProvider(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
base_url: str = Field(..., description="API的基础URL")
model_id: str = Field(..., description="API使用的模型ID")
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
class LLMResponse(SQLModel):
timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="响应的时间戳"
)
response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: dict = Field(default_factory=lambda: {
"temperature": None,
"max_tokens": None,
"top_p": None,
"frequency_penalty": None,
"presence_penalty": None,
"seed": None
}, description="API的生成参数")
class LLMRequest(SQLModel):
prompt: str = Field(..., description="发送给API的提示词")
provider_id: int = Field(foreign_key="apiprovider.id")
provider: APIProvider = Relationship()
format: Optional[str] = Field(default=None, description="API响应的格式")
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
total_tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")

13
schema/md_doc.py Normal file
View File

@@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import List, Optional
class MarkdownNode(BaseModel):
level: int = Field(default=0, description="节点层级")
title: str = Field(default="Root", description="节点标题")
content: Optional[str] = Field(default=None, description="节点内容")
children: List['MarkdownNode'] = Field(default_factory=list, description="子节点列表")
class Config:
arbitrary_types_allowed = True
MarkdownNode.model_rebuild()

View File

@@ -1,69 +0,0 @@
import json
from openai import OpenAI
def generate_json_via_llm(
prompt: str,
base_url: str,
api_key: str,
model_id: str
) -> str:
client = OpenAI(
api_key=api_key,
base_url=base_url
)
try:
response = client.chat.completions.create(
model=model_id,
messages=[
{
"role": "user",
"content": prompt
}
],
response_format={
'type': 'json_object'
}
)
return response.choices[0].message.content
except Exception as e:
raise RuntimeError(f"API请求失败: {e}")
if __name__ == "__main__":
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from config.llm import load_config
# 将项目根目录添加到 sys.path 中
# 示例用法
try:
config = load_config()
print(config)
result = generate_json_via_llm(
prompt="""测试随便生成点什么返回json格式的字符串,格式如下
{
"dataset":[
{
"question":"",
"answer":""
},
{
"question":"",
"answer":""
}
......
]
}
""",
base_url=config["openai"]["base_url"],
api_key=config["openai"]["api_key"],
model_id=config["openai"]["model_id"],
)
print(result)
except Exception as e:
print(f"错误: {e}")

View File

@@ -1,28 +1,24 @@
import re import re
import sys
from pathlib import Path
class MarkdownNode: # 添加项目根目录到sys.path
def __init__(self, level=0, title="Root"): sys.path.append(str(Path(__file__).resolve().parent.parent))
self.level = level from schema import MarkdownNode
self.title = title
self.content = "" # 使用字符串存储合并后的内容
self.children = []
def __repr__(self): def add_child(parent, child):
return f"({self.level}) {self.title}" parent.children.append(child)
def add_child(self, child): def print_tree(node, indent=0):
self.children.append(child) prefix = "" * (indent - 1) + "└─ " if indent > 0 else ""
print(f"{prefix}{node.title}")
def print_tree(self, indent=0): if node.content:
prefix = "" * (indent - 1) + "" if indent > 0 else "" content_prefix = "" * indent + "[内容]"
print(f"{prefix}{self.title}") print(content_prefix)
if self.content: for line in node.content.split('\n'):
content_prefix = "" * indent + "├─ [内容]" print("" * indent + "" + line)
print(content_prefix) for child in node.children:
for line in self.content.split('\n'): print_tree(child, indent + 1)
print("" * indent + "" + line)
for child in self.children:
child.print_tree(indent + 1)
def parse_markdown(markdown): def parse_markdown(markdown):
lines = markdown.split('\n') lines = markdown.split('\n')
@@ -51,10 +47,10 @@ def parse_markdown(markdown):
if match: if match:
level = len(match.group(1)) level = len(match.group(1))
title = match.group(2) title = match.group(2)
node = MarkdownNode(level, title) node = MarkdownNode(level=level, title=title, content="", children=[])
while stack[-1].level >= level: while stack[-1].level >= level:
stack.pop() stack.pop()
stack[-1].add_child(node) add_child(stack[-1], node)
stack.append(node) stack.append(node)
else: else:
if stack[-1].content: if stack[-1].content:
@@ -65,9 +61,9 @@ def parse_markdown(markdown):
if __name__=="__main__": if __name__=="__main__":
# 从文件读取 Markdown 内容 # 从文件读取 Markdown 内容
with open("example.md", "r", encoding="utf-8") as f: with open("workdir/example.md", "r", encoding="utf-8") as f:
markdown = f.read() markdown = f.read()
# 解析 Markdown 并打印树结构 # 解析 Markdown 并打印树结构
root = parse_markdown(markdown) root = parse_markdown(markdown)
root.print_tree() print_tree(root)

File diff suppressed because it is too large Load Diff

View File

@@ -1,70 +0,0 @@
from unsloth import FastLanguageModel
import torch
# 基础配置参数
max_seq_length = 4096 # 最大序列长度
dtype = None # 自动检测数据类型
load_in_4bit = True # 使用4位量化以减少内存使用
# 加载预训练模型和分词器
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "workdir\model\Qwen2.5-3B-Instruct-bnb-4bit", # 选择Qwen2.5 32B指令模型
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
model = FastLanguageModel.get_peft_model(
model,
r = 64, # LoRA秩,控制可训练参数数量
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",], # 需要训练的目标模块
lora_alpha = 64, # LoRA缩放因子
lora_dropout = 0, # LoRA dropout率
bias = "none", # 是否训练偏置项
use_gradient_checkpointing = "unsloth", # 使用梯度检查点节省显存
random_state = 114514, # 随机数种子
use_rslora = False, # 是否使用稳定版LoRA
loftq_config = None, # LoftQ配置
)
from unsloth.chat_templates import get_chat_template
# 配置分词器使用qwen-2.5对话模板
tokenizer = get_chat_template(
tokenizer,
chat_template="qwen-2.5",
)
def formatting_prompts_func(examples):
"""格式化对话数据的函数
Args:
examples: 包含对话列表的字典
Returns:
包含格式化文本的字典
"""
questions = examples["question"]
answer = examples["answer"]
# 将Question和Response组合成对话形式
convos = [
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
for q, r in zip(questions, answer)
]
# 使用tokenizer.apply_chat_template格式化对话
texts = [
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
for convo in convos
]
return {"text": texts}
from unsloth.chat_templates import standardize_sharegpt
# 加载数据集
from datasets import load_dataset
dataset = load_dataset("json", data_files="workdir\dataset\dataset.json")
dataset = dataset.map(formatting_prompts_func, batched = True)
print(dataset[5])
print(dataset[5]["text"])