Compare commits

2 Commits

Author SHA1 Message Date
carry
d35475d9e8 docs(mvp): 添加项目基础文档
- 新增 LICENSE 文件,定义项目开源许可证
- 新增 README.md 文件,介绍项目的基本信息和预期技术栈
2025-04-09 11:16:45 +08:00
carry
f882b82e57 release: 完成了mvp,具有基本的模型训练,语料生成和推理功能 2025-03-26 18:08:15 +08:00
34 changed files with 1847 additions and 1406 deletions

6
.gitignore vendored
View File

@@ -11,7 +11,6 @@ env/
# IDE
.vscode/
.idea/
.roo
# Environment files
.env
@@ -29,8 +28,3 @@ workdir/
# cache
unsloth_compiled_cache
# 测试和参考代码
test.ipynb
test.py
refer/

View File

@@ -1,11 +1,11 @@
# 基于文档驱动的自适应编码大模型微调框架
## 简介
本人的毕业设计
本人的毕业设计,这个是mvp分支MVP 是指最小可行产品Minimum Viable Product其他功能在master分支中
### 项目概述
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
### 项目技术
### 项目技术(预计)
* 使用unsloth框架在GPU上实现大语言模型的qlora微调
* 使用langchain框架编写工作流实现批量生成微调语料

15
config/llm.py Normal file
View File

@@ -0,0 +1,15 @@
import os
from dotenv import load_dotenv
from typing import Dict, Any
def load_config() -> Dict[str, Any]:
"""从.env文件加载配置"""
load_dotenv()
return {
"openai": {
"api_key": os.getenv("OPENAI_API_KEY"),
"base_url": os.getenv("OPENAI_BASE_URL"),
"model_id": os.getenv("OPENAI_MODEL_ID")
}
}

94
dataset_generator.py Normal file
View File

@@ -0,0 +1,94 @@
import os
import json
from tools.parse_markdown import parse_markdown, MarkdownNode
from tools.openai_api import generate_json_via_llm
from prompt.base import create_dataset
from config.llm import load_config
from tqdm import tqdm
def process_markdown_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
root = parse_markdown(content)
results = []
def traverse(node, parent_titles):
current_titles = parent_titles.copy()
current_titles.append(node.title)
if not node.children: # 叶子节点
if node.content:
full_text = ' -> '.join(current_titles) + '\n' + node.content
results.append(full_text)
else:
for child in node.children:
traverse(child, current_titles)
traverse(root, [])
return results
def find_markdown_files(directory):
markdown_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.md'):
markdown_files.append(os.path.join(root, file))
return markdown_files
def process_all_markdown(doc_dir):
all_results = []
markdown_files = find_markdown_files(doc_dir)
for file_path in markdown_files:
results = process_markdown_file(file_path)
all_results.extend(results)
return all_results
def save_dataset(dataset, output_dir):
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, 'dataset.json')
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
# 解析markdown文档
results = process_all_markdown('workdir/my_docs')
# 加载LLM配置
config = load_config()
dataset = []
# 使用tqdm包装外部循环以显示进度条
for content in tqdm(results, desc="生成数据集进度", unit="文档"):
for _ in range(3):
prompt = create_dataset.create(
"LLaMA-Factory", # 项目名
content, # 文档内容
"""{
"dataset":[
{
"question":"",
"answer":""
}
]
}"""
)
# 调用LLM生成JSON
try:
result = generate_json_via_llm(
prompt=prompt,
base_url=config["openai"]["base_url"],
api_key=config["openai"]["api_key"],
model_id=config["openai"]["model_id"]
)
print(json.loads(result)["dataset"])
dataset.extend(json.loads(result)["dataset"])
except Exception as e:
print(f"生成数据集时出错: {e}")
# 保存数据集
save_dataset(dataset, 'workdir/dataset2')
print(f"数据集已生成,共{len(dataset)}条数据")

View File

@@ -1,11 +0,0 @@
from .init_db import get_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
from .dataset_store import get_all_dataset
__all__ = [
"get_sqlite_engine",
"initialize_sqlite_db",
"get_prompt_tinydb",
"initialize_prompt_store",
"get_all_dataset"
]

View File

@@ -1,50 +0,0 @@
import os
import sys
import json
from pathlib import Path
from typing import List
from tinydb import TinyDB, Query
from tinydb.storages import MemoryStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A
def get_all_dataset(workdir: str) -> TinyDB:
"""
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
Args:
workdir (str): 工作目录路径
Returns:
TinyDB: 包含所有数据集对象的TinyDB对象
"""
dataset_dir = os.path.join(workdir, "dataset")
if not os.path.exists(dataset_dir):
return TinyDB(storage=MemoryStorage)
db = TinyDB(storage=MemoryStorage)
for filename in os.listdir(dataset_dir):
if filename.endswith(".json"):
filepath = os.path.join(dataset_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
db.insert(data)
except (json.JSONDecodeError, Exception) as e:
print(f"Error loading dataset file {filename}: {str(e)}")
continue
return db
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取所有数据集
datasets = get_all_dataset(workdir)
# 打印结果
print(f"Found {len(datasets)} datasets:")
for ds in datasets.all():
print(f"- {ds['name']} (ID: {ds['id']})")

View File

@@ -1,79 +0,0 @@
import os
import sys
from sqlmodel import SQLModel, create_engine, Session
from sqlmodel import select
from typing import Optional
from pathlib import Path
from dotenv import load_dotenv
from sqlalchemy.engine import Engine
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset_generation import APIProvider
# 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None
def get_sqlite_engine(workdir: str) -> Engine:
"""
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
Args:
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
Returns:
Engine: SQLAlchemy 数据库引擎实例。
"""
global _engine
if not _engine:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "db.sqlite")
# 创建数据库URL
db_url = f"sqlite:///{db_path}"
# 创建数据库引擎
_engine = create_engine(db_url)
return _engine
def initialize_sqlite_db(engine: Engine) -> None:
"""
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
Args:
engine (Engine): SQLAlchemy 数据库引擎实例。
"""
# 创建所有表结构
SQLModel.metadata.create_all(engine)
# 加载环境变量
load_dotenv()
# 从环境变量中获取API相关配置
api_key = os.getenv("API_KEY")
base_url = os.getenv("BASE_URL")
model_id = os.getenv("MODEL_ID")
# 如果所有必要的环境变量都存在,则插入初始数据
if api_key and base_url and model_id:
with Session(engine) as session:
# 查询是否已存在APIProvider记录
statement = select(APIProvider).limit(1)
existing_provider = session.exec(statement).first()
# 如果不存在则插入新的APIProvider记录
if not existing_provider:
provider = APIProvider(
base_url=base_url,
model_id=model_id,
api_key=api_key
)
session.add(provider)
session.commit()
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎
engine = get_sqlite_engine(workdir)
# 初始化数据库
initialize_sqlite_db(engine)

