Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
d35475d9e8 | ||
![]() |
f882b82e57 |
8
.gitignore
vendored
8
.gitignore
vendored
@@ -11,7 +11,6 @@ env/
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
.roo
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
@@ -28,9 +27,4 @@ Thumbs.db
|
||||
workdir/
|
||||
|
||||
# cache
|
||||
unsloth_compiled_cache
|
||||
|
||||
# 测试和参考代码
|
||||
test.ipynb
|
||||
test.py
|
||||
refer/
|
||||
unsloth_compiled_cache
|
92
README.md
92
README.md
@@ -1,91 +1,15 @@
|
||||
# 基于文档驱动的自适应编码大模型微调框架
|
||||
|
||||
## 简介
|
||||
|
||||
本人的毕业设计,这个是mvp分支(MVP 是指最小可行产品Minimum Viable Product),其他功能在master分支中
|
||||
### 项目概述
|
||||
本项目是一个基于文档驱动的自适应编码大模型微调框架,通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
|
||||
|
||||
### 核心功能
|
||||
- 文档解析与语料生成
|
||||
- 大语言模型高效微调
|
||||
- 交互式训练与推理界面
|
||||
- 训练过程可视化监控
|
||||
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
|
||||
|
||||
## 技术架构
|
||||
### 项目技术(预计)
|
||||
|
||||
### 系统架构
|
||||
```
|
||||
[前端界面] -> [模型微调] -> [数据存储]
|
||||
↑ ↑ ↑
|
||||
│ │ │
|
||||
[Gradio] [unsloth/QLoRA] [SQLite/TinyDB]
|
||||
```
|
||||
* 使用unsloth框架在GPU上实现大语言模型的qlora微调
|
||||
* 使用langchain框架编写工作流实现批量生成微调语料
|
||||
* 使用tinydb和sqlite实现数据的持久化
|
||||
* 使用gradio框架实现前端展示
|
||||
|
||||
### 技术栈
|
||||
- **前端界面**: Gradio构建的交互式Web界面
|
||||
- **模型微调**: 基于unsloth框架的QLoRA高效微调
|
||||
- **数据存储**: SQLite(结构化数据) + TinyDB(非结构化数据)
|
||||
- **工作流引擎**: LangChain实现文档解析与语料生成
|
||||
- **训练监控**: TensorBoard集成
|
||||
|
||||
## 功能模块
|
||||
|
||||
### 1. 模型管理
|
||||
- 支持多种格式的大语言模型加载
|
||||
- 模型信息查看与状态管理
|
||||
- 模型卸载与内存释放
|
||||
|
||||
### 2. 模型推理
|
||||
- 对话式交互界面
|
||||
- 流式响应输出
|
||||
- 历史对话管理
|
||||
|
||||
### 3. 模型微调
|
||||
- 训练参数配置(学习率、batch size等)
|
||||
- LoRA参数配置(秩、alpha等)
|
||||
- 训练过程实时监控
|
||||
- 训练中断与恢复
|
||||
|
||||
### 4. 数据集生成
|
||||
- 文档解析与清洗
|
||||
- 指令-响应对生成
|
||||
- 数据集质量评估
|
||||
- 数据集版本管理
|
||||
|
||||
## 微调技术
|
||||
|
||||
### QLoRA原理
|
||||
QLoRA(Quantized Low-Rank Adaptation)是一种高效的大模型微调技术,核心特点包括:
|
||||
1. **4-bit量化**: 将预训练模型量化为4-bit表示,大幅减少显存占用
|
||||
2. **低秩适配**: 通过低秩矩阵分解(LoRA)实现参数高效更新
|
||||
3. **内存优化**: 使用梯度检查点等技术进一步降低显存需求
|
||||
|
||||
### 参数配置
|
||||
- **学习率**: 建议2e-5到2e-4
|
||||
- **LoRA秩**: 控制适配器复杂度(建议16-64)
|
||||
- **LoRA Alpha**: 控制适配器更新幅度(通常设为秩的1-2倍)
|
||||
|
||||
### 训练监控
|
||||
- **TensorBoard集成**: 实时查看损失曲线、学习率等指标
|
||||
- **日志记录**: 训练过程详细日志保存
|
||||
- **模型检查点**: 定期保存中间权重
|
||||
|
||||
## 快速开始
|
||||
|
||||
1. 安装依赖:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. 启动应用:
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
3. 访问Web界面:
|
||||
```
|
||||
http://localhost:7860
|
||||
```
|
||||
|
||||
## 许可证
|
||||
MIT License
|
||||
**施工中......**
|
||||
|
15
config/llm.py
Normal file
15
config/llm.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from typing import Dict, Any
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
"""从.env文件加载配置"""
|
||||
load_dotenv()
|
||||
|
||||
return {
|
||||
"openai": {
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"base_url": os.getenv("OPENAI_BASE_URL"),
|
||||
"model_id": os.getenv("OPENAI_MODEL_ID")
|
||||
}
|
||||
}
|
94
dataset_generator.py
Normal file
94
dataset_generator.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import os
|
||||
import json
|
||||
from tools.parse_markdown import parse_markdown, MarkdownNode
|
||||
from tools.openai_api import generate_json_via_llm
|
||||
from prompt.base import create_dataset
|
||||
from config.llm import load_config
|
||||
from tqdm import tqdm
|
||||
|
||||
def process_markdown_file(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
root = parse_markdown(content)
|
||||
results = []
|
||||
|
||||
def traverse(node, parent_titles):
|
||||
current_titles = parent_titles.copy()
|
||||
current_titles.append(node.title)
|
||||
|
||||
if not node.children: # 叶子节点
|
||||
if node.content:
|
||||
full_text = ' -> '.join(current_titles) + '\n' + node.content
|
||||
results.append(full_text)
|
||||
else:
|
||||
for child in node.children:
|
||||
traverse(child, current_titles)
|
||||
|
||||
traverse(root, [])
|
||||
return results
|
||||
|
||||
def find_markdown_files(directory):
|
||||
markdown_files = []
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for file in files:
|
||||
if file.endswith('.md'):
|
||||
markdown_files.append(os.path.join(root, file))
|
||||
return markdown_files
|
||||
|
||||
def process_all_markdown(doc_dir):
|
||||
all_results = []
|
||||
|
||||
markdown_files = find_markdown_files(doc_dir)
|
||||
for file_path in markdown_files:
|
||||
results = process_markdown_file(file_path)
|
||||
all_results.extend(results)
|
||||
|
||||
return all_results
|
||||
|
||||
def save_dataset(dataset, output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_path = os.path.join(output_dir, 'dataset.json')
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 解析markdown文档
|
||||
results = process_all_markdown('workdir/my_docs')
|
||||
|
||||
# 加载LLM配置
|
||||
config = load_config()
|
||||
|
||||
dataset = []
|
||||
# 使用tqdm包装外部循环以显示进度条
|
||||
for content in tqdm(results, desc="生成数据集进度", unit="文档"):
|
||||
for _ in range(3):
|
||||
prompt = create_dataset.create(
|
||||
"LLaMA-Factory", # 项目名
|
||||
content, # 文档内容
|
||||
"""{
|
||||
"dataset":[
|
||||
{
|
||||
"question":"",
|
||||
"answer":""
|
||||
}
|
||||
]
|
||||
}"""
|
||||
)
|
||||
|
||||
# 调用LLM生成JSON
|
||||
try:
|
||||
result = generate_json_via_llm(
|
||||
prompt=prompt,
|
||||
base_url=config["openai"]["base_url"],
|
||||
api_key=config["openai"]["api_key"],
|
||||
model_id=config["openai"]["model_id"]
|
||||
)
|
||||
print(json.loads(result)["dataset"])
|
||||
dataset.extend(json.loads(result)["dataset"])
|
||||
except Exception as e:
|
||||
print(f"生成数据集时出错: {e}")
|
||||
|
||||
# 保存数据集
|
||||
save_dataset(dataset, 'workdir/dataset2')
|
||||
print(f"数据集已生成,共{len(dataset)}条数据")
|
@@ -1,12 +0,0 @@
|
||||
from .init_db import load_sqlite_engine, initialize_sqlite_db
|
||||
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
||||
from .dataset_store import get_all_dataset, save_dataset
|
||||
|
||||
__all__ = [
|
||||
"load_sqlite_engine",
|
||||
"initialize_sqlite_db",
|
||||
"get_prompt_tinydb",
|
||||
"initialize_prompt_store",
|
||||
"get_all_dataset",
|
||||
"save_dataset"
|
||||
]
|
@@ -1,81 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tinydb import TinyDB, Query
|
||||
from tinydb.storages import MemoryStorage
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset import Dataset, DatasetItem, Q_A
|
||||
|
||||
def get_all_dataset(workdir: str) -> TinyDB:
|
||||
"""
|
||||
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径
|
||||
|
||||
Returns:
|
||||
TinyDB: 包含所有数据集对象的TinyDB对象
|
||||
"""
|
||||
dataset_dir = os.path.join(workdir, "dataset")
|
||||
if not os.path.exists(dataset_dir):
|
||||
return TinyDB(storage=MemoryStorage)
|
||||
|
||||
db = TinyDB(storage=MemoryStorage)
|
||||
for filename in os.listdir(dataset_dir):
|
||||
if filename.endswith(".json"):
|
||||
filepath = os.path.join(dataset_dir, filename)
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
db.insert(data)
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
print(f"Error loading dataset file {filename}: {str(e)}")
|
||||
continue
|
||||
|
||||
return db
|
||||
|
||||
|
||||
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
|
||||
"""
|
||||
将TinyDB中的数据集保存为单独的json文件
|
||||
|
||||
Args:
|
||||
db (TinyDB): 包含数据集对象的TinyDB实例
|
||||
workdir (str): 工作目录路径
|
||||
name (str, optional): 要保存的数据集名称,None表示保存所有
|
||||
"""
|
||||
dataset_dir = os.path.join(workdir, "dataset")
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
datasets = db.all() if name is None else db.search(Query().name == name)
|
||||
|
||||
for dataset in datasets:
|
||||
try:
|
||||
filename = f"{dataset.get(dataset['name'])}.json"
|
||||
filepath = os.path.join(dataset_dir, filename)
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取所有数据集
|
||||
datasets = get_all_dataset(workdir)
|
||||
# 打印结果
|
||||
print(f"Found {len(datasets)} datasets:")
|
||||
for ds in datasets.all():
|
||||
print(f"- {ds['name']} (ID: {ds['id']})")
|
||||
|
||||
# 询问要保存的数据集名称
|
||||
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
|
||||
|
||||
# 保存数据集到文件
|
||||
save_dataset(datasets, workdir, name)
|
||||
print(f"Datasets {'all' if name is None else name} saved to json files")
|
@@ -1,79 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
from sqlmodel import SQLModel, create_engine, Session
|
||||
from sqlmodel import select
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset_generation import APIProvider
|
||||
|
||||
# 全局变量,用于存储数据库引擎实例
|
||||
_engine: Optional[Engine] = None
|
||||
|
||||
def load_sqlite_engine(workdir: str) -> Engine:
|
||||
"""
|
||||
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||||
|
||||
Returns:
|
||||
Engine: SQLAlchemy 数据库引擎实例。
|
||||
"""
|
||||
global _engine
|
||||
if not _engine:
|
||||
# 创建数据库目录(如果不存在)
|
||||
db_dir = os.path.join(workdir, "db")
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
# 定义数据库文件路径
|
||||
db_path = os.path.join(db_dir, "db.sqlite")
|
||||
# 创建数据库URL
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
# 创建数据库引擎
|
||||
_engine = create_engine(db_url)
|
||||
return _engine
|
||||
|
||||
def initialize_sqlite_db(engine: Engine) -> None:
|
||||
"""
|
||||
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
|
||||
|
||||
Args:
|
||||
engine (Engine): SQLAlchemy 数据库引擎实例。
|
||||
"""
|
||||
# 创建所有表结构
|
||||
SQLModel.metadata.create_all(engine)
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
# 从环境变量中获取API相关配置
|
||||
api_key = os.getenv("API_KEY")
|
||||
base_url = os.getenv("BASE_URL")
|
||||
model_id = os.getenv("MODEL_ID")
|
||||
|
||||
# 如果所有必要的环境变量都存在,则插入初始数据
|
||||
if api_key and base_url and model_id:
|
||||
with Session(engine) as session:
|
||||
# 查询是否已存在APIProvider记录
|
||||
statement = select(APIProvider).limit(1)
|
||||
existing_provider = session.exec(statement).first()
|
||||
|
||||
# 如果不存在,则插入新的APIProvider记录
|
||||
if not existing_provider:
|
||||
provider = APIProvider(
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
api_key=api_key
|
||||
)
|
||||
session.add(provider)
|
||||
session.commit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取数据库引擎
|
||||
engine = load_sqlite_engine(workdir)
|
||||
# 初始化数据库
|
||||
initialize_sqlite_db(engine)
|
@@ -1,62 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from tinydb import TinyDB, Query
|
||||
from tinydb.storages import JSONStorage
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.prompt import promptTempleta
|
||||
|
||||
# 全局变量,用于存储TinyDB实例
|
||||
_db_instance: Optional[TinyDB] = None
|
||||
# 自定义存储类,用于格式化JSON数据
|
||||
def get_prompt_tinydb(workdir: str) -> TinyDB:
|
||||
"""
|
||||
获取TinyDB实例。如果实例尚未创建,则创建一个新的并返回。
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||||
|
||||
Returns:
|
||||
TinyDB: TinyDB数据库实例
|
||||
"""
|
||||
global _db_instance
|
||||
if not _db_instance:
|
||||
# 创建数据库目录(如果不存在)
|
||||
db_dir = os.path.join(workdir, "db")
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
# 定义数据库文件路径
|
||||
db_path = os.path.join(db_dir, "prompts.json")
|
||||
# 创建TinyDB实例
|
||||
_db_instance = TinyDB(db_path)
|
||||
return _db_instance
|
||||
|
||||
def initialize_prompt_store(db: TinyDB) -> None:
|
||||
"""
|
||||
初始化prompt模板存储
|
||||
|
||||
Args:
|
||||
db (TinyDB): TinyDB数据库实例
|
||||
"""
|
||||
# 检查数据库是否为空
|
||||
if not db.all(): # 如果数据库中没有数据
|
||||
db.insert(promptTempleta(
|
||||
id=1,
|
||||
name="default",
|
||||
description="默认提示词模板",
|
||||
content="""项目名为:{project_name}
|
||||
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
|
||||
文档节选:{document_slice}""").model_dump())
|
||||
# 如果数据库中已有数据,则跳过插入
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取数据库实例
|
||||
db = get_prompt_tinydb(workdir)
|
||||
# 初始化prompt存储
|
||||
initialize_prompt_store(db)
|
@@ -1,7 +0,0 @@
|
||||
from .train_page import *
|
||||
from .chat_page import *
|
||||
from .setting_page import *
|
||||
from .model_manage_page import *
|
||||
from .dataset_manage_page import *
|
||||
from .dataset_generate_page import *
|
||||
from .prompt_manage_page import *
|
@@ -1,97 +0,0 @@
|
||||
import gradio as gr
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread # 需要导入 Thread
|
||||
from transformers import TextIteratorStreamer # 使用 TextIteratorStreamer
|
||||
|
||||
# 假设 global_var.py 在父目录
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer # 假设这两个函数能正确获取模型和分词器
|
||||
|
||||
def chat_page():
|
||||
with gr.Blocks() as demo:
|
||||
# 聊天框
|
||||
gr.Markdown("## 对话")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
chatbot = gr.Chatbot(type="messages", label="聊天机器人")
|
||||
msg = gr.Textbox(label="输入消息")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
# 新增超参数输入框
|
||||
max_new_tokens_input = gr.Textbox(label="最大生成长度", value="1024")
|
||||
temperature_input = gr.Textbox(label="温度 (Temperature)", value="0.8")
|
||||
top_p_input = gr.Textbox(label="Top-p 采样", value="0.95")
|
||||
repetition_penalty_input = gr.Textbox(label="重复惩罚", value="1.1")
|
||||
clear = gr.Button("清除对话")
|
||||
|
||||
def user(user_message, history: list):
|
||||
return "", history + [{"role": "user", "content": user_message}]
|
||||
|
||||
def bot(history: list, max_new_tokens, temperature, top_p, repetition_penalty):
|
||||
model = get_model()
|
||||
tokenizer = get_tokenizer()
|
||||
if not history:
|
||||
yield history
|
||||
return
|
||||
|
||||
if model is None or tokenizer is None:
|
||||
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"})
|
||||
yield history
|
||||
return
|
||||
|
||||
try:
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
history,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
# 将超参数转换为数值类型
|
||||
generation_kwargs = dict(
|
||||
input_ids=inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=int(max_new_tokens),
|
||||
temperature=float(temperature),
|
||||
top_p=float(top_p),
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
do_sample=True,
|
||||
use_cache=False
|
||||
)
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
history.append({"role": "assistant", "content": ""})
|
||||
|
||||
for new_text in streamer:
|
||||
if new_text:
|
||||
history[-1]["content"] += new_text
|
||||
yield history
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
|
||||
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
|
||||
history[-1]["content"] = error_message
|
||||
else:
|
||||
history.append({"role": "assistant", "content": error_message})
|
||||
yield history
|
||||
|
||||
# 更新 .then() 调用,将超参数传递给 bot 函数
|
||||
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
||||
bot, [chatbot, max_new_tokens_input, temperature_input, top_p_input, repetition_penalty_input], chatbot
|
||||
)
|
||||
|
||||
clear.click(lambda: [], None, chatbot, queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from model_manage_page import model_manage_page
|
||||
# 装载两个页面
|
||||
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
|
||||
demo.queue()
|
||||
demo.launch()
|
@@ -1,189 +0,0 @@
|
||||
import gradio as gr
|
||||
import sys
|
||||
import json
|
||||
from tinydb import Query
|
||||
from pathlib import Path
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sqlmodel import Session, select
|
||||
from schema import Dataset, DatasetItem, Q_A
|
||||
from db.dataset_store import get_all_dataset
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
|
||||
from db import save_dataset
|
||||
from tools import call_openai_api, process_markdown_file, generate_json_example
|
||||
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
|
||||
|
||||
def dataset_generate_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 数据集生成")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
docs_list = [str(doc.name) for doc in get_docs()]
|
||||
initial_doc = docs_list[0] if docs_list else None
|
||||
prompts = get_prompt_store().all()
|
||||
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
|
||||
initial_prompt = prompt_list[0] if prompt_list else None
|
||||
|
||||
# 初始化Dataframe的值
|
||||
initial_dataframe_value = []
|
||||
if initial_prompt:
|
||||
selected_prompt_id = int(initial_prompt.split(" ")[0])
|
||||
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
|
||||
prompt_content = prompt_data["content"]
|
||||
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||
input_variables = prompt_template.input_variables
|
||||
input_variables.remove("document_slice")
|
||||
initial_dataframe_value = [[var, ""] for var in input_variables]
|
||||
# 从数据库获取API Provider列表
|
||||
with Session(get_sql_engine()) as session:
|
||||
providers = session.exec(select(APIProvider)).all()
|
||||
api_list = [f"{p.id} {p.model_id}" for p in providers]
|
||||
initial_api = api_list[0] if api_list else None
|
||||
|
||||
api_dropdown = gr.Dropdown(
|
||||
choices=api_list,
|
||||
value=initial_api,
|
||||
label="选择API",
|
||||
interactive=True
|
||||
)
|
||||
doc_dropdown = gr.Dropdown(
|
||||
choices=docs_list,
|
||||
value=initial_doc,
|
||||
label="选择文档",
|
||||
interactive=True
|
||||
)
|
||||
prompt_dropdown = gr.Dropdown(
|
||||
choices=prompt_list,
|
||||
value=initial_prompt,
|
||||
label="选择模板",
|
||||
interactive=True
|
||||
)
|
||||
rounds_input = gr.Number(
|
||||
value=1,
|
||||
label="生成轮次",
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
step=1,
|
||||
interactive=True
|
||||
)
|
||||
concurrency_input = gr.Number(
|
||||
value=1,
|
||||
label="并发数",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=1,
|
||||
interactive=True,
|
||||
visible=False
|
||||
)
|
||||
dataset_name_input = gr.Textbox(
|
||||
label="数据集名称",
|
||||
placeholder="输入数据集保存名称",
|
||||
interactive=True
|
||||
)
|
||||
prompt_choice = gr.State(value=initial_prompt)
|
||||
generate_button = gr.Button("生成数据集",variant="primary")
|
||||
doc_choice = gr.State(value=initial_doc)
|
||||
output_text = gr.Textbox(label="生成结果", interactive=False)
|
||||
api_choice = gr.State(value=initial_api)
|
||||
with gr.Column(scale=2):
|
||||
variables_dataframe = gr.Dataframe(
|
||||
headers=["变量名", "变量值"],
|
||||
datatype=["str", "str"],
|
||||
interactive=True,
|
||||
label="变量列表",
|
||||
value=initial_dataframe_value # 设置初始化数据
|
||||
)
|
||||
def on_doc_change(selected_doc):
|
||||
return selected_doc
|
||||
|
||||
def on_api_change(selected_api):
|
||||
return selected_api
|
||||
|
||||
def on_prompt_change(selected_prompt):
|
||||
if not selected_prompt:
|
||||
return None, []
|
||||
selected_prompt_id = int(selected_prompt.split(" ")[0])
|
||||
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
|
||||
prompt_content = prompt_data["content"]
|
||||
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||
input_variables = prompt_template.input_variables
|
||||
input_variables.remove("document_slice")
|
||||
dataframe_value = [] if input_variables is None else input_variables
|
||||
dataframe_value = [[var, ""] for var in input_variables]
|
||||
return selected_prompt, dataframe_value
|
||||
|
||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
|
||||
dataset_db = get_datasets()
|
||||
if not dataset_db.search(Query().name == dataset_name):
|
||||
raise gr.Error("数据集名称已存在")
|
||||
|
||||
doc = [i for i in get_docs() if i.name == doc_state][0]
|
||||
doc_files = doc.markdown_files
|
||||
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
|
||||
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
||||
prompt = PromptTemplate.from_template(prompt["content"])
|
||||
with Session(get_sql_engine()) as session:
|
||||
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||
|
||||
variables_dict = {}
|
||||
for _, row in variables_dataframe.iterrows():
|
||||
var_name = row['变量名'].strip()
|
||||
var_value = row['变量值'].strip()
|
||||
if var_name:
|
||||
variables_dict[var_name] = var_value
|
||||
|
||||
prompt = prompt.partial(**variables_dict)
|
||||
|
||||
dataset = Dataset(
|
||||
name=dataset_name,
|
||||
model_id=[api_provider.model_id],
|
||||
source_doc=doc,
|
||||
dataset_items=[]
|
||||
)
|
||||
|
||||
total_slices = len(document_slice_list)
|
||||
for i, document_slice in enumerate(document_slice_list):
|
||||
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
|
||||
request = LLMRequest(api_provider=api_provider,
|
||||
prompt=prompt.format(document_slice=document_slice),
|
||||
format=generate_json_example(DatasetItem))
|
||||
call_openai_api(request, rounds)
|
||||
|
||||
for resp in request.response:
|
||||
try:
|
||||
content = json.loads(resp.content)
|
||||
dataset_item = DatasetItem(
|
||||
message=[Q_A(
|
||||
question=content.get("question", ""),
|
||||
answer=content.get("answer", "")
|
||||
)]
|
||||
)
|
||||
dataset.dataset_items.append(dataset_item)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse response: {e}")
|
||||
|
||||
# 保存数据集到TinyDB
|
||||
dataset_db.insert(dataset.model_dump())
|
||||
|
||||
save_dataset(dataset_db,get_workdir(),dataset_name)
|
||||
|
||||
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
|
||||
|
||||
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
||||
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
||||
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
|
||||
|
||||
generate_button.click(
|
||||
on_generate_click,
|
||||
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
|
||||
outputs=output_text
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from global_var import init_global_var
|
||||
init_global_var("workdir")
|
||||
demo = dataset_generate_page()
|
||||
demo.launch()
|
@@ -1,55 +0,0 @@
|
||||
import gradio as gr
|
||||
from global_var import get_datasets
|
||||
from tinydb import Query
|
||||
|
||||
def dataset_manage_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 数据集管理")
|
||||
with gr.Row():
|
||||
# 获取数据集列表并设置初始值
|
||||
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
||||
initial_dataset = datasets_list[0] if datasets_list else None
|
||||
|
||||
dataset_dropdown = gr.Dropdown(
|
||||
choices=datasets_list,
|
||||
value=initial_dataset, # 设置初始选中项
|
||||
label="选择数据集",
|
||||
allow_custom_value=True,
|
||||
interactive=True
|
||||
)
|
||||
|
||||
# 添加数据集展示组件
|
||||
qa_dataset = gr.Dataset(
|
||||
components=["text", "text"],
|
||||
label="问答数据",
|
||||
headers=["问题", "答案"],
|
||||
samples=[["示例问题", "示例答案"]],
|
||||
samples_per_page=20,
|
||||
)
|
||||
|
||||
def update_qa_display(dataset_name):
|
||||
if not dataset_name:
|
||||
return {"samples": [], "__type__": "update"}
|
||||
|
||||
# 从数据库获取数据集
|
||||
Dataset = Query()
|
||||
ds = get_datasets().get(Dataset.name == dataset_name)
|
||||
if not ds:
|
||||
return {"samples": [], "__type__": "update"}
|
||||
|
||||
# 提取所有Q_A数据
|
||||
qa_list = []
|
||||
for item in ds["dataset_items"]:
|
||||
for qa in item["message"]:
|
||||
qa_list.append([qa["question"], qa["answer"]])
|
||||
|
||||
return {"samples": qa_list, "__type__": "update"}
|
||||
|
||||
# 绑定事件,更新QA数据显示
|
||||
dataset_dropdown.change(
|
||||
update_qa_display,
|
||||
inputs=dataset_dropdown,
|
||||
outputs=qa_dataset
|
||||
)
|
||||
|
||||
return demo
|
@@ -1,103 +0,0 @@
|
||||
import gradio as gr
|
||||
import os # 导入os模块以便扫描文件夹
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unsloth import FastLanguageModel
|
||||
import torch
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
|
||||
from train import get_model_name
|
||||
|
||||
def model_manage_page():
|
||||
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
|
||||
models_dir = os.path.join(workdir, "models")
|
||||
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))] # 扫描models文件夹下的所有子文件夹
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 模型管理")
|
||||
state_output = gr.Label(label="当前状态",value="当前未加载模型") # 将 Textbox 改为 Label
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
model_select_dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中,并设置为可选
|
||||
max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0)
|
||||
load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True)
|
||||
with gr.Column(scale=1):
|
||||
load_button = gr.Button("加载模型", variant="primary")
|
||||
unload_button = gr.Button("卸载模型", variant="stop")
|
||||
refresh_button = gr.Button("刷新模型列表", variant="secondary") # 新增刷新按钮
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
|
||||
save_method_dropdown = gr.Dropdown(
|
||||
choices=["default", "merged_16bit", "merged_4bit", "lora", "gguf", "gguf_q4_k_m", "gguf_f16"],
|
||||
label="保存模式",
|
||||
value="default"
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
save_button = gr.Button("保存模型", variant="secondary")
|
||||
|
||||
def load_model(selected_model, max_seq_length, load_in_4bit):
|
||||
try:
|
||||
# 判空操作,如果模型已加载,则先卸载
|
||||
if get_model() is not None:
|
||||
unload_model()
|
||||
|
||||
model_path = os.path.join(models_dir, selected_model)
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=model_path,
|
||||
max_seq_length=max_seq_length,
|
||||
load_in_4bit=load_in_4bit,
|
||||
)
|
||||
set_model(model)
|
||||
set_tokenizer(tokenizer)
|
||||
return f"模型 {get_model_name(model)} 已加载"
|
||||
except Exception as e:
|
||||
return f"加载模型时出错: {str(e)}"
|
||||
|
||||
load_button.click(fn=load_model, inputs=[model_select_dropdown, max_seq_length_input, load_in_4bit_input], outputs=state_output)
|
||||
|
||||
def unload_model():
|
||||
try:
|
||||
# 将模型移动到 CPU
|
||||
model = get_model()
|
||||
if model is not None:
|
||||
model.cpu()
|
||||
|
||||
# 清空 CUDA 缓存
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 将模型和tokenizer设置为 None
|
||||
set_model(None)
|
||||
set_tokenizer(None)
|
||||
|
||||
return "当前未加载模型"
|
||||
except Exception as e:
|
||||
return f"卸载模型时出错: {str(e)}"
|
||||
|
||||
unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
|
||||
|
||||
from train.save_model import save_model_to_dir
|
||||
|
||||
def save_model(save_model_name, save_method):
|
||||
return save_model_to_dir(save_model_name, models_dir, get_model(), get_tokenizer(), save_method)
|
||||
|
||||
save_button.click(fn=save_model, inputs=[save_model_name_input, save_method_dropdown], outputs=state_output)
|
||||
|
||||
def refresh_model_list():
|
||||
try:
|
||||
nonlocal model_folders
|
||||
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))]
|
||||
return gr.Dropdown(choices=model_folders)
|
||||
except Exception as e:
|
||||
return f"刷新模型列表时出错: {str(e)}"
|
||||
|
||||
refresh_button.click(fn=refresh_model_list, inputs=None, outputs=model_select_dropdown)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo = model_manage_page()
|
||||
demo.queue()
|
||||
demo.launch()
|
@@ -1,130 +0,0 @@
|
||||
import gradio as gr
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_prompt_store
|
||||
from schema.prompt import promptTempleta
|
||||
def prompt_manage_page():
|
||||
def get_prompts() -> List[List[str]]:
|
||||
selected_row = None
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
prompts = db.all()
|
||||
return [
|
||||
[p["id"], p["name"], p["description"], p["content"]]
|
||||
for p in prompts
|
||||
]
|
||||
except Exception as e:
|
||||
raise gr.Error(f"获取提示词失败: {str(e)}")
|
||||
|
||||
def add_prompt(name, description, content):
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
new_prompt = promptTempleta(
|
||||
name=name if name else "",
|
||||
description=description if description else "",
|
||||
content=content if content else ""
|
||||
)
|
||||
prompt_id = db.insert(new_prompt.model_dump())
|
||||
# 更新ID
|
||||
db.update({"id": prompt_id}, doc_ids=[prompt_id])
|
||||
return get_prompts(), "", "", "" # 返回清空后的输入框值
|
||||
except Exception as e:
|
||||
raise gr.Error(f"添加失败: {str(e)}")
|
||||
|
||||
def edit_prompt():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要编辑的行")
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
db.update({
|
||||
"name": selected_row[1] if selected_row[1] else "",
|
||||
"description": selected_row[2] if selected_row[2] else "",
|
||||
"content": selected_row[3] if selected_row[3] else ""
|
||||
}, doc_ids=[selected_row[0]])
|
||||
return get_prompts()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"编辑失败: {str(e)}")
|
||||
|
||||
def delete_prompt():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要删除的行")
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
db.remove(doc_ids=[selected_row[0]])
|
||||
return get_prompts()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"删除失败: {str(e)}")
|
||||
|
||||
selected_row = None # 保存当前选中行的全局变量
|
||||
|
||||
def select_record(dataFrame ,evt: gr.SelectData):
|
||||
global selected_row
|
||||
selected_row = dataFrame.iloc[evt.index[0]].tolist()
|
||||
selected_row[0] = int(selected_row[0])
|
||||
print(selected_row)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 提示词模板管理")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
name_input = gr.Textbox(label="模板名称")
|
||||
description_input = gr.Textbox(label="模板描述")
|
||||
content_input = gr.Textbox(label="模板内容", lines=10)
|
||||
add_button = gr.Button("添加新模板", variant="primary")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
prompt_table = gr.DataFrame(
|
||||
headers=["id", "名称", "描述", "内容"],
|
||||
datatype=["number", "str", "str", "str"],
|
||||
interactive=True,
|
||||
value=get_prompts(),
|
||||
wrap=True,
|
||||
col_count=(4, "auto")
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
refresh_button = gr.Button("刷新数据", variant="secondary")
|
||||
edit_button = gr.Button("编辑选中行", variant="primary")
|
||||
delete_button = gr.Button("删除选中行", variant="stop")
|
||||
|
||||
refresh_button.click(
|
||||
fn=get_prompts,
|
||||
outputs=[prompt_table],
|
||||
queue=False
|
||||
)
|
||||
|
||||
add_button.click(
|
||||
fn=add_prompt,
|
||||
inputs=[name_input, description_input, content_input],
|
||||
outputs=[prompt_table, name_input, description_input, content_input]
|
||||
)
|
||||
|
||||
prompt_table.select(fn=select_record,
|
||||
inputs=[prompt_table],
|
||||
outputs=[],
|
||||
show_progress="hidden")
|
||||
|
||||
edit_button.click(
|
||||
fn=edit_prompt,
|
||||
inputs=[],
|
||||
outputs=[prompt_table]
|
||||
)
|
||||
|
||||
delete_button.click(
|
||||
fn=delete_prompt,
|
||||
inputs=[],
|
||||
outputs=[prompt_table]
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo = prompt_manage_page()
|
||||
demo.queue()
|
||||
demo.launch()
|
@@ -1,131 +0,0 @@
|
||||
import gradio as gr
|
||||
from typing import List
|
||||
from sqlmodel import Session, select
|
||||
from schema import APIProvider
|
||||
from global_var import get_sql_engine
|
||||
|
||||
def setting_page():
|
||||
def get_providers() -> List[List[str]]:
|
||||
selected_row = None
|
||||
try: # 添加异常处理
|
||||
with Session(get_sql_engine()) as session:
|
||||
providers = session.exec(select(APIProvider)).all()
|
||||
return [
|
||||
[p.id, p.model_id, p.base_url, p.api_key or ""]
|
||||
for p in providers
|
||||
]
|
||||
except Exception as e:
|
||||
raise gr.Error(f"获取数据失败: {str(e)}")
|
||||
|
||||
def add_provider(model_id, base_url, api_key):
|
||||
try:
|
||||
with Session(get_sql_engine()) as session:
|
||||
new_provider = APIProvider(
|
||||
model_id=model_id if model_id else None,
|
||||
base_url=base_url if base_url else None,
|
||||
api_key=api_key if api_key else None
|
||||
)
|
||||
session.add(new_provider)
|
||||
session.commit()
|
||||
session.refresh(new_provider)
|
||||
return get_providers(), "", "", "" # 返回清空后的输入框值
|
||||
except Exception as e:
|
||||
raise gr.Error(f"添加失败: {str(e)}")
|
||||
|
||||
def edit_provider():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要编辑的行")
|
||||
try:
|
||||
with Session(get_sql_engine()) as session:
|
||||
provider = session.get(APIProvider, selected_row[0])
|
||||
if not provider:
|
||||
raise gr.Error("找不到选中的记录")
|
||||
provider.model_id = selected_row[1] if selected_row[1] else None
|
||||
provider.base_url = selected_row[2] if selected_row[2] else None
|
||||
provider.api_key = selected_row[3] if selected_row[3] else None
|
||||
session.add(provider)
|
||||
session.commit()
|
||||
session.refresh(provider)
|
||||
return get_providers()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"编辑失败: {str(e)}")
|
||||
|
||||
def delete_provider():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要删除的行")
|
||||
try:
|
||||
with Session(get_sql_engine()) as session:
|
||||
provider = session.get(APIProvider, selected_row[0])
|
||||
if not provider:
|
||||
raise gr.Error("找不到选中的记录")
|
||||
session.delete(provider)
|
||||
session.commit()
|
||||
return get_providers()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"删除失败: {str(e)}")
|
||||
|
||||
selected_row = None # 保存当前选中行的全局变量
|
||||
|
||||
def select_record(dataFrame ,evt: gr.SelectData):
|
||||
global selected_row
|
||||
selected_row = dataFrame.iloc[evt.index[0]].tolist()
|
||||
selected_row[0] = int(selected_row[0])
|
||||
print(selected_row)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## API Provider 管理")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
model_id_input = gr.Textbox(label="Model ID")
|
||||
base_url_input = gr.Textbox(label="Base URL")
|
||||
api_key_input = gr.Textbox(label="API Key")
|
||||
add_button = gr.Button("添加新API", variant="primary")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
provider_table = gr.DataFrame(
|
||||
headers=["id", "model id", "base URL", "API Key"],
|
||||
datatype=["number", "str", "str", "str"],
|
||||
interactive=True,
|
||||
value=get_providers(),
|
||||
wrap=True,
|
||||
col_count=(4, "auto")
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
refresh_button = gr.Button("刷新数据", variant="secondary")
|
||||
edit_button = gr.Button("编辑选中行", variant="primary")
|
||||
delete_button = gr.Button("删除选中行", variant="stop")
|
||||
|
||||
refresh_button.click(
|
||||
fn=get_providers,
|
||||
outputs=[provider_table],
|
||||
queue=False # 立即刷新不需要排队
|
||||
)
|
||||
|
||||
add_button.click(
|
||||
fn=add_provider,
|
||||
inputs=[model_id_input, base_url_input, api_key_input],
|
||||
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
|
||||
)
|
||||
|
||||
provider_table.select(fn=select_record,
|
||||
inputs=[provider_table],
|
||||
outputs=[],
|
||||
show_progress="hidden")
|
||||
|
||||
edit_button.click(
|
||||
fn=edit_provider,
|
||||
inputs=[],
|
||||
outputs=[provider_table]
|
||||
)
|
||||
|
||||
delete_button.click(
|
||||
fn=delete_provider,
|
||||
inputs=[],
|
||||
outputs=[provider_table]
|
||||
)
|
||||
|
||||
return demo
|
@@ -1,113 +0,0 @@
|
||||
import subprocess
|
||||
import os
|
||||
import gradio as gr
|
||||
import sys
|
||||
from tinydb import Query
|
||||
from pathlib import Path
|
||||
from transformers import TrainerCallback
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
||||
from tools import find_available_port
|
||||
from train import train_model
|
||||
|
||||
def train_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 微调")
|
||||
# 获取数据集列表并设置初始值
|
||||
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
||||
initial_dataset = datasets_list[0] if datasets_list else None
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
dataset_dropdown = gr.Dropdown(
|
||||
choices=datasets_list,
|
||||
value=initial_dataset, # 设置初始选中项
|
||||
label="选择数据集",
|
||||
allow_custom_value=True,
|
||||
interactive=True
|
||||
)
|
||||
# 新增超参数输入组件
|
||||
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
||||
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
||||
epoch_input = gr.Number(value=1, label="epoch", precision=0)
|
||||
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
|
||||
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
|
||||
|
||||
train_button = gr.Button("开始微调")
|
||||
|
||||
# 训练状态输出
|
||||
output = gr.Textbox(label="训练状态", interactive=False)
|
||||
with gr.Column(scale=3):
|
||||
# 新增 TensorBoard iframe 显示框
|
||||
tensorboard_iframe = gr.HTML(label="TensorBoard 可视化")
|
||||
|
||||
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||||
# 使用动态传入的超参数
|
||||
learning_rate = float(learning_rate)
|
||||
per_device_train_batch_size = int(per_device_train_batch_size)
|
||||
epoch = int(epoch)
|
||||
save_steps = int(save_steps) # 新增保存步数参数
|
||||
lora_rank = int(lora_rank) # 新增LoRA秩参数
|
||||
|
||||
# 加载数据集
|
||||
dataset = get_datasets().get(Query().name == dataset_name)
|
||||
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
||||
|
||||
# 扫描 training 文件夹并生成递增目录
|
||||
training_dir = get_workdir() + "/training"
|
||||
os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在
|
||||
existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()]
|
||||
next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1
|
||||
new_training_dir = os.path.join(training_dir, str(next_dir_number))
|
||||
|
||||
tensorboard_port = find_available_port(6006) # 从默认端口 6006 开始检测
|
||||
print(f"TensorBoard 将使用端口: {tensorboard_port}")
|
||||
|
||||
tensorboard_logdir = os.path.join(new_training_dir, "logs")
|
||||
os.makedirs(tensorboard_logdir, exist_ok=True) # 确保日志目录存在
|
||||
tensorboard_process = subprocess.Popen(
|
||||
["tensorboard", "--logdir", tensorboard_logdir, "--port", str(tensorboard_port)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
print("TensorBoard 已启动,日志目录:", tensorboard_logdir)
|
||||
|
||||
# 动态生成 TensorBoard iframe
|
||||
tensorboard_url = f"http://localhost:{tensorboard_port}"
|
||||
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>'
|
||||
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
|
||||
|
||||
try:
|
||||
train_model(get_model(), get_tokenizer(),
|
||||
dataset, new_training_dir,
|
||||
learning_rate, per_device_train_batch_size, epoch,
|
||||
save_steps, lora_rank)
|
||||
except Exception as e:
|
||||
raise gr.Error(str(e))
|
||||
finally:
|
||||
# 确保训练结束后终止 TensorBoard 子进程
|
||||
tensorboard_process.terminate()
|
||||
print("TensorBoard 子进程已终止")
|
||||
|
||||
train_button.click(
|
||||
fn=start_training,
|
||||
inputs=[
|
||||
dataset_dropdown,
|
||||
learning_rate_input,
|
||||
per_device_train_batch_size_input,
|
||||
epoch_input,
|
||||
save_steps_input,
|
||||
lora_rank_input
|
||||
],
|
||||
outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from global_var import init_global_var
|
||||
from model_manage_page import model_manage_page
|
||||
init_global_var("workdir")
|
||||
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
|
||||
demo.queue()
|
||||
demo.launch()
|
@@ -1,52 +0,0 @@
|
||||
from db import load_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
||||
from tools import scan_docs_directory
|
||||
|
||||
_prompt_store = None
|
||||
_sql_engine = None
|
||||
_datasets = None
|
||||
_model = None
|
||||
_tokenizer = None
|
||||
_workdir = None
|
||||
def init_global_var(workdir="workdir"):
|
||||
global _prompt_store, _sql_engine, _datasets, _workdir
|
||||
_prompt_store = get_prompt_tinydb(workdir)
|
||||
_sql_engine = load_sqlite_engine(workdir)
|
||||
_datasets = get_all_dataset(workdir)
|
||||
_workdir = workdir
|
||||
|
||||
def get_workdir():
|
||||
return _workdir
|
||||
|
||||
def get_prompt_store():
|
||||
return _prompt_store
|
||||
|
||||
def set_prompt_store(new_prompt_store):
|
||||
global _prompt_store
|
||||
_prompt_store = new_prompt_store
|
||||
|
||||
def get_sql_engine():
|
||||
return _sql_engine
|
||||
|
||||
def set_sql_engine(new_sql_engine):
|
||||
global _sql_engine
|
||||
_sql_engine = new_sql_engine
|
||||
|
||||
def get_docs():
|
||||
global _workdir
|
||||
return scan_docs_directory(_workdir)
|
||||
def get_datasets():
|
||||
return _datasets
|
||||
|
||||
def get_model():
|
||||
return _model
|
||||
|
||||
def set_model(new_model):
|
||||
global _model
|
||||
_model = new_model
|
||||
|
||||
def get_tokenizer():
|
||||
return _tokenizer
|
||||
|
||||
def set_tokenizer(new_tokenizer):
|
||||
global _tokenizer
|
||||
_tokenizer = new_tokenizer
|
29
main.py
29
main.py
@@ -1,29 +0,0 @@
|
||||
import gradio as gr
|
||||
import train
|
||||
from frontend import *
|
||||
from db import initialize_sqlite_db, initialize_prompt_store
|
||||
from global_var import init_global_var, get_sql_engine, get_prompt_store
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_global_var()
|
||||
initialize_sqlite_db(get_sql_engine())
|
||||
initialize_prompt_store(get_prompt_store())
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("模型管理"):
|
||||
model_manage_page()
|
||||
with gr.TabItem("模型推理"):
|
||||
chat_page()
|
||||
with gr.TabItem("模型微调"):
|
||||
train_page()
|
||||
with gr.TabItem("数据集生成"):
|
||||
dataset_generate_page()
|
||||
with gr.TabItem("数据集管理"):
|
||||
dataset_manage_page()
|
||||
with gr.TabItem("提示词模板管理"):
|
||||
prompt_manage_page()
|
||||
with gr.TabItem("设置"):
|
||||
setting_page()
|
||||
|
||||
app.launch(server_name="0.0.0.0")
|
25
prompt/base.py
Normal file
25
prompt/base.py
Normal file
@@ -0,0 +1,25 @@
|
||||
class create_dataset:
|
||||
"""用于生成微调数据集模板的类"""
|
||||
|
||||
template = """
|
||||
项目名为:{}
|
||||
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
|
||||
文档节选:{}
|
||||
按照如下json格式返回:{}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(*args: any) -> str:
|
||||
"""根据提供的任意数量参数生成数据集模板
|
||||
|
||||
Args:
|
||||
*args: 任意数量的参数,将按顺序填充到模板中
|
||||
|
||||
Returns:
|
||||
格式化后的模板字符串
|
||||
"""
|
||||
return create_dataset.template.format(*args)
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
print(create_dataset.create("a", "b", "c"))
|
@@ -1,11 +1,2 @@
|
||||
openai>=1.0.0
|
||||
python-dotenv>=1.0.0
|
||||
pydantic>=2.0.0
|
||||
gradio>=5.25.0
|
||||
langchain>=0.3
|
||||
tinydb>=4.0.0
|
||||
unsloth>=2025.3.19
|
||||
sqlmodel>=0.0.24
|
||||
jinja2>=3.1.0
|
||||
tensorboardX>=2.6.2.2
|
||||
tensorboard>=2.19.0
|
||||
python-dotenv>=1.0.0
|
@@ -1,4 +0,0 @@
|
||||
from .dataset import *
|
||||
from .dataset_generation import *
|
||||
from .md_doc import MarkdownNode
|
||||
from .prompt import promptTempleta
|
@@ -1,30 +1,9 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone
|
||||
from pydantic import BaseModel, RootModel
|
||||
from typing import List
|
||||
|
||||
class Doc(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="文档ID")
|
||||
name: str = Field(default="", description="文档名称")
|
||||
path: str = Field(default="", description="文档路径")
|
||||
markdown_files: list[str] = Field(default_factory=list, description="文档路径列表")
|
||||
version: Optional[str] = Field(default="", description="文档版本")
|
||||
class QAPair(BaseModel):
|
||||
question: str
|
||||
response: str
|
||||
|
||||
class Q_A(BaseModel):
|
||||
question: str = Field(default="", min_length=1,description="问题")
|
||||
answer: str = Field(default="", min_length=1, description="答案")
|
||||
|
||||
class DatasetItem(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="数据集项ID")
|
||||
message: list[Q_A] = Field(description="数据集项内容")
|
||||
|
||||
class Dataset(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="数据集ID")
|
||||
name: str = Field(default="", description="数据集名称")
|
||||
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
|
||||
source_doc: Optional[Doc] = Field(default=None, description="数据集来源文档")
|
||||
description: Optional[str] = Field(default="", description="数据集描述")
|
||||
created_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="记录创建时间"
|
||||
)
|
||||
dataset_items: list[DatasetItem] = Field(default_factory=list, description="数据集项列表")
|
||||
class QAArray(RootModel):
|
||||
root: List[QAPair]
|
@@ -1,47 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlmodel import SQLModel, Relationship, Field
|
||||
|
||||
class APIProvider(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True,allow_mutation=False)
|
||||
base_url: str = Field(...,min_length=1,description="API的基础URL,不能为空")
|
||||
model_id: str = Field(...,min_length=1,description="API使用的模型ID,不能为空")
|
||||
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
|
||||
created_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="记录创建时间"
|
||||
)
|
||||
|
||||
class LLMParameters(SQLModel):
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
class TokensUsage(SQLModel):
|
||||
prompt_tokens: int = Field(default=0, description="提示词使用的token数量")
|
||||
completion_tokens: int = Field(default=0, description="完成部分使用的token数量")
|
||||
prompt_cache_hit_tokens: Optional[int] = Field(default=None, description="缓存命中token数量")
|
||||
prompt_cache_miss_tokens: Optional[int] = Field(default=None, description="缓存未命中token数量")
|
||||
|
||||
class LLMResponse(SQLModel):
|
||||
timestamp: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="响应的时间戳"
|
||||
)
|
||||
response_id: str = Field(..., description="响应的唯一ID")
|
||||
tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
||||
content: str = Field(default_factory=dict, description="API响应的内容")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")
|
||||
|
||||
class LLMRequest(SQLModel):
|
||||
prompt: str = Field(..., description="发送给API的提示词")
|
||||
api_provider: APIProvider = Field(..., description="API提供者的信息")
|
||||
format: Optional[str] = Field(default=None, description="API响应的格式")
|
||||
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
||||
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
total_tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
@@ -1,13 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
class MarkdownNode(BaseModel):
|
||||
level: int = Field(default=0, description="节点层级")
|
||||
title: str = Field(default="Root", description="节点标题")
|
||||
content: Optional[str] = Field(default=None, description="节点内容")
|
||||
children: List['MarkdownNode'] = Field(default_factory=list, description="子节点列表")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
MarkdownNode.model_rebuild()
|
@@ -1,20 +0,0 @@
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
class promptTempleta(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="模板ID")
|
||||
name: Optional[str] = Field(default="", description="模板名称")
|
||||
description: Optional[str] = Field(default="", description="模板描述")
|
||||
content: str = Field(default="", min_length=1, description="模板内容")
|
||||
created_at: str = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
|
||||
description="记录创建时间"
|
||||
)
|
||||
|
||||
@field_validator('content')
|
||||
def validate_content(cls, value):
|
||||
if not "document_slice" in PromptTemplate.from_template(value).input_variables:
|
||||
raise ValueError("模板变量缺少 document_slice")
|
||||
return value
|
@@ -1,5 +0,0 @@
|
||||
from .parse_markdown import *
|
||||
from .document import *
|
||||
from .json_example import generate_json_example
|
||||
from .port import *
|
||||
from .reasoning import call_openai_api
|
@@ -1,35 +0,0 @@
|
||||
from typing import List
|
||||
from schema.dataset import Dataset, DatasetItem, Q_A
|
||||
import json
|
||||
|
||||
def convert_json_to_dataset(json_data: List[dict]) -> Dataset:
|
||||
# 将JSON数据转换为dataset格式
|
||||
dataset_items = []
|
||||
item_id = 1 # 自增ID计数器
|
||||
for item in json_data:
|
||||
qa = Q_A(question=item["question"], answer=item["answer"])
|
||||
dataset_item_obj = DatasetItem(id=item_id, message=[qa])
|
||||
dataset_items.append(dataset_item_obj)
|
||||
item_id += 1 # ID自增
|
||||
|
||||
# 创建dataset对象
|
||||
result_dataset = Dataset(
|
||||
name="Converted Dataset",
|
||||
model_id=None,
|
||||
description="Dataset converted from JSON",
|
||||
dataset_items=dataset_items
|
||||
)
|
||||
return result_dataset
|
||||
|
||||
# 示例:从文件读取JSON并转换
|
||||
if __name__ == "__main__":
|
||||
# 假设JSON数据存储在文件中
|
||||
with open(r"workdir\dataset_old\llamafactory.json", "r", encoding="utf-8") as file:
|
||||
json_data = json.load(file)
|
||||
|
||||
# 转换为dataset格式
|
||||
converted_dataset = convert_json_to_dataset(json_data)
|
||||
|
||||
# 输出结果到文件
|
||||
with open("output.json", "w", encoding="utf-8") as file:
|
||||
file.write(converted_dataset.model_dump_json(indent=4))
|
@@ -1,32 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import Doc
|
||||
|
||||
def scan_docs_directory(workdir: str):
|
||||
docs_dir = os.path.join(workdir, "docs")
|
||||
|
||||
doc_list = os.listdir(docs_dir)
|
||||
|
||||
to_return = []
|
||||
|
||||
for doc_name in doc_list:
|
||||
doc_path = os.path.join(docs_dir, doc_name)
|
||||
if os.path.isdir(doc_path):
|
||||
markdown_files = []
|
||||
for root, dirs, files in os.walk(doc_path):
|
||||
for file in files:
|
||||
if file.endswith(".md"):
|
||||
markdown_files.append(os.path.join(root, file))
|
||||
to_return.append(Doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
|
||||
|
||||
return to_return
|
||||
|
||||
# 添加测试代码
|
||||
if __name__ == "__main__":
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
docs = scan_docs_directory(workdir)
|
||||
print(docs)
|
@@ -1,67 +0,0 @@
|
||||
from pydantic import BaseModel, create_model
|
||||
from typing import Any, Dict, List, Optional, Union, get_args, get_origin
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
|
||||
def generate_json_example(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
|
||||
"""
|
||||
根据 Pydantic V2 模型生成示例 JSON 数据结构。
|
||||
"""
|
||||
def _generate_example(field_type: Any) -> Any:
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
if origin is list or origin is List:
|
||||
# 生成多个元素,这里生成 3 个
|
||||
result = [_generate_example(args[0]) for _ in range(list_length)] if args else []
|
||||
result.append("......")
|
||||
return result
|
||||
elif origin is dict or origin is Dict:
|
||||
if len(args) == 2:
|
||||
return {"key": _generate_example(args[1])}
|
||||
return {}
|
||||
elif origin is Union:
|
||||
# 处理 Optional 类型(Union[T, None])
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
return _generate_example(non_none_args[0]) if non_none_args else None
|
||||
elif field_type is str:
|
||||
return "string"
|
||||
elif field_type is int:
|
||||
return 0
|
||||
elif field_type is float:
|
||||
return 0.0
|
||||
elif field_type is bool:
|
||||
return True
|
||||
elif field_type is datetime:
|
||||
return datetime.now().isoformat()
|
||||
elif field_type is date:
|
||||
return date.today().isoformat()
|
||||
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_json_example(field_type, include_optional))
|
||||
else:
|
||||
# 处理直接类型注解(非泛型)
|
||||
if field_type is type(None):
|
||||
return None
|
||||
try:
|
||||
if issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_json_example(field_type, include_optional))
|
||||
except TypeError:
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
example_data = {}
|
||||
for field_name, field in model.model_fields.items():
|
||||
if include_optional or not isinstance(field.default, type(None)):
|
||||
example_data[field_name] = _generate_example(field.annotation)
|
||||
|
||||
return json.dumps(example_data, indent=2, default=str)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
from pathlib import Path
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import Dataset
|
||||
|
||||
print("示例 JSON:")
|
||||
print(generate_json_example(Dataset))
|
69
tools/openai_api.py
Normal file
69
tools/openai_api.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import json
|
||||
from openai import OpenAI
|
||||
|
||||
def generate_json_via_llm(
|
||||
prompt: str,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model_id: str
|
||||
) -> str:
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model_id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
response_format={
|
||||
'type': 'json_object'
|
||||
}
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"API请求失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
from config.llm import load_config
|
||||
# 将项目根目录添加到 sys.path 中
|
||||
|
||||
# 示例用法
|
||||
try:
|
||||
config = load_config()
|
||||
print(config)
|
||||
result = generate_json_via_llm(
|
||||
prompt="""测试,随便生成点什么,返回json格式的字符串,格式如下
|
||||
{
|
||||
"dataset":[
|
||||
{
|
||||
"question":"",
|
||||
"answer":""
|
||||
},
|
||||
{
|
||||
"question":"",
|
||||
"answer":""
|
||||
}
|
||||
......
|
||||
]
|
||||
}
|
||||
""",
|
||||
base_url=config["openai"]["base_url"],
|
||||
api_key=config["openai"]["api_key"],
|
||||
model_id=config["openai"]["model_id"],
|
||||
)
|
||||
print(result)
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
@@ -1,45 +1,28 @@
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import MarkdownNode
|
||||
class MarkdownNode:
|
||||
def __init__(self, level=0, title="Root"):
|
||||
self.level = level
|
||||
self.title = title
|
||||
self.content = "" # 使用字符串存储合并后的内容
|
||||
self.children = []
|
||||
|
||||
def process_markdown_file(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
root = parse_markdown(content)
|
||||
results = []
|
||||
|
||||
def traverse(node, parent_titles):
|
||||
current_titles = parent_titles.copy()
|
||||
current_titles.append(node.title)
|
||||
|
||||
if not node.children: # 叶子节点
|
||||
if node.content:
|
||||
full_text = ' -> '.join(current_titles) + '\n' + node.content
|
||||
results.append(full_text)
|
||||
else:
|
||||
for child in node.children:
|
||||
traverse(child, current_titles)
|
||||
|
||||
traverse(root, [])
|
||||
return results
|
||||
def add_child(parent, child):
|
||||
parent.children.append(child)
|
||||
def __repr__(self):
|
||||
return f"({self.level}) {self.title}"
|
||||
|
||||
def print_tree(node, indent=0):
|
||||
prefix = "│ " * (indent - 1) + "└─ " if indent > 0 else ""
|
||||
print(f"{prefix}{node.title}")
|
||||
if node.content:
|
||||
content_prefix = "│ " * indent + "├─ [内容]"
|
||||
print(content_prefix)
|
||||
for line in node.content.split('\n'):
|
||||
print("│ " * indent + "│ " + line)
|
||||
for child in node.children:
|
||||
print_tree(child, indent + 1)
|
||||
def add_child(self, child):
|
||||
self.children.append(child)
|
||||
|
||||
def print_tree(self, indent=0):
|
||||
prefix = "│ " * (indent - 1) + "└─ " if indent > 0 else ""
|
||||
print(f"{prefix}{self.title}")
|
||||
if self.content:
|
||||
content_prefix = "│ " * indent + "├─ [内容]"
|
||||
print(content_prefix)
|
||||
for line in self.content.split('\n'):
|
||||
print("│ " * indent + "│ " + line)
|
||||
for child in self.children:
|
||||
child.print_tree(indent + 1)
|
||||
|
||||
def parse_markdown(markdown):
|
||||
lines = markdown.split('\n')
|
||||
@@ -68,10 +51,10 @@ def parse_markdown(markdown):
|
||||
if match:
|
||||
level = len(match.group(1))
|
||||
title = match.group(2)
|
||||
node = MarkdownNode(level=level, title=title, content="", children=[])
|
||||
node = MarkdownNode(level, title)
|
||||
while stack[-1].level >= level:
|
||||
stack.pop()
|
||||
add_child(stack[-1], node)
|
||||
stack[-1].add_child(node)
|
||||
stack.append(node)
|
||||
else:
|
||||
if stack[-1].content:
|
||||
@@ -81,13 +64,10 @@ def parse_markdown(markdown):
|
||||
return root
|
||||
|
||||
if __name__=="__main__":
|
||||
# # 从文件读取 Markdown 内容
|
||||
# with open("workdir/example.md", "r", encoding="utf-8") as f:
|
||||
# markdown = f.read()
|
||||
# 从文件读取 Markdown 内容
|
||||
with open("example.md", "r", encoding="utf-8") as f:
|
||||
markdown = f.read()
|
||||
|
||||
# # 解析 Markdown 并打印树结构
|
||||
# root = parse_markdown(markdown)
|
||||
# print_tree(root)
|
||||
for i in process_markdown_file("workdir/example.md"):
|
||||
print("~"*20)
|
||||
print(i)
|
||||
# 解析 Markdown 并打印树结构
|
||||
root = parse_markdown(markdown)
|
||||
root.print_tree()
|
||||
|
@@ -1,15 +0,0 @@
|
||||
import socket
|
||||
|
||||
# 启动 TensorBoard 子进程前添加端口检测逻辑
|
||||
def find_available_port(start_port):
|
||||
port = start_port
|
||||
while True:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
if s.connect_ex(('localhost', port)) != 0: # 端口未被占用
|
||||
return port
|
||||
port += 1 # 如果端口被占用,尝试下一个端口
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_port = 6006 # 起始端口号
|
||||
available_port = find_available_port(start_port)
|
||||
print(f"Available port: {available_port}")
|
@@ -1,123 +0,0 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import openai
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters
|
||||
|
||||
async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=llm_request.api_provider.api_key,
|
||||
base_url=llm_request.api_provider.base_url
|
||||
)
|
||||
|
||||
total_duration = 0.0
|
||||
total_tokens = TokensUsage()
|
||||
prompt = llm_request.prompt
|
||||
round_start = datetime.now(timezone.utc)
|
||||
if llm_request.format:
|
||||
prompt += "\n请以JSON格式返回结果" + llm_request.format
|
||||
|
||||
for i in range(rounds):
|
||||
round_start = datetime.now(timezone.utc)
|
||||
try:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
create_args = {
|
||||
"model": llm_request.api_provider.model_id,
|
||||
"messages": messages,
|
||||
"temperature": llm_parameters.temperature if llm_parameters else None,
|
||||
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
|
||||
"top_p": llm_parameters.top_p if llm_parameters else None,
|
||||
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
|
||||
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
|
||||
"seed": llm_parameters.seed if llm_parameters else None
|
||||
} # 处理format参数
|
||||
|
||||
if llm_request.format:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
|
||||
response = await client.chat.completions.create(**create_args)
|
||||
|
||||
round_end = datetime.now(timezone.utc)
|
||||
duration = (round_end - round_start).total_seconds()
|
||||
total_duration += duration
|
||||
|
||||
# 处理可能不存在的缓存token字段
|
||||
usage = response.usage
|
||||
cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None)
|
||||
cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None)
|
||||
|
||||
tokens_usage = TokensUsage(
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
prompt_cache_hit_tokens=cache_hit,
|
||||
prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens
|
||||
)
|
||||
|
||||
# 累加总token使用量
|
||||
total_tokens.prompt_tokens += tokens_usage.prompt_tokens
|
||||
total_tokens.completion_tokens += tokens_usage.completion_tokens
|
||||
if tokens_usage.prompt_cache_hit_tokens:
|
||||
total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens
|
||||
if tokens_usage.prompt_cache_miss_tokens:
|
||||
total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens
|
||||
|
||||
llm_request.response.append(LLMResponse(
|
||||
response_id=response.id,
|
||||
tokens_usage=tokens_usage,
|
||||
content = response.choices[0].message.content,
|
||||
total_duration=duration,
|
||||
llm_parameters=llm_parameters
|
||||
))
|
||||
except Exception as e:
|
||||
round_end = datetime.now(timezone.utc)
|
||||
duration = (round_end - round_start).total_seconds()
|
||||
total_duration += duration
|
||||
|
||||
llm_request.response.append(LLMResponse(
|
||||
response_id=f"error-round-{i+1}",
|
||||
content={"error": str(e)},
|
||||
total_duration=duration
|
||||
))
|
||||
if llm_request.error is None:
|
||||
llm_request.error = []
|
||||
llm_request.error.append(str(e))
|
||||
|
||||
# 更新总耗时和总token使用量
|
||||
llm_request.total_duration = total_duration
|
||||
llm_request.total_tokens_usage = total_tokens
|
||||
|
||||
return llm_request
|
||||
|
||||
if __name__ == "__main__":
|
||||
from json_example import generate_json_example
|
||||
from sqlmodel import Session, select
|
||||
from global_var import get_sql_engine, init_global_var
|
||||
from schema import DatasetItem
|
||||
|
||||
init_global_var("workdir")
|
||||
api_state = "1 deepseek-chat"
|
||||
with Session(get_sql_engine()) as session:
|
||||
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||
llm_request = LLMRequest(
|
||||
prompt="测试,随便说点什么",
|
||||
api_provider=api_provider,
|
||||
format=generate_json_example(DatasetItem)
|
||||
)
|
||||
|
||||
# # 单次调用示例
|
||||
# result = asyncio.run(call_openai_api(llm_request))
|
||||
# print(f"\n单次调用结果 - 响应数量: {len(result.response)}")
|
||||
# for i, resp in enumerate(result.response, 1):
|
||||
# print(f"响应{i}: {resp.response_content}")
|
||||
|
||||
# 多次调用示例
|
||||
params = LLMParameters(temperature=0.7, max_tokens=100)
|
||||
result = asyncio.run(call_openai_api(llm_request, 3,params))
|
||||
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
|
||||
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
|
||||
for i, resp in enumerate(result.response, 1):
|
||||
print(f"响应{i}: {resp.content}")
|
1534
train.ipynb
Normal file
1534
train.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
||||
from .model import *
|
134
train/model.py
134
train/model.py
@@ -1,134 +0,0 @@
|
||||
import os
|
||||
from datasets import Dataset as HFDataset
|
||||
from unsloth import FastLanguageModel
|
||||
from trl import SFTTrainer # 用于监督微调的训练器
|
||||
from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数
|
||||
from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练
|
||||
from unsloth.chat_templates import get_chat_template, train_on_responses_only
|
||||
def get_model_name(model):
|
||||
return os.path.basename(model.name_or_path)
|
||||
def formatting_prompts(examples,tokenizer):
|
||||
"""格式化对话数据的函数
|
||||
Args:
|
||||
examples: 包含对话列表的字典
|
||||
Returns:
|
||||
包含格式化文本的字典
|
||||
"""
|
||||
questions = examples["question"]
|
||||
answer = examples["answer"]
|
||||
|
||||
convos = [
|
||||
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
|
||||
for q, r in zip(questions, answer)
|
||||
]
|
||||
|
||||
# 使用tokenizer.apply_chat_template格式化对话
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
|
||||
for convo in convos
|
||||
]
|
||||
|
||||
return {"text": texts}
|
||||
|
||||
|
||||
def train_model(
|
||||
model,
|
||||
tokenizer,
|
||||
dataset: list,
|
||||
train_dir: str,
|
||||
learning_rate: float,
|
||||
per_device_train_batch_size: int,
|
||||
epoch: int,
|
||||
save_steps: int,
|
||||
lora_rank: int,
|
||||
trainer_callback=None
|
||||
) -> None:
|
||||
# 模型配置参数
|
||||
dtype = None # 数据类型,None表示自动选择
|
||||
load_in_4bit = False # 使用4bit量化加载模型以节省显存
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
# 原始模型
|
||||
model,
|
||||
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
|
||||
# 建议: 8-32之间
|
||||
r=lora_rank, # 使用动态传入的LoRA秩
|
||||
# 需要应用LoRA的目标模块列表
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
|
||||
"gate_proj", "up_proj", "down_proj", # FFN相关层
|
||||
],
|
||||
# LoRA缩放因子,用于控制LoRA更新的幅度。值越大,LoRA的更新影响越大。
|
||||
lora_alpha=16,
|
||||
# LoRA层的dropout率,用于防止过拟合,这里设为0表示不使用dropout。
|
||||
# 如果数据集较小,建议设置0.1左右。
|
||||
lora_dropout=0,
|
||||
# 是否对bias参数进行微调,none表示不微调bias
|
||||
# none: 不微调偏置参数;
|
||||
# all: 微调所有参数;
|
||||
# lora_only: 只微调LoRA参数。
|
||||
bias="none",
|
||||
# 是否使用梯度检查点技术节省显存,使用unsloth优化版本
|
||||
# 会略微降低训练速度,但可以显著减少显存使用
|
||||
use_gradient_checkpointing="unsloth",
|
||||
# 随机数种子,用于结果复现
|
||||
random_state=3407,
|
||||
# 是否使用rank-stabilized LoRA,这里不使用
|
||||
# 会略微降低训练速度,但可以显著减少显存使用
|
||||
use_rslora=False,
|
||||
# LoFTQ配置,这里不使用该量化技术,用于进一步压缩模型大小
|
||||
loftq_config=None,
|
||||
)
|
||||
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="qwen-2.5",
|
||||
)
|
||||
|
||||
train_dataset = HFDataset.from_list(dataset)
|
||||
train_dataset = train_dataset.map(formatting_prompts,
|
||||
fn_kwargs={"tokenizer": tokenizer},
|
||||
batched=True)
|
||||
|
||||
# 初始化SFT训练器
|
||||
trainer = SFTTrainer(
|
||||
model=model, # 待训练的模型
|
||||
tokenizer=tokenizer, # 分词器
|
||||
train_dataset=train_dataset, # 训练数据集
|
||||
dataset_text_field="text", # 数据集字段的名称
|
||||
max_seq_length=model.max_seq_length, # 最大序列长度
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=1, # 数据集处理的并行进程数
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=per_device_train_batch_size, # 每个GPU的训练批次大小
|
||||
gradient_accumulation_steps=4, # 梯度累积步数,用于模拟更大的batch size
|
||||
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
|
||||
learning_rate=learning_rate, # 学习率
|
||||
lr_scheduler_type="linear", # 线性学习率调度器
|
||||
max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据)
|
||||
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
|
||||
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
|
||||
logging_steps=1, # 每1步记录一次日志
|
||||
optim="adamw_8bit", # 使用8位AdamW优化器节省显存,几乎不影响训练效果
|
||||
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
|
||||
seed=114514, # 随机数种子
|
||||
output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志
|
||||
save_strategy="steps", # 按步保存中间权重
|
||||
save_steps=save_steps, # 使用动态传入的保存步数
|
||||
logging_dir=train_dir + "/logs", # 日志文件存储路径
|
||||
report_to="tensorboard", # 使用TensorBoard记录日志
|
||||
),
|
||||
)
|
||||
|
||||
if trainer_callback is not None:
|
||||
trainer.add_callback(trainer_callback)
|
||||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part = "<|im_start|>user\n",
|
||||
response_part = "<|im_start|>assistant\n",
|
||||
)
|
||||
|
||||
# 开始训练
|
||||
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
@@ -1,48 +0,0 @@
|
||||
import os
|
||||
from global_var import get_model, get_tokenizer
|
||||
|
||||
def save_model_to_dir(save_model_name, models_dir, model, tokenizer, save_method="default"):
|
||||
"""
|
||||
保存模型到指定目录
|
||||
:param save_model_name: 要保存的模型名称
|
||||
:param models_dir: 模型保存的基础目录
|
||||
:param model: 要保存的模型
|
||||
:param tokenizer: 要保存的tokenizer
|
||||
:param save_method: 保存模式选项
|
||||
- "default": 默认保存方式
|
||||
- "merged_16bit": 合并为16位
|
||||
- "merged_4bit": 合并为4位
|
||||
- "lora": 仅LoRA适配器
|
||||
- "gguf": 保存为GGUF格式
|
||||
- "gguf_q4_k_m": 保存为q4_k_m GGUF格式
|
||||
- "gguf_f16": 保存为16位GGUF格式
|
||||
:return: 保存结果消息或错误信息
|
||||
"""
|
||||
try:
|
||||
if model is None:
|
||||
return "没有加载的模型可保存"
|
||||
|
||||
save_path = os.path.join(models_dir, save_model_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
if save_method == "default":
|
||||
model.save_pretrained(save_path)
|
||||
tokenizer.save_pretrained(save_path)
|
||||
elif save_method == "merged_16bit":
|
||||
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_16bit")
|
||||
elif save_method == "merged_4bit":
|
||||
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_4bit_forced")
|
||||
elif save_method == "lora":
|
||||
model.save_pretrained_merged(save_path, tokenizer, save_method="lora")
|
||||
elif save_method == "gguf":
|
||||
model.save_pretrained_gguf(save_path, tokenizer)
|
||||
elif save_method == "gguf_q4_k_m":
|
||||
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="q4_k_m")
|
||||
elif save_method == "gguf_f16":
|
||||
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="f16")
|
||||
else:
|
||||
return f"不支持的保存模式: {save_method}"
|
||||
|
||||
return f"模型已保存到 {save_path} (模式: {save_method})"
|
||||
except Exception as e:
|
||||
return f"保存模型时出错: {str(e)}"
|
70
trainer.py
Normal file
70
trainer.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from unsloth import FastLanguageModel
|
||||
import torch
|
||||
|
||||
# 基础配置参数
|
||||
max_seq_length = 4096 # 最大序列长度
|
||||
dtype = None # 自动检测数据类型
|
||||
load_in_4bit = True # 使用4位量化以减少内存使用
|
||||
|
||||
# 加载预训练模型和分词器
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name = "workdir\model\Qwen2.5-3B-Instruct-bnb-4bit", # 选择Qwen2.5 32B指令模型
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = dtype,
|
||||
load_in_4bit = load_in_4bit,
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r = 64, # LoRA秩,控制可训练参数数量
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",], # 需要训练的目标模块
|
||||
lora_alpha = 64, # LoRA缩放因子
|
||||
lora_dropout = 0, # LoRA dropout率
|
||||
bias = "none", # 是否训练偏置项
|
||||
use_gradient_checkpointing = "unsloth", # 使用梯度检查点节省显存
|
||||
random_state = 114514, # 随机数种子
|
||||
use_rslora = False, # 是否使用稳定版LoRA
|
||||
loftq_config = None, # LoftQ配置
|
||||
)
|
||||
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
# 配置分词器使用qwen-2.5对话模板
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="qwen-2.5",
|
||||
)
|
||||
|
||||
def formatting_prompts_func(examples):
|
||||
"""格式化对话数据的函数
|
||||
Args:
|
||||
examples: 包含对话列表的字典
|
||||
Returns:
|
||||
包含格式化文本的字典
|
||||
"""
|
||||
questions = examples["question"]
|
||||
answer = examples["answer"]
|
||||
|
||||
# 将Question和Response组合成对话形式
|
||||
convos = [
|
||||
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
|
||||
for q, r in zip(questions, answer)
|
||||
]
|
||||
|
||||
# 使用tokenizer.apply_chat_template格式化对话
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
|
||||
for convo in convos
|
||||
]
|
||||
|
||||
return {"text": texts}
|
||||
|
||||
from unsloth.chat_templates import standardize_sharegpt
|
||||
|
||||
# 加载数据集
|
||||
from datasets import load_dataset
|
||||
dataset = load_dataset("json", data_files="workdir\dataset\dataset.json")
|
||||
dataset = dataset.map(formatting_prompts_func, batched = True)
|
||||
|
||||
print(dataset[5])
|
||||
print(dataset[5]["text"])
|
Reference in New Issue
Block a user