View File

@@ -1,62 +0,0 @@
import os
import sys
from typing import Optional
from pathlib import Path
from datetime import datetime, timezone
from tinydb import TinyDB, Query
from tinydb.storages import JSONStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.prompt import promptTempleta
# 全局变量用于存储TinyDB实例
_db_instance: Optional[TinyDB] = None
# 自定义存储类用于格式化JSON数据
def get_prompt_tinydb(workdir: str) -> TinyDB:
"""
获取TinyDB实例。如果实例尚未创建则创建一个新的并返回。
Args:
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
Returns:
TinyDB: TinyDB数据库实例
"""
global _db_instance
if not _db_instance:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "prompts.json")
# 创建TinyDB实例
_db_instance = TinyDB(db_path)
return _db_instance
def initialize_prompt_store(db: TinyDB) -> None:
"""
初始化prompt模板存储
Args:
db (TinyDB): TinyDB数据库实例
"""
# 检查数据库是否为空
if not db.all(): # 如果数据库中没有数据
db.insert(promptTempleta(
id=0,
name="default",
description="默认提示词模板",
content="""项目名为:{ project_name }
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
文档节选:{ content }""").model_dump())
# 如果数据库中已有数据,则跳过插入
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库实例
db = get_prompt_tinydb(workdir)
# 初始化prompt存储
initialize_prompt_store(db)

View File

@@ -1,7 +0,0 @@
from .train_page import *
from .chat_page import *
from .setting_page import *
from .model_manage_page import *
from .dataset_manage_page import *
from .dataset_generate_page import *
from .prompt_manage_page import *

View File

@@ -1,97 +0,0 @@
import gradio as gr
import sys
from pathlib import Path
from threading import Thread # 需要导入 Thread
from transformers import TextIteratorStreamer # 使用 TextIteratorStreamer
# 假设 global_var.py 在父目录
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer # 假设这两个函数能正确获取模型和分词器
def chat_page():
with gr.Blocks() as demo:
# 聊天框
gr.Markdown("## 对话")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(type="messages", label="聊天机器人")
msg = gr.Textbox(label="输入消息")
with gr.Column(scale=1):
# 新增超参数输入框
max_new_tokens_input = gr.Textbox(label="最大生成长度", value="1024")
temperature_input = gr.Textbox(label="温度 (Temperature)", value="0.8")
top_p_input = gr.Textbox(label="Top-p 采样", value="0.95")
repetition_penalty_input = gr.Textbox(label="重复惩罚", value="1.1")
clear = gr.Button("清除对话")
def user(user_message, history: list):
return "", history + [{"role": "user", "content": user_message}]
def bot(history: list, max_new_tokens, temperature, top_p, repetition_penalty):
model = get_model()
tokenizer = get_tokenizer()
if not history:
yield history
return
if model is None or tokenizer is None:
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"})
yield history
return
try:
inputs = tokenizer.apply_chat_template(
history,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 将超参数转换为数值类型
generation_kwargs = dict(
input_ids=inputs,
streamer=streamer,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
repetition_penalty=float(repetition_penalty),
do_sample=True,
use_cache=False
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
history.append({"role": "assistant", "content": ""})
for new_text in streamer:
if new_text:
history[-1]["content"] += new_text
yield history
except Exception as e:
import traceback
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
history[-1]["content"] = error_message
else:
history.append({"role": "assistant", "content": error_message})
yield history
# 更新 .then() 调用,将超参数传递给 bot 函数
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, max_new_tokens_input, temperature_input, top_p_input, repetition_penalty_input], chatbot
)
clear.click(lambda: [], None, chatbot, queue=False)
return demo
if __name__ == "__main__":
from model_manage_page import model_manage_page
# 装载两个页面
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()

View File

@@ -1,42 +0,0 @@
import gradio as gr
from tools import scan_docs_directory
from global_var import get_docs, scan_docs_directory, get_prompt_store
def dataset_generate_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集生成")
with gr.Row():
with gr.Column():
# 获取文档列表并设置初始值
docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown(
choices=docs_list,
value=initial_doc, # 设置初始选中项
label="选择文档",
allow_custom_value=True,
interactive=True
)
doc_state = gr.State(value=initial_doc) # 用文档初始值初始化状态
with gr.Column():
# 获取模板列表并设置初始值
prompts = get_prompt_store().all()
prompt_choices = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_choices[0] if prompt_choices else None
prompt_dropdown = gr.Dropdown(
choices=prompt_choices,
value=initial_prompt, # 设置初始选中项
label="选择模板",
allow_custom_value=True,
interactive=True
)
prompt_state = gr.State(value=initial_prompt) # 用模板初始值初始化状态
# 绑定事件(保留原有逻辑,确保交互时更新)
doc_dropdown.change(lambda x: x, inputs=doc_dropdown, outputs=doc_state)
prompt_dropdown.change(lambda x: x, inputs=prompt_dropdown, outputs=prompt_state)
return demo

View File

@@ -1,55 +0,0 @@
import gradio as gr
from global_var import get_datasets
from tinydb import Query
def dataset_manage_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集管理")
with gr.Row():
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
# 添加数据集展示组件
qa_dataset = gr.Dataset(
components=["text", "text"],
label="问答数据",
headers=["问题", "答案"],
samples=[["示例问题", "示例答案"]],
samples_per_page=20,
)
def update_qa_display(dataset_name):
if not dataset_name:
return {"samples": [], "__type__": "update"}
# 从数据库获取数据集
Dataset = Query()
ds = get_datasets().get(Dataset.name == dataset_name)
if not ds:
return {"samples": [], "__type__": "update"}
# 提取所有Q_A数据
qa_list = []
for item in ds["dataset_items"]:
for qa in item["message"]:
qa_list.append([qa["question"], qa["answer"]])
return {"samples": qa_list, "__type__": "update"}
# 绑定事件更新QA数据显示
dataset_dropdown.change(
update_qa_display,
inputs=dataset_dropdown,
outputs=qa_dataset
)
return demo

View File

@@ -1,107 +0,0 @@
import gradio as gr
import os # 导入os模块以便扫描文件夹
import sys
from pathlib import Path
from unsloth import FastLanguageModel
import torch
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
from tools.model import get_model_name
def model_manage_page():
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
models_dir = os.path.join(workdir, "models")
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))] # 扫描models文件夹下的所有子文件夹
with gr.Blocks() as demo:
gr.Markdown("## 模型管理")
state_output = gr.Label(label="当前状态",value="当前未加载模型") # 将 Textbox 改为 Label
with gr.Row():
with gr.Column(scale=3):
model_select_dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中并设置为可选
max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0)
load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True)
with gr.Column(scale=1):
load_button = gr.Button("加载模型", variant="primary")
unload_button = gr.Button("卸载模型", variant="stop")
refresh_button = gr.Button("刷新模型列表", variant="secondary") # 新增刷新按钮
with gr.Row():
with gr.Column(scale=3):
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
with gr.Column(scale=1):
save_button = gr.Button("保存模型", variant="secondary")
def load_model(selected_model, max_seq_length, load_in_4bit):
try:
# 判空操作,如果模型已加载,则先卸载
if get_model() is not None:
unload_model()
model_path = os.path.join(models_dir, selected_model)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
)
set_model(model)
set_tokenizer(tokenizer)
return f"模型 {get_model_name(model)} 已加载"
except Exception as e:
return f"加载模型时出错: {str(e)}"
load_button.click(fn=load_model, inputs=[model_select_dropdown, max_seq_length_input, load_in_4bit_input], outputs=state_output)
def unload_model():
try:
# 将模型移动到 CPU
model = get_model()
if model is not None:
model.cpu()
# 清空 CUDA 缓存
torch.cuda.empty_cache()
# 将模型和tokenizer设置为 None
set_model(None)
set_tokenizer(None)
return "当前未加载模型"
except Exception as e:
return f"卸载模型时出错: {str(e)}"
unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
def save_model(save_model_name):
try:
global model, tokenizer
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name)
os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
return f"模型已保存到 {save_path}"
except Exception as e:
return f"保存模型时出错: {str(e)}"
save_button.click(fn=save_model, inputs=save_model_name_input, outputs=state_output)
def refresh_model_list():
try:
nonlocal model_folders
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))]
return gr.Dropdown(choices=model_folders)
except Exception as e:
return f"刷新模型列表时出错: {str(e)}"
refresh_button.click(fn=refresh_model_list, inputs=None, outputs=model_select_dropdown)
return demo
if __name__ == "__main__":
demo = model_manage_page()
demo.queue()
demo.launch()

View File

@@ -1,125 +0,0 @@
import gradio as gr
import sys
from pathlib import Path
from typing import List
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_prompt_store
from schema.prompt import promptTempleta
def prompt_manage_page():
def get_prompts() -> List[List[str]]:
selected_row = None
try:
db = get_prompt_store()
prompts = db.all()
return [
[p["id"], p["name"], p["description"], p["content"]]
for p in prompts
]
except Exception as e:
raise gr.Error(f"获取提示词失败: {str(e)}")
def add_prompt(name, description, content):
try:
db = get_prompt_store()
new_prompt = promptTempleta(
name=name if name else "",
description=description if description else "",
content=content if content else ""
)
prompt_id = db.insert(new_prompt.model_dump())
# 更新ID
db.update({"id": prompt_id}, doc_ids=[prompt_id])
return get_prompts(), "", "", "" # 返回清空后的输入框值
except Exception as e:
raise gr.Error(f"添加失败: {str(e)}")
def edit_prompt():
global selected_row
if not selected_row:
raise gr.Error("请先选择要编辑的行")
try:
db = get_prompt_store()
db.update({
"name": selected_row[1] if selected_row[1] else "",
"description": selected_row[2] if selected_row[2] else "",
"content": selected_row[3] if selected_row[3] else ""
}, doc_ids=[selected_row[0]])
return get_prompts()
except Exception as e:
raise gr.Error(f"编辑失败: {str(e)}")
def delete_prompt():
global selected_row
if not selected_row:
raise gr.Error("请先选择要删除的行")
try:
db = get_prompt_store()
db.remove(doc_ids=[selected_row[0]])
return get_prompts()
except Exception as e:
raise gr.Error(f"删除失败: {str(e)}")
selected_row = None # 保存当前选中行的全局变量
def select_record(evt: gr.SelectData):
global selected_row
selected_row = evt.row_value
with gr.Blocks() as demo:
gr.Markdown("## 提示词模板管理")
with gr.Row():
with gr.Column(scale=1):
name_input = gr.Textbox(label="模板名称")
description_input = gr.Textbox(label="模板描述")
content_input = gr.Textbox(label="模板内容", lines=10)
add_button = gr.Button("添加新模板", variant="primary")
with gr.Column(scale=3):
prompt_table = gr.DataFrame(
headers=["id", "名称", "描述", "内容"],
datatype=["number", "str", "str", "str"],
interactive=True,
value=get_prompts(),
wrap=True,
col_count=(4, "auto")
)
with gr.Row():
refresh_button = gr.Button("刷新数据", variant="secondary")
edit_button = gr.Button("编辑选中行", variant="primary")
delete_button = gr.Button("删除选中行", variant="stop")
refresh_button.click(
fn=get_prompts,
outputs=[prompt_table],
queue=False
)
add_button.click(
fn=add_prompt,
inputs=[name_input, description_input, content_input],
outputs=[prompt_table, name_input, description_input, content_input]
)
prompt_table.select(select_record, [], [], show_progress="hidden")
edit_button.click(
fn=edit_prompt,
inputs=[],
outputs=[prompt_table]
)
delete_button.click(
fn=delete_prompt,
inputs=[],
outputs=[prompt_table]
)
return demo
if __name__ == "__main__":
demo = prompt_manage_page()
demo.queue()
demo.launch()

View File

@@ -1,126 +0,0 @@
import gradio as gr
from typing import List
from sqlmodel import Session, select
from schema import APIProvider
from global_var import get_sql_engine
def setting_page():
def get_providers() -> List[List[str]]:
selected_row = None
try: # 添加异常处理
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
return [
[p.id, p.model_id, p.base_url, p.api_key or ""]
for p in providers
]
except Exception as e:
raise gr.Error(f"获取数据失败: {str(e)}")
def add_provider(model_id, base_url, api_key):
try:
with Session(get_sql_engine()) as session:
new_provider = APIProvider(
model_id=model_id if model_id else None,
base_url=base_url if base_url else None,
api_key=api_key if api_key else None
)
session.add(new_provider)
session.commit()
session.refresh(new_provider)
return get_providers(), "", "", "" # 返回清空后的输入框值
except Exception as e:
raise gr.Error(f"添加失败: {str(e)}")
def edit_provider():
global selected_row
if not selected_row:
raise gr.Error("请先选择要编辑的行")
try:
with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0])
if not provider:
raise gr.Error("找不到选中的记录")
provider.model_id = selected_row[1] if selected_row[1] else None
provider.base_url = selected_row[2] if selected_row[2] else None
provider.api_key = selected_row[3] if selected_row[3] else None
session.add(provider)
session.commit()
session.refresh(provider)
return get_providers()
except Exception as e:
raise gr.Error(f"编辑失败: {str(e)}")
def delete_provider():
global selected_row
if not selected_row:
raise gr.Error("请先选择要删除的行")
try:
with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0])
if not provider:
raise gr.Error("找不到选中的记录")
session.delete(provider)
session.commit()
return get_providers()
except Exception as e:
raise gr.Error(f"删除失败: {str(e)}")
selected_row = None # 保存当前选中行的全局变量
def select_record(evt: gr.SelectData):
global selected_row
selected_row = evt.row_value
with gr.Blocks() as demo:
gr.Markdown("## API Provider 管理")
with gr.Row():
with gr.Column(scale=1):
model_id_input = gr.Textbox(label="Model ID")
base_url_input = gr.Textbox(label="Base URL")
api_key_input = gr.Textbox(label="API Key")
add_button = gr.Button("添加新API", variant="primary")
with gr.Column(scale=3):
provider_table = gr.DataFrame(
headers=["id", "model id", "base URL", "API Key"],
datatype=["number", "str", "str", "str"],
interactive=True,
value=get_providers(),
wrap=True,
col_count=(4, "auto")
)
with gr.Row():
refresh_button = gr.Button("刷新数据", variant="secondary")
edit_button = gr.Button("编辑选中行", variant="primary")
delete_button = gr.Button("删除选中行", variant="stop")
refresh_button.click(
fn=get_providers,
outputs=[provider_table],
queue=False # 立即刷新不需要排队
)
add_button.click(
fn=add_provider,
inputs=[model_id_input, base_url_input, api_key_input],
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
)
provider_table.select(select_record, [], [], show_progress="hidden")
edit_button.click(
fn=edit_provider,
inputs=[],
outputs=[provider_table]
)
delete_button.click(
fn=delete_provider,
inputs=[],
outputs=[provider_table]
)
return demo

View File

@@ -1,219 +0,0 @@
import unsloth
import gradio as gr
import sys
import torch
import pandas as pd
from tinydb import Query
from pathlib import Path
from datasets import Dataset as HFDataset
from transformers import TrainerCallback
from unsloth import FastLanguageModel
from trl import SFTTrainer # 用于监督微调的训练器
from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数
from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练
from unsloth.chat_templates import get_chat_template, train_on_responses_only
from tools import formatting_prompts_func
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
from tools import formatting_prompts_func
def train_page():
with gr.Blocks() as demo:
gr.Markdown("## 微调")
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
# 新增超参数输入组件
learning_rate_input = gr.Number(value=2e-4, label="学习率")
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
epoch_input = gr.Number(value=1, label="epoch", precision=0)
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
train_button = gr.Button("开始微调")
# 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False)
# 添加loss曲线展示
loss_plot = gr.LinePlot(
x="step",
y="loss",
title="训练Loss曲线",
interactive=True,
width=600,
height=300
)
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
# 使用动态传入的超参数
learning_rate = float(learning_rate)
per_device_train_batch_size = int(per_device_train_batch_size)
epoch = int(epoch)
save_steps = int(save_steps) # 新增保存步数参数
lora_rank = int(lora_rank) # 新增LoRA秩参数
# 模型配置参数
dtype = None # 数据类型None表示自动选择
load_in_4bit = False # 使用4bit量化加载模型以节省显存
# 加载预训练模型和分词器
model = get_model()
tokenizer = get_tokenizer()
model = FastLanguageModel.get_peft_model(
# 原始模型
model,
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
# 建议: 8-32之间
r=lora_rank, # 使用动态传入的LoRA秩
# 需要应用LoRA的目标模块列表
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
"gate_proj", "up_proj", "down_proj", # FFN相关层
],
# LoRA缩放因子,用于控制LoRA更新的幅度。值越大LoRA的更新影响越大。
lora_alpha=16,
# LoRA层的dropout率,用于防止过拟合,这里设为0表示不使用dropout。
# 如果数据集较小建议设置0.1左右。
lora_dropout=0,
# 是否对bias参数进行微调,none表示不微调bias
# none: 不微调偏置参数;
# all: 微调所有参数;
# lora_only: 只微调LoRA参数。
bias="none",
# 是否使用梯度检查点技术节省显存,使用unsloth优化版本
# 会略微降低训练速度,但可以显著减少显存使用
use_gradient_checkpointing="unsloth",
# 随机数种子,用于结果复现
random_state=3407,
# 是否使用rank-stabilized LoRA,这里不使用
# 会略微降低训练速度,但可以显著减少显存使用
use_rslora=False,
# LoFTQ配置,这里不使用该量化技术,用于进一步压缩模型大小
loftq_config=None,
)
tokenizer = get_chat_template(
tokenizer,
chat_template="qwen-2.5",
)
# 加载数据集
dataset = get_datasets().get(Query().name == dataset_name)
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
dataset = HFDataset.from_list(dataset)
dataset = dataset.map(formatting_prompts_func,
fn_kwargs={"tokenizer": tokenizer},
batched=True)
# 创建回调类
class GradioLossCallback(TrainerCallback):
def __init__(self):
self.loss_data = []
self.log_text = ""
self.last_output = {"text": "", "plot": None}
def on_log(self, args, state, control, logs=None, **kwargs):
print(f"on_log called with logs: {logs}") # 调试输出
if "loss" in logs:
print(f"Recording loss: {logs['loss']} at step {state.global_step}") # 调试输出
self.loss_data.append({
"step": state.global_step,
"loss": float(logs["loss"]) # 确保转换为float
})
self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n"
df = pd.DataFrame(self.loss_data)
print(f"DataFrame created: {df}") # 调试输出
self.last_output = {
"text": self.log_text,
"plot": df
}
return control
# 初始化回调
callback = GradioLossCallback()
# 初始化SFT训练器
trainer = SFTTrainer(
model=model, # 待训练的模型
tokenizer=tokenizer, # 分词器
train_dataset=dataset, # 训练数据集
dataset_text_field="text", # 数据集字段的名称
max_seq_length=model.max_seq_length, # 最大序列长度
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=1, # 数据集处理的并行进程数
packing=False,
args=TrainingArguments(
per_device_train_batch_size=per_device_train_batch_size, # 每个GPU的训练批次大小
gradient_accumulation_steps=4, # 梯度累积步数,用于模拟更大的batch size
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
learning_rate=learning_rate, # 学习率
lr_scheduler_type="linear", # 线性学习率调度器
max_steps=int(epoch * len(dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
logging_steps=1, # 每1步记录一次日志
optim="adamw_8bit", # 使用8位AdamW优化器节省显存几乎不影响训练效果
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
seed=114514, # 随机数种子
output_dir=get_workdir() + "/checkpoint/", # 保存模型检查点和训练日志
save_strategy="steps", # 按步保存中间权重
save_steps=save_steps, # 使用动态传入的保存步数
# report_to="tensorboard", # 将信息输出到tensorboard
),
)
trainer.add_callback(callback)
trainer = train_on_responses_only(
trainer,
instruction_part = "<|im_start|>user\n",
response_part = "<|im_start|>assistant\n",
)
# 开始训练
trainer_stats = trainer.train(resume_from_checkpoint=False)
return callback.last_output
def wrapped_train_model(*args):
print("Starting training...") # 调试输出
result = train_model(*args)
print(f"Training completed with result: {result}") # 调试输出
# 确保返回格式正确
if result and "text" in result and "plot" in result:
return result["text"], result["plot"]
return "", pd.DataFrame() # 返回默认值
train_button.click(
fn=wrapped_train_model,
inputs=[
dataset_dropdown,
learning_rate_input,
per_device_train_batch_size_input,
epoch_input,
save_steps_input,
lora_rank_input
],
outputs=[output, loss_plot]
)
return demo
if __name__ == "__main__":
from global_var import init_global_var
from model_manage_page import model_manage_page
init_global_var("workdir")
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()

View File

@@ -1,62 +0,0 @@
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory
_prompt_store = None
_sql_engine = None
_docs = None
_datasets = None
_model = None
_tokenizer = None
_workdir = None
def init_global_var(workdir="workdir"):
global _prompt_store, _sql_engine, _docs, _datasets, _workdir
_prompt_store = get_prompt_tinydb(workdir)
_sql_engine = get_sqlite_engine(workdir)
_docs = scan_docs_directory(workdir)
_datasets = get_all_dataset(workdir)
_workdir = workdir
def get_workdir():
return _workdir
def get_prompt_store():
return _prompt_store
def set_prompt_store(new_prompt_store):
global _prompt_store
_prompt_store = new_prompt_store
def get_sql_engine():
return _sql_engine
def set_sql_engine(new_sql_engine):
global _sql_engine
_sql_engine = new_sql_engine
def get_docs():
return _docs
def set_docs(new_docs):
global _docs
_docs = new_docs
def get_datasets():
return _datasets
def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
def get_model():
return _model
def set_model(new_model):
global _model
_model = new_model
def get_tokenizer():
return _tokenizer
def set_tokenizer(new_tokenizer):
global _tokenizer
_tokenizer = new_tokenizer

29
main.py
View File

@@ -1,29 +0,0 @@
import gradio as gr
from frontend.setting_page import setting_page
from frontend import *
from db import initialize_sqlite_db, initialize_prompt_store
from global_var import init_global_var, get_sql_engine, get_prompt_store
if __name__ == "__main__":
init_global_var()
initialize_sqlite_db(get_sql_engine())
initialize_prompt_store(get_prompt_store())
with gr.Blocks() as app:
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
with gr.Tabs():
with gr.TabItem("模型管理"):
model_manage_page()
with gr.TabItem("模型推理"):
chat_page()
with gr.TabItem("模型微调"):
train_page()
with gr.TabItem("数据集生成"):
dataset_generate_page()
with gr.TabItem("数据集管理"):
dataset_manage_page()
with gr.TabItem("提示词模板管理"):
prompt_manage_page()
with gr.TabItem("设置"):
setting_page()
app.launch()

25
prompt/base.py Normal file
View File

@@ -0,0 +1,25 @@
class create_dataset:
"""用于生成微调数据集模板的类"""
template = """
项目名为:{}
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
文档节选:{}
按照如下json格式返回{}
"""
@staticmethod
def create(*args: any) -> str:
"""根据提供的任意数量参数生成数据集模板
Args:
*args: 任意数量的参数,将按顺序填充到模板中
Returns:
格式化后的模板字符串
"""
return create_dataset.template.format(*args)
if __name__=="__main__":
print(create_dataset.create("a", "b", "c"))

View File

@@ -1,9 +1,2 @@
openai>=1.0.0
python-dotenv>=1.0.0
pydantic>=2.0.0
gradio>=5.0.0
langchain>=0.3
tinydb>=4.0.0
unsloth>=2025.3.19
sqlmodel>=0.0.24
jinja2>=3.1.0

View File

@@ -1,4 +0,0 @@
from .dataset import *
from .dataset_generation import APIProvider, LLMResponse, LLMRequest
from .md_doc import MarkdownNode
from .prompt import promptTempleta

View File

@@ -1,30 +1,9 @@
from typing import Optional
from pydantic import BaseModel, Field
from datetime import datetime, timezone
from pydantic import BaseModel, RootModel
from typing import List
class doc(BaseModel):
id: Optional[int] = Field(default=None, description="文档ID")
name: str = Field(default="", description="文档名称")
path: str = Field(default="", description="文档路径")
markdown_files: list[str] = Field(default_factory=list, description="文档路径列表")
version: Optional[str] = Field(default="", description="文档版本")
class QAPair(BaseModel):
question: str
response: str
class Q_A(BaseModel):
question: str = Field(default="", min_length=1,description="问题")
answer: str = Field(default="", min_length=1, description="答案")
class dataset_item(BaseModel):
id: Optional[int] = Field(default=None, description="数据集项ID")
message: list[Q_A] = Field(description="数据集项内容")
class dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID")
name: str = Field(default="", description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档")
description: Optional[str] = Field(default="", description="数据集描述")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
dataset_items: list[dataset_item] = Field(default_factory=list, description="数据集项列表")
class QAArray(RootModel):
root: List[QAPair]

View File

@@ -1,51 +0,0 @@
from datetime import datetime, timezone
from typing import Optional
from sqlmodel import SQLModel, Relationship, Field
class APIProvider(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True,allow_mutation=False)
base_url: str = Field(...,min_length=1,description="API的基础URL不能为空")
model_id: str = Field(...,min_length=1,description="API使用的模型ID不能为空")
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
class LLMResponse(SQLModel):
timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="响应的时间戳"
)
response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: dict = Field(default_factory=lambda: {
"temperature": None,
"max_tokens": None,
"top_p": None,
"frequency_penalty": None,
"presence_penalty": None,
"seed": None
}, description="API的生成参数")
class LLMRequest(SQLModel):
prompt: str = Field(..., description="发送给API的提示词")
provider_id: int = Field(foreign_key="apiprovider.id")
provider: APIProvider = Relationship()
format: Optional[str] = Field(default=None, description="API响应的格式")
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
total_tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")

View File

@@ -1,13 +0,0 @@
from pydantic import BaseModel, Field
from typing import List, Optional
class MarkdownNode(BaseModel):
level: int = Field(default=0, description="节点层级")
title: str = Field(default="Root", description="节点标题")
content: Optional[str] = Field(default=None, description="节点内容")
children: List['MarkdownNode'] = Field(default_factory=list, description="子节点列表")
class Config:
arbitrary_types_allowed = True
MarkdownNode.model_rebuild()

View File

@@ -1,13 +0,0 @@
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime, timezone
class promptTempleta(BaseModel):
id: Optional[int] = Field(default=None, description="模板ID")
name: Optional[str] = Field(default="", description="模板名称")
description: Optional[str] = Field(default="", description="模板描述")
content: str = Field(default="", min_length=1, description="模板内容")
created_at: str = Field(
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
description="记录创建时间"
)

View File

@@ -1,4 +0,0 @@
from .parse_markdown import parse_markdown
from .scan_doc_dir import *
from .json_example import generate_example_json
from .model import *

View File

@@ -1,35 +0,0 @@
from typing import List
from schema.dataset import dataset, dataset_item, Q_A
import json
def convert_json_to_dataset(json_data: List[dict]) -> dataset:
# 将JSON数据转换为dataset格式
dataset_items = []
item_id = 1 # 自增ID计数器
for item in json_data:
qa = Q_A(question=item["question"], answer=item["answer"])
dataset_item_obj = dataset_item(id=item_id, message=[qa])
dataset_items.append(dataset_item_obj)
item_id += 1 # ID自增
# 创建dataset对象
result_dataset = dataset(
name="Converted Dataset",
model_id=None,
description="Dataset converted from JSON",
dataset_items=dataset_items
)
return result_dataset
# 示例从文件读取JSON并转换
if __name__ == "__main__":
# 假设JSON数据存储在文件中
with open(r"workdir\dataset_old\llamafactory.json", "r", encoding="utf-8") as file:
json_data = json.load(file)
# 转换为dataset格式
converted_dataset = convert_json_to_dataset(json_data)
# 输出结果到文件
with open("output.json", "w", encoding="utf-8") as file:
file.write(converted_dataset.model_dump_json(indent=4))

View File

@@ -1,63 +0,0 @@
from pydantic import BaseModel, create_model
from typing import Any, Dict, List, Optional, get_args, get_origin
import json
from datetime import datetime, date
def generate_example_json(model: type[BaseModel]) -> str:
"""
根据 Pydantic V2 模型生成示例 JSON 数据结构。
"""
def _generate_example(field_type: Any) -> Any:
origin = get_origin(field_type)
args = get_args(field_type)
if origin is list or origin is List:
if args:
return [_generate_example(args[0])]
else:
return []
elif origin is dict or origin is Dict:
if len(args) == 2 and args[0] is str:
return {"key": _generate_example(args[1])}
else:
return {}
elif origin is Optional or origin is type(None):
if args:
return _generate_example(args[0])
else:
return None
elif field_type is str:
return "string"
elif field_type is int:
return 0
elif field_type is float:
return 0.0
elif field_type is bool:
return True
elif field_type is datetime:
return datetime.now().isoformat()
elif field_type is date:
return date.today().isoformat()
elif issubclass(field_type, BaseModel):
return generate_example_json(field_type)
else:
return "unknown" # 对于未知类型返回 "unknown"
example_data = {}
for field_name, field in model.model_fields.items():
example_data[field_name] = _generate_example(field.annotation)
return json.dumps(example_data, indent=2, default=str)
if __name__ == "__main__":
import sys
from pathlib import Path
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import Q_A
class Q_A_list(BaseModel):
Q_As: List[Q_A]
print("示例 JSON:")
print(generate_example_json(Q_A_list))

View File

@@ -1,26 +0,0 @@
import os
def formatting_prompts_func(examples,tokenizer):
"""格式化对话数据的函数
Args:
examples: 包含对话列表的字典
Returns:
包含格式化文本的字典
"""
questions = examples["question"]
answer = examples["answer"]
convos = [
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
for q, r in zip(questions, answer)
]
# 使用tokenizer.apply_chat_template格式化对话
texts = [
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
for convo in convos
]
return {"text": texts}
def get_model_name(model):
return os.path.basename(model.name_or_path)

69
tools/openai_api.py Normal file
View File

@@ -0,0 +1,69 @@
import json
from openai import OpenAI
def generate_json_via_llm(
prompt: str,
base_url: str,
api_key: str,
model_id: str
) -> str:
client = OpenAI(
api_key=api_key,
base_url=base_url
)
try:
response = client.chat.completions.create(
model=model_id,
messages=[
{
"role": "user",
"content": prompt
}
],
response_format={
'type': 'json_object'
}
)
return response.choices[0].message.content
except Exception as e:
raise RuntimeError(f"API请求失败: {e}")
if __name__ == "__main__":
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from config.llm import load_config
# 将项目根目录添加到 sys.path 中
# 示例用法
try:
config = load_config()
print(config)
result = generate_json_via_llm(
prompt="""测试随便生成点什么返回json格式的字符串,格式如下
{
"dataset":[
{
"question":"",
"answer":""
},
{
"question":"",
"answer":""
}
......
]
}
""",
base_url=config["openai"]["base_url"],
api_key=config["openai"]["api_key"],
model_id=config["openai"]["model_id"],
)
print(result)
except Exception as e:
print(f"错误: {e}")

View File

@@ -1,45 +1,28 @@
import re
import sys
from pathlib import Path
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import MarkdownNode
class MarkdownNode:
def __init__(self, level=0, title="Root"):
self.level = level
self.title = title
self.content = "" # 使用字符串存储合并后的内容
self.children = []
def process_markdown_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
def __repr__(self):
return f"({self.level}) {self.title}"
root = parse_markdown(content)
results = []
def add_child(self, child):
self.children.append(child)
def traverse(node, parent_titles):
current_titles = parent_titles.copy()
current_titles.append(node.title)
if not node.children: # 叶子节点
if node.content:
full_text = ' -> '.join(current_titles) + '\n' + node.content
results.append(full_text)
else:
for child in node.children:
traverse(child, current_titles)
traverse(root, [])
return results
def add_child(parent, child):
parent.children.append(child)
def print_tree(node, indent=0):
prefix = "" * (indent - 1) + "└─ " if indent > 0 else ""
print(f"{prefix}{node.title}")
if node.content:
content_prefix = "" * indent + "├─ [内容]"
print(content_prefix)
for line in node.content.split('\n'):
print("" * indent + "" + line)
for child in node.children:
print_tree(child, indent + 1)
def print_tree(self, indent=0):
prefix = "" * (indent - 1) + "└─ " if indent > 0 else ""
print(f"{prefix}{self.title}")
if self.content:
content_prefix = "" * indent + "├─ [内容]"
print(content_prefix)
for line in self.content.split('\n'):
print("" * indent + "" + line)
for child in self.children:
child.print_tree(indent + 1)
def parse_markdown(markdown):
lines = markdown.split('\n')
@@ -68,10 +51,10 @@ def parse_markdown(markdown):
if match:
level = len(match.group(1))
title = match.group(2)
node = MarkdownNode(level=level, title=title, content="", children=[])
node = MarkdownNode(level, title)
while stack[-1].level >= level:
stack.pop()
add_child(stack[-1], node)
stack[-1].add_child(node)
stack.append(node)
else:
if stack[-1].content:
@@ -81,13 +64,10 @@ def parse_markdown(markdown):
return root
if __name__=="__main__":
# # 从文件读取 Markdown 内容
# with open("workdir/example.md", "r", encoding="utf-8") as f:
# markdown = f.read()
# 从文件读取 Markdown 内容
with open("example.md", "r", encoding="utf-8") as f:
markdown = f.read()
# # 解析 Markdown 并打印树结构
# root = parse_markdown(markdown)
# print_tree(root)
for i in process_markdown_file("workdir/example.md"):
print("~"*20)
print(i)
# 解析 Markdown 并打印树结构
root = parse_markdown(markdown)
root.print_tree()

View File

@@ -1,32 +0,0 @@
import sys
import os
from pathlib import Path
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import doc
def scan_docs_directory(workdir: str):
docs_dir = os.path.join(workdir, "docs")
doc_list = os.listdir(docs_dir)
to_return = []
for doc_name in doc_list:
doc_path = os.path.join(docs_dir, doc_name)
if os.path.isdir(doc_path):
markdown_files = []
for root, dirs, files in os.walk(doc_path):
for file in files:
if file.endswith(".md"):
markdown_files.append(os.path.join(root, file))
to_return.append(doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
return to_return
# 添加测试代码
if __name__ == "__main__":
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
docs = scan_docs_directory(workdir)
print(docs)

1534
train.ipynb Normal file

File diff suppressed because it is too large Load Diff

70
trainer.py Normal file
View File

@@ -0,0 +1,70 @@
from unsloth import FastLanguageModel
import torch
# 基础配置参数
max_seq_length = 4096 # 最大序列长度
dtype = None # 自动检测数据类型
load_in_4bit = True # 使用4位量化以减少内存使用
# 加载预训练模型和分词器
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "workdir\model\Qwen2.5-3B-Instruct-bnb-4bit", # 选择Qwen2.5 32B指令模型
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
model = FastLanguageModel.get_peft_model(
model,
r = 64, # LoRA秩,控制可训练参数数量
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",], # 需要训练的目标模块
lora_alpha = 64, # LoRA缩放因子
lora_dropout = 0, # LoRA dropout率
bias = "none", # 是否训练偏置项
use_gradient_checkpointing = "unsloth", # 使用梯度检查点节省显存
random_state = 114514, # 随机数种子
use_rslora = False, # 是否使用稳定版LoRA
loftq_config = None, # LoftQ配置
)
from unsloth.chat_templates import get_chat_template
# 配置分词器使用qwen-2.5对话模板
tokenizer = get_chat_template(
tokenizer,
chat_template="qwen-2.5",
)
def formatting_prompts_func(examples):
"""格式化对话数据的函数
Args:
examples: 包含对话列表的字典
Returns:
包含格式化文本的字典
"""
questions = examples["question"]
answer = examples["answer"]
# 将Question和Response组合成对话形式
convos = [
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
for q, r in zip(questions, answer)
]
# 使用tokenizer.apply_chat_template格式化对话
texts = [
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
for convo in convos
]
return {"text": texts}
from unsloth.chat_templates import standardize_sharegpt
# 加载数据集
from datasets import load_dataset
dataset = load_dataset("json", data_files="workdir\dataset\dataset.json")
dataset = dataset.map(formatting_prompts_func, batched = True)
print(dataset[5])
print(dataset[5]["text"])