Compare commits
251 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
0521c2bc5a | ||
![]() |
aba990d8b9 | ||
![]() |
1e9762b631 | ||
![]() |
457efaeaf1 | ||
![]() |
2e72d1e1bd | ||
![]() |
26629d68c0 | ||
![]() |
148471e93c | ||
![]() |
5d6217c04f | ||
![]() |
e00ff8a216 | ||
![]() |
a8cc793e43 | ||
![]() |
0b2f1d6c9f | ||
![]() |
1d2583c128 | ||
![]() |
9202135d3b | ||
![]() |
db51abe066 | ||
![]() |
a44102a8c2 | ||
![]() |
9aff1875e2 | ||
![]() |
50a33561f4 | ||
![]() |
f91a80a124 | ||
![]() |
6e62902867 | ||
![]() |
71484745d1 | ||
![]() |
64de52a428 | ||
![]() |
c16329d421 | ||
![]() |
d2828fb1f8 | ||
![]() |
67c5ed310a | ||
![]() |
4785985b6f | ||
![]() |
47fd47dab1 | ||
![]() |
eaf3d5f4ca | ||
![]() |
d5652beb01 | ||
![]() |
fa0f95f21d | ||
![]() |
b6cba12aa0 | ||
![]() |
76c494d086 | ||
![]() |
122cdcfd9c | ||
![]() |
0f338db935 | ||
![]() |
ce8e0b9b6a | ||
![]() |
8219b0f68f | ||
![]() |
5fb6faa59a | ||
![]() |
e3843a1853 | ||
![]() |
1955a4d76b | ||
![]() |
3898290e16 | ||
![]() |
228b3b9e64 | ||
![]() |
aa4758af4e | ||
![]() |
d7a3c0cfc0 | ||
![]() |
26b3d2d02e | ||
![]() |
ddbb9f9e34 | ||
![]() |
d5507c4082 | ||
![]() |
8120b66422 | ||
![]() |
4c27351fc1 | ||
![]() |
e63e815c5b | ||
![]() |
0dc9c89e5a | ||
![]() |
904fce220d | ||
![]() |
694f37b38a | ||
![]() |
0a291e0330 | ||
![]() |
833c6b94b3 | ||
![]() |
55508a55b9 | ||
![]() |
fa4b1500de | ||
![]() |
278fc67f18 | ||
![]() |
2af5316ff2 | ||
![]() |
af006ebcc7 | ||
![]() |
b61a5f7a75 | ||
![]() |
39a0ec48b1 | ||
![]() |
9054be273a | ||
![]() |
f2ecbca569 | ||
![]() |
d0917e134e | ||
![]() |
d35074038e | ||
![]() |
5543150a9a | ||
![]() |
80f12c66c8 | ||
![]() |
46e82361d4 | ||
![]() |
7c3aa12cd4 | ||
![]() |
342f180639 | ||
![]() |
f6beeb8788 | ||
![]() |
ef8b1a5115 | ||
![]() |
b4da95e81e | ||
![]() |
728c38dd2d | ||
![]() |
083542069c | ||
![]() |
7dad189b8c | ||
![]() |
8edae2b862 | ||
![]() |
e7af584e9a | ||
![]() |
7d6444c6c8 | ||
![]() |
c273bb3399 | ||
![]() |
c3a04c4958 | ||
![]() |
41c1a2fce4 | ||
![]() |
1ff958e641 | ||
![]() |
aac8c9d334 | ||
![]() |
e4919ecf69 | ||
![]() |
2cf8b1b6fd | ||
![]() |
55784b3775 | ||
![]() |
0ed43a963c | ||
![]() |
e87d721f47 | ||
![]() |
7907b96baa | ||
![]() |
7bd6d2feab | ||
![]() |
d632029724 | ||
![]() |
c6cc73f813 | ||
![]() |
47014a6d43 | ||
![]() |
ccfc81c1f1 | ||
![]() |
e8107519ed | ||
![]() |
2362e43d38 | ||
![]() |
6dbe9514fb | ||
![]() |
585f470403 | ||
![]() |
a94e14831e | ||
![]() |
4141e09494 | ||
![]() |
634ce8fff8 | ||
![]() |
75bc9a76d0 | ||
![]() |
4a67c20b70 | ||
![]() |
d210ddcca9 | ||
![]() |
97ee546bdf | ||
![]() |
2af26560b3 | ||
![]() |
0977002c06 | ||
![]() |
fa83e06346 | ||
![]() |
03221547bb | ||
![]() |
7cc26feaa9 | ||
![]() |
9494c9c913 | ||
![]() |
ff31213aa8 | ||
![]() |
aa2f75f67f | ||
![]() |
4d61a4cffd | ||
![]() |
3e520604ba | ||
![]() |
959ff2033d | ||
![]() |
d653e05a61 | ||
![]() |
6685b742ed | ||
![]() |
7a4388c928 | ||
![]() |
6338706967 | ||
![]() |
3718c75cee | ||
![]() |
905658073a | ||
![]() |
9806334517 | ||
![]() |
0a4efa5641 | ||
![]() |
994d600221 | ||
![]() |
d5774eee0c | ||
![]() |
87501c9353 | ||
![]() |
5fc3b4950b | ||
![]() |
c28e4819d9 | ||
![]() |
e7cf51d662 | ||
![]() |
4c9caff668 | ||
![]() |
9236f49b36 | ||
![]() |
868fcd45ba | ||
![]() |
5a21c8598a | ||
![]() |
1e829c9268 | ||
![]() |
9fc3ab904b | ||
![]() |
d827f9758f | ||
![]() |
ff1e9731bc | ||
![]() |
90fde639ff | ||
![]() |
5fc90903fb | ||
![]() |
81c2ad4a2d | ||
![]() |
314434951d | ||
![]() |
e16882953d | ||
![]() |
86bcf90c66 | ||
![]() |
961a017f19 | ||
![]() |
5a386d6401 | ||
![]() |
feaea1fb64 | ||
![]() |
7242a2ce03 | ||
![]() |
db6e2271dc | ||
![]() |
d764537143 | ||
![]() |
8c35a38c47 | ||
![]() |
7ee751c88f | ||
![]() |
b715b36a5f | ||
![]() |
8023233bb2 | ||
![]() |
2a86b3b5b0 | ||
![]() |
ca1505304e | ||
![]() |
df9260e918 | ||
![]() |
df9aba0c6e | ||
![]() |
6b87dcb58f | ||
![]() |
d0aebd17fa | ||
![]() |
d9abf08184 | ||
![]() |
a27a1ab079 | ||
![]() |
aa758e3c2a | ||
![]() |
664944f0c5 | ||
![]() |
9298438f98 | ||
![]() |
4f7926aec6 | ||
![]() |
148f4afb25 | ||
![]() |
11a3039775 | ||
![]() |
a4289815ba | ||
![]() |
088067d335 | ||
![]() |
9fb31c46c8 | ||
![]() |
4f09823123 | ||
![]() |
1a2ca3e244 | ||
![]() |
bb1d8fbd38 | ||
![]() |
4558929c52 | ||
![]() |
0722748997 | ||
![]() |
e08f0059bb | ||
![]() |
6d1fecbdac | ||
![]() |
79d3eb153c | ||
![]() |
80dae7c6e2 | ||
![]() |
2d39b91764 | ||
![]() |
5094febcb4 | ||
![]() |
eeee68dbd1 | ||
![]() |
539e14d39c | ||
![]() |
9784f2aed3 | ||
![]() |
611904cef9 | ||
![]() |
8a9a080745 | ||
![]() |
a23ad88769 | ||
![]() |
83427aaaba | ||
![]() |
61672021ef | ||
![]() |
fb6157af05 | ||
![]() |
f655936741 | ||
![]() |
ab7897351a | ||
![]() |
216bfe39ae | ||
![]() |
0fa2b51a79 | ||
![]() |
cbb3a09dd8 | ||
![]() |
2e552c186d | ||
![]() |
1b3f546669 | ||
![]() |
402bc73dce | ||
![]() |
bb5851f800 | ||
![]() |
a407fa1f76 | ||
![]() |
4b465ec917 | ||
![]() |
e7cc03297b | ||
![]() |
051d1a7535 | ||
![]() |
97172f9596 | ||
![]() |
f582820443 | ||
![]() |
8fb9f785b9 | ||
![]() |
2c8e54bb1e | ||
![]() |
932d1e2687 | ||
![]() |
202d4c44df | ||
![]() |
4d77c429bd | ||
![]() |
41447c5ed4 | ||
![]() |
84fe78243a | ||
![]() |
4d8754aad2 | ||
![]() |
541d37c674 | ||
![]() |
6a00699472 | ||
![]() |
ff8162890d | ||
![]() |
daddcd34da | ||
![]() |
5c7ced30df | ||
![]() |
9741ce6b92 | ||
![]() |
67281fe06a | ||
![]() |
2d905a0270 | ||
![]() |
374b124cf8 | ||
![]() |
74ae5e1426 | ||
![]() |
0a6ae7a4ee | ||
![]() |
faf72d1e99 | ||
![]() |
cce5e4e114 | ||
![]() |
293f63017f | ||
![]() |
2e31f4f57c | ||
![]() |
967133162e | ||
![]() |
dc28c25c65 | ||
![]() |
70b64dc3d3 | ||
![]() |
b52ca9b1af | ||
![]() |
46b4453ccd | ||
![]() |
d5b528d375 | ||
![]() |
475cd033d9 | ||
![]() |
3970a67df3 | ||
![]() |
286db405ca | ||
![]() |
d40f5b1f24 | ||
![]() |
7a77f61ee6 | ||
![]() |
841e14a093 | ||
![]() |
2ff077bb1c | ||
![]() |
513b639bce | ||
![]() |
f93f213a31 | ||
![]() |
10b4c29bda | ||
![]() |
b1e98ca913 | ||
![]() |
2d5a5277ae | ||
![]() |
519a5f3773 | ||
![]() |
1f4d491694 | ||
![]() |
8ce4f1e373 | ||
![]() |
3395b860e4 |
6
.gitignore
vendored
@ -11,6 +11,7 @@ env/
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
.roo
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
@ -28,3 +29,8 @@ workdir/
|
||||
|
||||
# cache
|
||||
unsloth_compiled_cache
|
||||
|
||||
# 测试和参考代码
|
||||
test.ipynb
|
||||
test.py
|
||||
refer/
|
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 C-a-r-r-y
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
109
README.md
Normal file
@ -0,0 +1,109 @@
|
||||
# 基于文档驱动的自适应编码大模型微调框架
|
||||
## 项目背景
|
||||
本项目是广州大学计算机科学与网络工程学院计算机科学与技术专业某2025届本科生的毕业设计。
|
||||
|
||||
## 论文与答辩资源
|
||||
Latex编译要求:
|
||||
- 操作系统:Windows 11
|
||||
- TeX发行版:TeXLive 2023+
|
||||
- 编译引擎:XeLaTeX
|
||||
- 字体要求:方正楷体简体(需预先安装)
|
||||
|
||||
相关latex资源:
|
||||
- 论文LaTeX源码位于`paper/`目录,
|
||||
- 答辩PPT相关资源位于`report/`目录
|
||||
|
||||
## 项目简介
|
||||
|
||||
### 项目概述
|
||||
本项目是一个基于文档驱动的自适应编码大模型微调框架,通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
|
||||
|
||||
### 核心功能
|
||||
- 文档解析与语料生成
|
||||
- 大语言模型高效微调
|
||||
- 交互式训练与推理界面
|
||||
- 训练过程可视化监控
|
||||
|
||||
## 技术架构
|
||||
|
||||
### 系统原理
|
||||
```
|
||||
[前端界面] -> [数据集生成] -> [模型微调] -> [数据存储]
|
||||
↑ ↑ ↑ ↑
|
||||
│ │ │ │
|
||||
[Gradio] [LangChain] [unsloth/QLoRA] [SQLite/TinyDB]
|
||||
```
|
||||
|
||||
### 技术栈
|
||||
- **前端界面**: Gradio构建的交互式Web界面
|
||||
- **模型微调**: 基于unsloth框架的QLoRA高效微调
|
||||
- **数据存储**: SQLite(结构化数据) + TinyDB(非结构化数据)
|
||||
- **工作流引擎**: LangChain实现文档解析与语料生成
|
||||
- **训练监控**: TensorBoard集成
|
||||
|
||||
## 功能模块
|
||||
|
||||
### 1. 模型管理
|
||||
- 支持多种格式的大语言模型加载
|
||||
- 模型信息查看与状态管理
|
||||
- 模型卸载与内存释放
|
||||
|
||||
### 2. 模型推理
|
||||
- 对话式交互界面
|
||||
- 流式响应输出
|
||||
- 历史对话管理
|
||||
|
||||
### 3. 模型微调
|
||||
- 训练参数配置(学习率、batch size等)
|
||||
- LoRA参数配置(秩、alpha等)
|
||||
- 训练过程实时监控
|
||||
- 训练中断与恢复
|
||||
|
||||
### 4. 数据集生成
|
||||
- 文档解析与清洗
|
||||
- 指令-响应对生成
|
||||
- 数据集质量评估
|
||||
- 数据集版本管理
|
||||
|
||||
## 微调技术
|
||||
|
||||
### QLoRA原理
|
||||
QLoRA(Quantized Low-Rank Adaptation)是一种高效的大模型微调技术,核心特点包括:
|
||||
1. **4-bit量化**: 将预训练模型量化为4-bit表示,大幅减少显存占用
|
||||
2. **低秩适配**: 通过低秩矩阵分解(LoRA)实现参数高效更新
|
||||
3. **内存优化**: 使用梯度检查点等技术进一步降低显存需求
|
||||
|
||||
### 参数配置
|
||||
- **学习率**: 建议2e-5到2e-4
|
||||
- **LoRA秩**: 控制适配器复杂度(建议16-64)
|
||||
- **LoRA Alpha**: 控制适配器更新幅度(通常设为秩的1-2倍)
|
||||
|
||||
### 训练监控
|
||||
- **TensorBoard集成**: 实时查看损失曲线、学习率等指标
|
||||
- **日志记录**: 训练过程详细日志保存
|
||||
- **模型检查点**: 定期保存中间权重
|
||||
|
||||
## 快速开始
|
||||
|
||||
0. 实验要求:
|
||||
* 必须要使用linux系统,尽管unsloth框架提供了Windows版本
|
||||
* 必须使用NVIDIA GPU,unsloth框架依赖NVIDIA GPU
|
||||
* 实验的显存需求不高,本人使用6GB显存即可完成实验,若显存不足6GB可更换3B以下参数量的模型实验
|
||||
|
||||
1. 安装依赖:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. 启动应用:
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
3. 访问Web界面:
|
||||
```
|
||||
http://localhost:7860
|
||||
```
|
||||
|
||||
## 许可证
|
||||
MIT License
|
12
db/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from .init_db import load_sqlite_engine, initialize_sqlite_db
|
||||
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
||||
from .dataset_store import get_all_dataset, save_dataset
|
||||
|
||||
__all__ = [
|
||||
"load_sqlite_engine",
|
||||
"initialize_sqlite_db",
|
||||
"get_prompt_tinydb",
|
||||
"initialize_prompt_store",
|
||||
"get_all_dataset",
|
||||
"save_dataset"
|
||||
]
|
81
db/dataset_store.py
Normal file
@ -0,0 +1,81 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tinydb import TinyDB, Query
|
||||
from tinydb.storages import MemoryStorage
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset import Dataset, DatasetItem, Q_A
|
||||
|
||||
def get_all_dataset(workdir: str) -> TinyDB:
|
||||
"""
|
||||
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径
|
||||
|
||||
Returns:
|
||||
TinyDB: 包含所有数据集对象的TinyDB对象
|
||||
"""
|
||||
dataset_dir = os.path.join(workdir, "dataset")
|
||||
if not os.path.exists(dataset_dir):
|
||||
return TinyDB(storage=MemoryStorage)
|
||||
|
||||
db = TinyDB(storage=MemoryStorage)
|
||||
for filename in os.listdir(dataset_dir):
|
||||
if filename.endswith(".json"):
|
||||
filepath = os.path.join(dataset_dir, filename)
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
db.insert(data)
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
print(f"Error loading dataset file {filename}: {str(e)}")
|
||||
continue
|
||||
|
||||
return db
|
||||
|
||||
|
||||
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
|
||||
"""
|
||||
将TinyDB中的数据集保存为单独的json文件
|
||||
|
||||
Args:
|
||||
db (TinyDB): 包含数据集对象的TinyDB实例
|
||||
workdir (str): 工作目录路径
|
||||
name (str, optional): 要保存的数据集名称,None表示保存所有
|
||||
"""
|
||||
dataset_dir = os.path.join(workdir, "dataset")
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
datasets = db.all() if name is None else db.search(Query().name == name)
|
||||
|
||||
for dataset in datasets:
|
||||
try:
|
||||
filename = f"{dataset.get(dataset['name'])}.json"
|
||||
filepath = os.path.join(dataset_dir, filename)
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取所有数据集
|
||||
datasets = get_all_dataset(workdir)
|
||||
# 打印结果
|
||||
print(f"Found {len(datasets)} datasets:")
|
||||
for ds in datasets.all():
|
||||
print(f"- {ds['name']} (ID: {ds['id']})")
|
||||
|
||||
# 询问要保存的数据集名称
|
||||
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
|
||||
|
||||
# 保存数据集到文件
|
||||
save_dataset(datasets, workdir, name)
|
||||
print(f"Datasets {'all' if name is None else name} saved to json files")
|
79
db/init_db.py
Normal file
@ -0,0 +1,79 @@
|
||||
import os
|
||||
import sys
|
||||
from sqlmodel import SQLModel, create_engine, Session
|
||||
from sqlmodel import select
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset_generation import APIProvider
|
||||
|
||||
# 全局变量,用于存储数据库引擎实例
|
||||
_engine: Optional[Engine] = None
|
||||
|
||||
def load_sqlite_engine(workdir: str) -> Engine:
|
||||
"""
|
||||
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||||
|
||||
Returns:
|
||||
Engine: SQLAlchemy 数据库引擎实例。
|
||||
"""
|
||||
global _engine
|
||||
if not _engine:
|
||||
# 创建数据库目录(如果不存在)
|
||||
db_dir = os.path.join(workdir, "db")
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
# 定义数据库文件路径
|
||||
db_path = os.path.join(db_dir, "db.sqlite")
|
||||
# 创建数据库URL
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
# 创建数据库引擎
|
||||
_engine = create_engine(db_url)
|
||||
return _engine
|
||||
|
||||
def initialize_sqlite_db(engine: Engine) -> None:
|
||||
"""
|
||||
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
|
||||
|
||||
Args:
|
||||
engine (Engine): SQLAlchemy 数据库引擎实例。
|
||||
"""
|
||||
# 创建所有表结构
|
||||
SQLModel.metadata.create_all(engine)
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
# 从环境变量中获取API相关配置
|
||||
api_key = os.getenv("API_KEY")
|
||||
base_url = os.getenv("BASE_URL")
|
||||
model_id = os.getenv("MODEL_ID")
|
||||
|
||||
# 如果所有必要的环境变量都存在,则插入初始数据
|
||||
if api_key and base_url and model_id:
|
||||
with Session(engine) as session:
|
||||
# 查询是否已存在APIProvider记录
|
||||
statement = select(APIProvider).limit(1)
|
||||
existing_provider = session.exec(statement).first()
|
||||
|
||||
# 如果不存在,则插入新的APIProvider记录
|
||||
if not existing_provider:
|
||||
provider = APIProvider(
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
api_key=api_key
|
||||
)
|
||||
session.add(provider)
|
||||
session.commit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取数据库引擎
|
||||
engine = load_sqlite_engine(workdir)
|
||||
# 初始化数据库
|
||||
initialize_sqlite_db(engine)
|
62
db/prompt_store.py
Normal file
@ -0,0 +1,62 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from tinydb import TinyDB, Query
|
||||
from tinydb.storages import JSONStorage
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.prompt import promptTempleta
|
||||
|
||||
# 全局变量,用于存储TinyDB实例
|
||||
_db_instance: Optional[TinyDB] = None
|
||||
# 自定义存储类,用于格式化JSON数据
|
||||
def get_prompt_tinydb(workdir: str) -> TinyDB:
|
||||
"""
|
||||
获取TinyDB实例。如果实例尚未创建,则创建一个新的并返回。
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||||
|
||||
Returns:
|
||||
TinyDB: TinyDB数据库实例
|
||||
"""
|
||||
global _db_instance
|
||||
if not _db_instance:
|
||||
# 创建数据库目录(如果不存在)
|
||||
db_dir = os.path.join(workdir, "db")
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
# 定义数据库文件路径
|
||||
db_path = os.path.join(db_dir, "prompts.json")
|
||||
# 创建TinyDB实例
|
||||
_db_instance = TinyDB(db_path)
|
||||
return _db_instance
|
||||
|
||||
def initialize_prompt_store(db: TinyDB) -> None:
|
||||
"""
|
||||
初始化prompt模板存储
|
||||
|
||||
Args:
|
||||
db (TinyDB): TinyDB数据库实例
|
||||
"""
|
||||
# 检查数据库是否为空
|
||||
if not db.all(): # 如果数据库中没有数据
|
||||
db.insert(promptTempleta(
|
||||
id=1,
|
||||
name="default",
|
||||
description="默认提示词模板",
|
||||
content="""项目名为:{project_name}
|
||||
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
|
||||
文档节选:{document_slice}""").model_dump())
|
||||
# 如果数据库中已有数据,则跳过插入
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取数据库实例
|
||||
db = get_prompt_tinydb(workdir)
|
||||
# 初始化prompt存储
|
||||
initialize_prompt_store(db)
|
7
frontend/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .train_page import *
|
||||
from .chat_page import *
|
||||
from .setting_page import *
|
||||
from .model_manage_page import *
|
||||
from .dataset_manage_page import *
|
||||
from .dataset_generate_page import *
|
||||
from .prompt_manage_page import *
|
97
frontend/chat_page.py
Normal file
@ -0,0 +1,97 @@
|
||||
import gradio as gr
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread # 需要导入 Thread
|
||||
from transformers import TextIteratorStreamer # 使用 TextIteratorStreamer
|
||||
|
||||
# 假设 global_var.py 在父目录
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer # 假设这两个函数能正确获取模型和分词器
|
||||
|
||||
def chat_page():
|
||||
with gr.Blocks() as demo:
|
||||
# 聊天框
|
||||
gr.Markdown("## 对话")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
chatbot = gr.Chatbot(type="messages", label="聊天机器人")
|
||||
msg = gr.Textbox(label="输入消息")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
# 新增超参数输入框
|
||||
max_new_tokens_input = gr.Textbox(label="最大生成长度", value="1024")
|
||||
temperature_input = gr.Textbox(label="温度 (Temperature)", value="0.8")
|
||||
top_p_input = gr.Textbox(label="Top-p 采样", value="0.95")
|
||||
repetition_penalty_input = gr.Textbox(label="重复惩罚", value="1.1")
|
||||
clear = gr.Button("清除对话")
|
||||
|
||||
def user(user_message, history: list):
|
||||
return "", history + [{"role": "user", "content": user_message}]
|
||||
|
||||
def bot(history: list, max_new_tokens, temperature, top_p, repetition_penalty):
|
||||
model = get_model()
|
||||
tokenizer = get_tokenizer()
|
||||
if not history:
|
||||
yield history
|
||||
return
|
||||
|
||||
if model is None or tokenizer is None:
|
||||
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"})
|
||||
yield history
|
||||
return
|
||||
|
||||
try:
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
history,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
# 将超参数转换为数值类型
|
||||
generation_kwargs = dict(
|
||||
input_ids=inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=int(max_new_tokens),
|
||||
temperature=float(temperature),
|
||||
top_p=float(top_p),
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
do_sample=True,
|
||||
use_cache=False
|
||||
)
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
history.append({"role": "assistant", "content": ""})
|
||||
|
||||
for new_text in streamer:
|
||||
if new_text:
|
||||
history[-1]["content"] += new_text
|
||||
yield history
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
|
||||
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
|
||||
history[-1]["content"] = error_message
|
||||
else:
|
||||
history.append({"role": "assistant", "content": error_message})
|
||||
yield history
|
||||
|
||||
# 更新 .then() 调用,将超参数传递给 bot 函数
|
||||
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
||||
bot, [chatbot, max_new_tokens_input, temperature_input, top_p_input, repetition_penalty_input], chatbot
|
||||
)
|
||||
|
||||
clear.click(lambda: [], None, chatbot, queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from model_manage_page import model_manage_page
|
||||
# 装载两个页面
|
||||
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
|
||||
demo.queue()
|
||||
demo.launch()
|
189
frontend/dataset_generate_page.py
Normal file
@ -0,0 +1,189 @@
|
||||
import gradio as gr
|
||||
import sys
|
||||
import json
|
||||
from tinydb import Query
|
||||
from pathlib import Path
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sqlmodel import Session, select
|
||||
from schema import Dataset, DatasetItem, Q_A
|
||||
from db.dataset_store import get_all_dataset
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
|
||||
from db import save_dataset
|
||||
from tools import call_openai_api, process_markdown_file, generate_json_example
|
||||
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
|
||||
|
||||
def dataset_generate_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 数据集生成")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
docs_list = [str(doc.name) for doc in get_docs()]
|
||||
initial_doc = docs_list[0] if docs_list else None
|
||||
prompts = get_prompt_store().all()
|
||||
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
|
||||
initial_prompt = prompt_list[0] if prompt_list else None
|
||||
|
||||
# 初始化Dataframe的值
|
||||
initial_dataframe_value = []
|
||||
if initial_prompt:
|
||||
selected_prompt_id = int(initial_prompt.split(" ")[0])
|
||||
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
|
||||
prompt_content = prompt_data["content"]
|
||||
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||
input_variables = prompt_template.input_variables
|
||||
input_variables.remove("document_slice")
|
||||
initial_dataframe_value = [[var, ""] for var in input_variables]
|
||||
# 从数据库获取API Provider列表
|
||||
with Session(get_sql_engine()) as session:
|
||||
providers = session.exec(select(APIProvider)).all()
|
||||
api_list = [f"{p.id} {p.model_id}" for p in providers]
|
||||
initial_api = api_list[0] if api_list else None
|
||||
|
||||
api_dropdown = gr.Dropdown(
|
||||
choices=api_list,
|
||||
value=initial_api,
|
||||
label="选择API",
|
||||
interactive=True
|
||||
)
|
||||
doc_dropdown = gr.Dropdown(
|
||||
choices=docs_list,
|
||||
value=initial_doc,
|
||||
label="选择文档",
|
||||
interactive=True
|
||||
)
|
||||
prompt_dropdown = gr.Dropdown(
|
||||
choices=prompt_list,
|
||||
value=initial_prompt,
|
||||
label="选择模板",
|
||||
interactive=True
|
||||
)
|
||||
rounds_input = gr.Number(
|
||||
value=1,
|
||||
label="生成轮次",
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
step=1,
|
||||
interactive=True
|
||||
)
|
||||
concurrency_input = gr.Number(
|
||||
value=1,
|
||||
label="并发数",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=1,
|
||||
interactive=True,
|
||||
visible=False
|
||||
)
|
||||
dataset_name_input = gr.Textbox(
|
||||
label="数据集名称",
|
||||
placeholder="输入数据集保存名称",
|
||||
interactive=True
|
||||
)
|
||||
prompt_choice = gr.State(value=initial_prompt)
|
||||
generate_button = gr.Button("生成数据集",variant="primary")
|
||||
doc_choice = gr.State(value=initial_doc)
|
||||
output_text = gr.Textbox(label="生成结果", interactive=False)
|
||||
api_choice = gr.State(value=initial_api)
|
||||
with gr.Column(scale=2):
|
||||
variables_dataframe = gr.Dataframe(
|
||||
headers=["变量名", "变量值"],
|
||||
datatype=["str", "str"],
|
||||
interactive=True,
|
||||
label="变量列表",
|
||||
value=initial_dataframe_value # 设置初始化数据
|
||||
)
|
||||
def on_doc_change(selected_doc):
|
||||
return selected_doc
|
||||
|
||||
def on_api_change(selected_api):
|
||||
return selected_api
|
||||
|
||||
def on_prompt_change(selected_prompt):
|
||||
if not selected_prompt:
|
||||
return None, []
|
||||
selected_prompt_id = int(selected_prompt.split(" ")[0])
|
||||
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
|
||||
prompt_content = prompt_data["content"]
|
||||
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||
input_variables = prompt_template.input_variables
|
||||
input_variables.remove("document_slice")
|
||||
dataframe_value = [] if input_variables is None else input_variables
|
||||
dataframe_value = [[var, ""] for var in input_variables]
|
||||
return selected_prompt, dataframe_value
|
||||
|
||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
|
||||
dataset_db = get_datasets()
|
||||
if not dataset_db.search(Query().name == dataset_name):
|
||||
raise gr.Error("数据集名称已存在")
|
||||
|
||||
doc = [i for i in get_docs() if i.name == doc_state][0]
|
||||
doc_files = doc.markdown_files
|
||||
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
|
||||
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
||||
prompt = PromptTemplate.from_template(prompt["content"])
|
||||
with Session(get_sql_engine()) as session:
|
||||
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||
|
||||
variables_dict = {}
|
||||
for _, row in variables_dataframe.iterrows():
|
||||
var_name = row['变量名'].strip()
|
||||
var_value = row['变量值'].strip()
|
||||
if var_name:
|
||||
variables_dict[var_name] = var_value
|
||||
|
||||
prompt = prompt.partial(**variables_dict)
|
||||
|
||||
dataset = Dataset(
|
||||
name=dataset_name,
|
||||
model_id=[api_provider.model_id],
|
||||
source_doc=doc,
|
||||
dataset_items=[]
|
||||
)
|
||||
|
||||
total_slices = len(document_slice_list)
|
||||
for i, document_slice in enumerate(document_slice_list):
|
||||
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
|
||||
request = LLMRequest(api_provider=api_provider,
|
||||
prompt=prompt.format(document_slice=document_slice),
|
||||
format=generate_json_example(DatasetItem))
|
||||
call_openai_api(request, rounds)
|
||||
|
||||
for resp in request.response:
|
||||
try:
|
||||
content = json.loads(resp.content)
|
||||
dataset_item = DatasetItem(
|
||||
message=[Q_A(
|
||||
question=content.get("question", ""),
|
||||
answer=content.get("answer", "")
|
||||
)]
|
||||
)
|
||||
dataset.dataset_items.append(dataset_item)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse response: {e}")
|
||||
|
||||
# 保存数据集到TinyDB
|
||||
dataset_db.insert(dataset.model_dump())
|
||||
|
||||
save_dataset(dataset_db,get_workdir(),dataset_name)
|
||||
|
||||
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
|
||||
|
||||
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
||||
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
||||
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
|
||||
|
||||
generate_button.click(
|
||||
on_generate_click,
|
||||
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
|
||||
outputs=output_text
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from global_var import init_global_var
|
||||
init_global_var("workdir")
|
||||
demo = dataset_generate_page()
|
||||
demo.launch()
|
55
frontend/dataset_manage_page.py
Normal file
@ -0,0 +1,55 @@
|
||||
import gradio as gr
|
||||
from global_var import get_datasets
|
||||
from tinydb import Query
|
||||
|
||||
def dataset_manage_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 数据集管理")
|
||||
with gr.Row():
|
||||
# 获取数据集列表并设置初始值
|
||||
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
||||
initial_dataset = datasets_list[0] if datasets_list else None
|
||||
|
||||
dataset_dropdown = gr.Dropdown(
|
||||
choices=datasets_list,
|
||||
value=initial_dataset, # 设置初始选中项
|
||||
label="选择数据集",
|
||||
allow_custom_value=True,
|
||||
interactive=True
|
||||
)
|
||||
|
||||
# 添加数据集展示组件
|
||||
qa_dataset = gr.Dataset(
|
||||
components=["text", "text"],
|
||||
label="问答数据",
|
||||
headers=["问题", "答案"],
|
||||
samples=[["示例问题", "示例答案"]],
|
||||
samples_per_page=20,
|
||||
)
|
||||
|
||||
def update_qa_display(dataset_name):
|
||||
if not dataset_name:
|
||||
return {"samples": [], "__type__": "update"}
|
||||
|
||||
# 从数据库获取数据集
|
||||
Dataset = Query()
|
||||
ds = get_datasets().get(Dataset.name == dataset_name)
|
||||
if not ds:
|
||||
return {"samples": [], "__type__": "update"}
|
||||
|
||||
# 提取所有Q_A数据
|
||||
qa_list = []
|
||||
for item in ds["dataset_items"]:
|
||||
for qa in item["message"]:
|
||||
qa_list.append([qa["question"], qa["answer"]])
|
||||
|
||||
return {"samples": qa_list, "__type__": "update"}
|
||||
|
||||
# 绑定事件,更新QA数据显示
|
||||
dataset_dropdown.change(
|
||||
update_qa_display,
|
||||
inputs=dataset_dropdown,
|
||||
outputs=qa_dataset
|
||||
)
|
||||
|
||||
return demo
|
103
frontend/model_manage_page.py
Normal file
@ -0,0 +1,103 @@
|
||||
import gradio as gr
|
||||
import os # 导入os模块以便扫描文件夹
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unsloth import FastLanguageModel
|
||||
import torch
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
|
||||
from train import get_model_name
|
||||
|
||||
def model_manage_page():
|
||||
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
|
||||
models_dir = os.path.join(workdir, "models")
|
||||
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))] # 扫描models文件夹下的所有子文件夹
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 模型管理")
|
||||
state_output = gr.Label(label="当前状态",value="当前未加载模型") # 将 Textbox 改为 Label
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
model_select_dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中,并设置为可选
|
||||
max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0)
|
||||
load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True)
|
||||
with gr.Column(scale=1):
|
||||
load_button = gr.Button("加载模型", variant="primary")
|
||||
unload_button = gr.Button("卸载模型", variant="stop")
|
||||
refresh_button = gr.Button("刷新模型列表", variant="secondary") # 新增刷新按钮
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
|
||||
save_method_dropdown = gr.Dropdown(
|
||||
choices=["default", "merged_16bit", "merged_4bit", "lora", "gguf", "gguf_q4_k_m", "gguf_f16"],
|
||||
label="保存模式",
|
||||
value="default"
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
save_button = gr.Button("保存模型", variant="secondary")
|
||||
|
||||
def load_model(selected_model, max_seq_length, load_in_4bit):
|
||||
try:
|
||||
# 判空操作,如果模型已加载,则先卸载
|
||||
if get_model() is not None:
|
||||
unload_model()
|
||||
|
||||
model_path = os.path.join(models_dir, selected_model)
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=model_path,
|
||||
max_seq_length=max_seq_length,
|
||||
load_in_4bit=load_in_4bit,
|
||||
)
|
||||
set_model(model)
|
||||
set_tokenizer(tokenizer)
|
||||
return f"模型 {get_model_name(model)} 已加载"
|
||||
except Exception as e:
|
||||
return f"加载模型时出错: {str(e)}"
|
||||
|
||||
load_button.click(fn=load_model, inputs=[model_select_dropdown, max_seq_length_input, load_in_4bit_input], outputs=state_output)
|
||||
|
||||
def unload_model():
|
||||
try:
|
||||
# 将模型移动到 CPU
|
||||
model = get_model()
|
||||
if model is not None:
|
||||
model.cpu()
|
||||
|
||||
# 清空 CUDA 缓存
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 将模型和tokenizer设置为 None
|
||||
set_model(None)
|
||||
set_tokenizer(None)
|
||||
|
||||
return "当前未加载模型"
|
||||
except Exception as e:
|
||||
return f"卸载模型时出错: {str(e)}"
|
||||
|
||||
unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
|
||||
|
||||
from train.save_model import save_model_to_dir
|
||||
|
||||
def save_model(save_model_name, save_method):
|
||||
return save_model_to_dir(save_model_name, models_dir, get_model(), get_tokenizer(), save_method)
|
||||
|
||||
save_button.click(fn=save_model, inputs=[save_model_name_input, save_method_dropdown], outputs=state_output)
|
||||
|
||||
def refresh_model_list():
|
||||
try:
|
||||
nonlocal model_folders
|
||||
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))]
|
||||
return gr.Dropdown(choices=model_folders)
|
||||
except Exception as e:
|
||||
return f"刷新模型列表时出错: {str(e)}"
|
||||
|
||||
refresh_button.click(fn=refresh_model_list, inputs=None, outputs=model_select_dropdown)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo = model_manage_page()
|
||||
demo.queue()
|
||||
demo.launch()
|
130
frontend/prompt_manage_page.py
Normal file
@ -0,0 +1,130 @@
|
||||
import gradio as gr
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_prompt_store
|
||||
from schema.prompt import promptTempleta
|
||||
def prompt_manage_page():
|
||||
def get_prompts() -> List[List[str]]:
|
||||
selected_row = None
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
prompts = db.all()
|
||||
return [
|
||||
[p["id"], p["name"], p["description"], p["content"]]
|
||||
for p in prompts
|
||||
]
|
||||
except Exception as e:
|
||||
raise gr.Error(f"获取提示词失败: {str(e)}")
|
||||
|
||||
def add_prompt(name, description, content):
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
new_prompt = promptTempleta(
|
||||
name=name if name else "",
|
||||
description=description if description else "",
|
||||
content=content if content else ""
|
||||
)
|
||||
prompt_id = db.insert(new_prompt.model_dump())
|
||||
# 更新ID
|
||||
db.update({"id": prompt_id}, doc_ids=[prompt_id])
|
||||
return get_prompts(), "", "", "" # 返回清空后的输入框值
|
||||
except Exception as e:
|
||||
raise gr.Error(f"添加失败: {str(e)}")
|
||||
|
||||
def edit_prompt():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要编辑的行")
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
db.update({
|
||||
"name": selected_row[1] if selected_row[1] else "",
|
||||
"description": selected_row[2] if selected_row[2] else "",
|
||||
"content": selected_row[3] if selected_row[3] else ""
|
||||
}, doc_ids=[selected_row[0]])
|
||||
return get_prompts()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"编辑失败: {str(e)}")
|
||||
|
||||
def delete_prompt():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要删除的行")
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
db.remove(doc_ids=[selected_row[0]])
|
||||
return get_prompts()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"删除失败: {str(e)}")
|
||||
|
||||
selected_row = None # 保存当前选中行的全局变量
|
||||
|
||||
def select_record(dataFrame ,evt: gr.SelectData):
|
||||
global selected_row
|
||||
selected_row = dataFrame.iloc[evt.index[0]].tolist()
|
||||
selected_row[0] = int(selected_row[0])
|
||||
print(selected_row)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 提示词模板管理")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
name_input = gr.Textbox(label="模板名称")
|
||||
description_input = gr.Textbox(label="模板描述")
|
||||
content_input = gr.Textbox(label="模板内容", lines=10)
|
||||
add_button = gr.Button("添加新模板", variant="primary")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
prompt_table = gr.DataFrame(
|
||||
headers=["id", "名称", "描述", "内容"],
|
||||
datatype=["number", "str", "str", "str"],
|
||||
interactive=True,
|
||||
value=get_prompts(),
|
||||
wrap=True,
|
||||
col_count=(4, "auto")
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
refresh_button = gr.Button("刷新数据", variant="secondary")
|
||||
edit_button = gr.Button("编辑选中行", variant="primary")
|
||||
delete_button = gr.Button("删除选中行", variant="stop")
|
||||
|
||||
refresh_button.click(
|
||||
fn=get_prompts,
|
||||
outputs=[prompt_table],
|
||||
queue=False
|
||||
)
|
||||
|
||||
add_button.click(
|
||||
fn=add_prompt,
|
||||
inputs=[name_input, description_input, content_input],
|
||||
outputs=[prompt_table, name_input, description_input, content_input]
|
||||
)
|
||||
|
||||
prompt_table.select(fn=select_record,
|
||||
inputs=[prompt_table],
|
||||
outputs=[],
|
||||
show_progress="hidden")
|
||||
|
||||
edit_button.click(
|
||||
fn=edit_prompt,
|
||||
inputs=[],
|
||||
outputs=[prompt_table]
|
||||
)
|
||||
|
||||
delete_button.click(
|
||||
fn=delete_prompt,
|
||||
inputs=[],
|
||||
outputs=[prompt_table]
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo = prompt_manage_page()
|
||||
demo.queue()
|
||||
demo.launch()
|
131
frontend/setting_page.py
Normal file
@ -0,0 +1,131 @@
|
||||
import gradio as gr
|
||||
from typing import List
|
||||
from sqlmodel import Session, select
|
||||
from schema import APIProvider
|
||||
from global_var import get_sql_engine
|
||||
|
||||
def setting_page():
|
||||
def get_providers() -> List[List[str]]:
|
||||
selected_row = None
|
||||
try: # 添加异常处理
|
||||
with Session(get_sql_engine()) as session:
|
||||
providers = session.exec(select(APIProvider)).all()
|
||||
return [
|
||||
[p.id, p.model_id, p.base_url, p.api_key or ""]
|
||||
for p in providers
|
||||
]
|
||||
except Exception as e:
|
||||
raise gr.Error(f"获取数据失败: {str(e)}")
|
||||
|
||||
def add_provider(model_id, base_url, api_key):
|
||||
try:
|
||||
with Session(get_sql_engine()) as session:
|
||||
new_provider = APIProvider(
|
||||
model_id=model_id if model_id else None,
|
||||
base_url=base_url if base_url else None,
|
||||
api_key=api_key if api_key else None
|
||||
)
|
||||
session.add(new_provider)
|
||||
session.commit()
|
||||
session.refresh(new_provider)
|
||||
return get_providers(), "", "", "" # 返回清空后的输入框值
|
||||
except Exception as e:
|
||||
raise gr.Error(f"添加失败: {str(e)}")
|
||||
|
||||
def edit_provider():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要编辑的行")
|
||||
try:
|
||||
with Session(get_sql_engine()) as session:
|
||||
provider = session.get(APIProvider, selected_row[0])
|
||||
if not provider:
|
||||
raise gr.Error("找不到选中的记录")
|
||||
provider.model_id = selected_row[1] if selected_row[1] else None
|
||||
provider.base_url = selected_row[2] if selected_row[2] else None
|
||||
provider.api_key = selected_row[3] if selected_row[3] else None
|
||||
session.add(provider)
|
||||
session.commit()
|
||||
session.refresh(provider)
|
||||
return get_providers()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"编辑失败: {str(e)}")
|
||||
|
||||
def delete_provider():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要删除的行")
|
||||
try:
|
||||
with Session(get_sql_engine()) as session:
|
||||
provider = session.get(APIProvider, selected_row[0])
|
||||
if not provider:
|
||||
raise gr.Error("找不到选中的记录")
|
||||
session.delete(provider)
|
||||
session.commit()
|
||||
return get_providers()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"删除失败: {str(e)}")
|
||||
|
||||
selected_row = None # 保存当前选中行的全局变量
|
||||
|
||||
def select_record(dataFrame ,evt: gr.SelectData):
|
||||
global selected_row
|
||||
selected_row = dataFrame.iloc[evt.index[0]].tolist()
|
||||
selected_row[0] = int(selected_row[0])
|
||||
print(selected_row)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## API Provider 管理")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
model_id_input = gr.Textbox(label="Model ID")
|
||||
base_url_input = gr.Textbox(label="Base URL")
|
||||
api_key_input = gr.Textbox(label="API Key")
|
||||
add_button = gr.Button("添加新API", variant="primary")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
provider_table = gr.DataFrame(
|
||||
headers=["id", "model id", "base URL", "API Key"],
|
||||
datatype=["number", "str", "str", "str"],
|
||||
interactive=True,
|
||||
value=get_providers(),
|
||||
wrap=True,
|
||||
col_count=(4, "auto")
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
refresh_button = gr.Button("刷新数据", variant="secondary")
|
||||
edit_button = gr.Button("编辑选中行", variant="primary")
|
||||
delete_button = gr.Button("删除选中行", variant="stop")
|
||||
|
||||
refresh_button.click(
|
||||
fn=get_providers,
|
||||
outputs=[provider_table],
|
||||
queue=False # 立即刷新不需要排队
|
||||
)
|
||||
|
||||
add_button.click(
|
||||
fn=add_provider,
|
||||
inputs=[model_id_input, base_url_input, api_key_input],
|
||||
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
|
||||
)
|
||||
|
||||
provider_table.select(fn=select_record,
|
||||
inputs=[provider_table],
|
||||
outputs=[],
|
||||
show_progress="hidden")
|
||||
|
||||
edit_button.click(
|
||||
fn=edit_provider,
|
||||
inputs=[],
|
||||
outputs=[provider_table]
|
||||
)
|
||||
|
||||
delete_button.click(
|
||||
fn=delete_provider,
|
||||
inputs=[],
|
||||
outputs=[provider_table]
|
||||
)
|
||||
|
||||
return demo
|
113
frontend/train_page.py
Normal file
@ -0,0 +1,113 @@
|
||||
import subprocess
|
||||
import os
|
||||
import gradio as gr
|
||||
import sys
|
||||
from tinydb import Query
|
||||
from pathlib import Path
|
||||
from transformers import TrainerCallback
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
||||
from tools import find_available_port
|
||||
from train import train_model
|
||||
|
||||
def train_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 微调")
|
||||
# 获取数据集列表并设置初始值
|
||||
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
||||
initial_dataset = datasets_list[0] if datasets_list else None
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
dataset_dropdown = gr.Dropdown(
|
||||
choices=datasets_list,
|
||||
value=initial_dataset, # 设置初始选中项
|
||||
label="选择数据集",
|
||||
allow_custom_value=True,
|
||||
interactive=True
|
||||
)
|
||||
# 新增超参数输入组件
|
||||
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
||||
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
||||
epoch_input = gr.Number(value=1, label="epoch", precision=0)
|
||||
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
|
||||
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
|
||||
|
||||
train_button = gr.Button("开始微调")
|
||||
|
||||
# 训练状态输出
|
||||
output = gr.Textbox(label="训练状态", interactive=False)
|
||||
with gr.Column(scale=3):
|
||||
# 新增 TensorBoard iframe 显示框
|
||||
tensorboard_iframe = gr.HTML(label="TensorBoard 可视化")
|
||||
|
||||
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||||
# 使用动态传入的超参数
|
||||
learning_rate = float(learning_rate)
|
||||
per_device_train_batch_size = int(per_device_train_batch_size)
|
||||
epoch = int(epoch)
|
||||
save_steps = int(save_steps) # 新增保存步数参数
|
||||
lora_rank = int(lora_rank) # 新增LoRA秩参数
|
||||
|
||||
# 加载数据集
|
||||
dataset = get_datasets().get(Query().name == dataset_name)
|
||||
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
||||
|
||||
# 扫描 training 文件夹并生成递增目录
|
||||
training_dir = get_workdir() + "/training"
|
||||
os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在
|
||||
existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()]
|
||||
next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1
|
||||
new_training_dir = os.path.join(training_dir, str(next_dir_number))
|
||||
|
||||
tensorboard_port = find_available_port(6006) # 从默认端口 6006 开始检测
|
||||
print(f"TensorBoard 将使用端口: {tensorboard_port}")
|
||||
|
||||
tensorboard_logdir = os.path.join(new_training_dir, "logs")
|
||||
os.makedirs(tensorboard_logdir, exist_ok=True) # 确保日志目录存在
|
||||
tensorboard_process = subprocess.Popen(
|
||||
["tensorboard", "--logdir", tensorboard_logdir, "--port", str(tensorboard_port)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
print("TensorBoard 已启动,日志目录:", tensorboard_logdir)
|
||||
|
||||
# 动态生成 TensorBoard iframe
|
||||
tensorboard_url = f"http://localhost:{tensorboard_port}"
|
||||
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>'
|
||||
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
|
||||
|
||||
try:
|
||||
train_model(get_model(), get_tokenizer(),
|
||||
dataset, new_training_dir,
|
||||
learning_rate, per_device_train_batch_size, epoch,
|
||||
save_steps, lora_rank)
|
||||
except Exception as e:
|
||||
raise gr.Error(str(e))
|
||||
finally:
|
||||
# 确保训练结束后终止 TensorBoard 子进程
|
||||
tensorboard_process.terminate()
|
||||
print("TensorBoard 子进程已终止")
|
||||
|
||||
train_button.click(
|
||||
fn=start_training,
|
||||
inputs=[
|
||||
dataset_dropdown,
|
||||
learning_rate_input,
|
||||
per_device_train_batch_size_input,
|
||||
epoch_input,
|
||||
save_steps_input,
|
||||
lora_rank_input
|
||||
],
|
||||
outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from global_var import init_global_var
|
||||
from model_manage_page import model_manage_page
|
||||
init_global_var("workdir")
|
||||
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
|
||||
demo.queue()
|
||||
demo.launch()
|
52
global_var.py
Normal file
@ -0,0 +1,52 @@
|
||||
from db import load_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
||||
from tools import scan_docs_directory
|
||||
|
||||
_prompt_store = None
|
||||
_sql_engine = None
|
||||
_datasets = None
|
||||
_model = None
|
||||
_tokenizer = None
|
||||
_workdir = None
|
||||
def init_global_var(workdir="workdir"):
|
||||
global _prompt_store, _sql_engine, _datasets, _workdir
|
||||
_prompt_store = get_prompt_tinydb(workdir)
|
||||
_sql_engine = load_sqlite_engine(workdir)
|
||||
_datasets = get_all_dataset(workdir)
|
||||
_workdir = workdir
|
||||
|
||||
def get_workdir():
|
||||
return _workdir
|
||||
|
||||
def get_prompt_store():
|
||||
return _prompt_store
|
||||
|
||||
def set_prompt_store(new_prompt_store):
|
||||
global _prompt_store
|
||||
_prompt_store = new_prompt_store
|
||||
|
||||
def get_sql_engine():
|
||||
return _sql_engine
|
||||
|
||||
def set_sql_engine(new_sql_engine):
|
||||
global _sql_engine
|
||||
_sql_engine = new_sql_engine
|
||||
|
||||
def get_docs():
|
||||
global _workdir
|
||||
return scan_docs_directory(_workdir)
|
||||
def get_datasets():
|
||||
return _datasets
|
||||
|
||||
def get_model():
|
||||
return _model
|
||||
|
||||
def set_model(new_model):
|
||||
global _model
|
||||
_model = new_model
|
||||
|
||||
def get_tokenizer():
|
||||
return _tokenizer
|
||||
|
||||
def set_tokenizer(new_tokenizer):
|
||||
global _tokenizer
|
||||
_tokenizer = new_tokenizer
|
29
main.py
Normal file
@ -0,0 +1,29 @@
|
||||
import gradio as gr
|
||||
import train
|
||||
from frontend import *
|
||||
from db import initialize_sqlite_db, initialize_prompt_store
|
||||
from global_var import init_global_var, get_sql_engine, get_prompt_store
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_global_var()
|
||||
initialize_sqlite_db(get_sql_engine())
|
||||
initialize_prompt_store(get_prompt_store())
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("模型管理"):
|
||||
model_manage_page()
|
||||
with gr.TabItem("模型推理"):
|
||||
chat_page()
|
||||
with gr.TabItem("模型微调"):
|
||||
train_page()
|
||||
with gr.TabItem("数据集生成"):
|
||||
dataset_generate_page()
|
||||
with gr.TabItem("数据集管理"):
|
||||
dataset_manage_page()
|
||||
with gr.TabItem("提示词模板管理"):
|
||||
prompt_manage_page()
|
||||
with gr.TabItem("设置"):
|
||||
setting_page()
|
||||
|
||||
app.launch(server_name="0.0.0.0")
|
21
paper/.gitignore
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
# LaTeX临时文件
|
||||
*.aux
|
||||
*.log
|
||||
*.out
|
||||
*.toc
|
||||
*.synctex.gz
|
||||
*.bbl
|
||||
*.blg
|
||||
*.dvi
|
||||
*.fdb_latexmk
|
||||
*.fls
|
||||
*.lof
|
||||
*.lot
|
||||
*.idx
|
||||
*.ilg
|
||||
*.ind
|
||||
*.nav
|
||||
*.snm
|
||||
*.vrb
|
||||
*.xdv
|
||||
*.pdf
|
22
paper/chapters/abstract.tex
Normal file
@ -0,0 +1,22 @@
|
||||
% 摘要
|
||||
\begin{center}
|
||||
{\zihao{3}\textbf{\thesistitle}}\par
|
||||
{\zihao{-4}\songti \quad \classname \quad \studentname \par
|
||||
指导教师:\adviser}
|
||||
\end{center}
|
||||
|
||||
% 中文摘要
|
||||
\begin{onecolabstract}
|
||||
\noindent{}{\zihao{4}\textbf{摘要}}{\quad \songti \zihao{-4}大语言模型(Large Language Models,LLMs)在通用代码生成任务中表现出色,但在处理包含专有知识的企业私有代码库时,其性能往往受限。针对此问题,本文提出并实现了一个基于文档驱动的自适应编码大模型微调框架。该框架的核心创新在于:首先,通过深度解析技术文档(Markdown格式),自动抽取信息并结合预设模板生成高质量的指令微调(Supervised Fine-Tuning,SFT)训练语料;其次,利用参数高效微调技术(如量化低秩微调,Quantized Low-Rank Adaptation,QLoRA)对预训练的编码大模型(以qwen2.5为例)进行针对性优化,使其精准适配私有库特有的语法结构与编程范式;最后,整合了包括数据持久化(SQLite+TinyDB)、训练监控(TensorBoard)和交互式前端(Gradio)在内的完整工作流。实验结果表明,该框架能够有效提升大模型在私有库代码生成任务上的准确性和实用性,为实现企业级软件开发的智能化和高效化提供了一套自动化、可扩展的解决方案。
|
||||
}\par
|
||||
\noindent{}{\zihao{4}\textbf{关键词}}{\quad \zihao{-4}\songti 大语言模型;代码生成;模型微调;文档驱动
|
||||
}\par
|
||||
\end{onecolabstract}
|
||||
|
||||
% 英文摘要
|
||||
\begin{onecolabstract}
|
||||
\noindent{}{ \zihao{4} \textbf{ABSTRACT}}{\quad \zihao{-4}While Large Language Models (LLMs) demonstrate impressive performance in general code generation, their efficacy often diminishes when dealing with enterprise private code repositories that contain proprietary knowledge. To address this challenge, this paper proposes and implements a **document-driven adaptive fine-tuning framework for large code models**. The framework introduces several key innovations: Firstly, it employs deep parsing techniques on technical documentation (in Markdown format) to automatically extract information and combine it with predefined templates, thereby generating high-quality Supervised Fine-Tuning (SFT) training data. Secondly, it leverages parameter-efficient fine-tuning techniques, such as Quantized Low-Rank Adaptation (QLoRA), to specifically optimize a pre-trained large code model (exemplified by Qwen2.5). This optimization enables the model to accurately adapt to the unique syntax, structures, and programming paradigms inherent in private libraries. Finally, the framework integrates a comprehensive workflow encompassing data persistence (using SQLite and TinyDB), training monitoring (via TensorBoard), and an interactive frontend (built with Gradio). Experimental results indicate that this framework significantly enhances the accuracy and practical utility of LLMs in private library code generation tasks, offering an automated and scalable solution for intelligent and efficient enterprise software development.
|
||||
}\par
|
||||
\noindent{}{ \zihao{4}\textbf{KEYWORDS}}{\quad \zihao{-4}Large Language Models; Code Generation; Model Fine-tuning; Document-Driven
|
||||
}\par
|
||||
\end{onecolabstract}
|
18
paper/chapters/acknowledgement.tex
Normal file
@ -0,0 +1,18 @@
|
||||
% 致谢章节
|
||||
|
||||
\section*{致谢}
|
||||
\addcontentsline{toc}{section}{致谢\tiny{\quad.}}
|
||||
|
||||
当笔下最后一个句点落下时,忽然惊觉这段与文字相伴的时光已近尾声。在这场充满探索与成长的旅程中,无数温暖的力量始终环绕左右,虽难以尽述细节,却值得用最真挚的文字向所有给予我支持的人道一声感谢。
|
||||
|
||||
首先要感恩学术道路上的引路人。正是老师们以渊博的学识搭建起知识的阶梯,以严谨的治学态度传递为学之道,让我在迷茫时得以窥见真理的微光,在徘徊时能坚定前行的方向。每一次课堂上的启发、每一次交流中的点拨,都如同一束束光,照亮了我从懵懂到逐渐明晰的成长之路。
|
||||
|
||||
其次要感谢校园里的人文滋养。这方充满活力的天地,不仅用丰富的资源培育着求知的心灵,更以包容的氛围接纳着每一个跃动的梦想。无论是漫步于绿荫环绕的小径时迸发的灵感,还是坐在安静的图书馆里与文字对话的时光,都成为了我青春记忆中不可替代的注脚。
|
||||
|
||||
还要深深致谢生命里的温暖陪伴。家人始终是我最坚实的港湾,他们用无条件的爱构筑起勇气的壁垒,让我能心无旁骛地奔赴理想;朋友的笑容与鼓励如同一曲曲轻快的旋律,在压力袭来时吹散阴霾,让枯燥的学术时光也充满了欢声笑语。那些并肩走过的日子,早已成为岁月馈赠的珍贵礼物。
|
||||
|
||||
最后,要向这段全力以赴的时光致敬。论文写作的过程或许充满艰辛,但每一次与思维的博弈、每一次对完美的追求,都让我更深刻地理解了坚持的意义。此刻的终点亦是新的起点,未来的日子里,我将带着这份感恩与热忱,在更广阔的天地中继续书写属于自己的篇章。
|
||||
|
||||
致谢是终点亦是起点,愿所有给予我温暖的人,都能在各自的星辰大海中闪闪发光。
|
||||
|
||||
如果需要调整情感基调或补充特定场景的表达,欢迎随时告诉我,我可以进一步优化内容。
|
103
paper/chapters/conclusion.tex
Normal file
@ -0,0 +1,103 @@
|
||||
% 第六章 总结与展望
|
||||
|
||||
\section{总结与展望}
|
||||
|
||||
|
||||
\subsection{研究工作总结}
|
||||
本研究围绕基于大语言模型的自动化微调框架展开,旨在解决企业级软件开发中私有代码库智能辅助编码的挑战。
|
||||
通过系统化的设计与实现,本研究取得了以下主要成果:
|
||||
|
||||
\paragraph{创新性架构设计}
|
||||
本研究提出了一种文档驱动的自适应编码大模型微调框架,采用三层架构设计(表现层、业务逻辑层和数据访问层)实现了系统的高内聚低耦合。
|
||||
特别是在数据存储方面,创新性地采用了SQLite与TinyDB相结合的双数据库架构,针对不同类型数据(结构化配置数据与半结构化训练数据)实现了最优化的存储方案。
|
||||
这种设计在保障事务完整性的同时,充分兼顾了半结构化数据处理的敏捷性需求,为系统的稳定运行和灵活扩展奠定了基础。
|
||||
|
||||
\paragraph{高效语料生成技术}
|
||||
本研究开发了一套完整的技术文档处理与训练语料生成方案。
|
||||
通过基于栈结构的Markdown文档解析器,系统能够准确捕获文档的层级结构和内容关系;结合PromptTemplate动态模板技术,实现了提示词的参数化配置;采用基于异步IO架构的OpenAI API协程并发调用机制,显著提升了语料生成效率;同时,通过严格的JSON格式校验与数据验证流程,确保了生成语料的质量和一致性。
|
||||
这套技术方案有效解决了从技术文档到训练语料的自动化转换问题,为模型微调提供了高质量的数据基础。
|
||||
|
||||
\paragraph{参数高效微调实现}
|
||||
在模型训练方面,本研究采用了LoRA(Low-Rank Adaptation)参数高效微调方法,并结合Unsloth优化算子,实现了在有限计算资源条件下的高效模型适配。
|
||||
系统构建了完整的监督式微调(SFT)流程,包括训练数据准备与格式化、LoRA微调方法应用、训练配置优化以及结果保存等环节。
|
||||
通过这种方法,成功将通用大语言模型定向优化为具备私有库代码生成能力的专用模型,在保持模型通用能力的同时,显著提升了其在特定领域的表现。
|
||||
|
||||
\paragraph{交互式前端系统}
|
||||
基于Gradio框架,本研究构建了一个功能完备的交互式前端系统,包含模型管理、模型推理、模型微调、数据集生成、数据集管理、提示词模板管理和系统设置等七个功能模块。
|
||||
系统实现了全局状态管理、前后端数据流控制、流式响应与实时反馈等关键功能,为用户提供了直观、易用的操作界面。
|
||||
这种设计不仅降低了系统使用门槛,还通过可视化配置和实时反馈机制,提升了整体用户体验。
|
||||
|
||||
\paragraph{系统集成与工程实践}
|
||||
本研究成功将文档解析、语料生成、模型微调和交互界面等多个模块集成为一个完整的自动化框架,实现了从技术文档输入到定制化模型输出的全流程自动化。
|
||||
在工程实践层面,系统采用了模块化设计、异常处理机制、数据持久化策略和性能优化措施,确保了系统的稳定性、可扩展性和高效性。
|
||||
这种全面的系统集成为企业级软件开发中的智能编码辅助提供了可行的技术路径。
|
||||
|
||||
总体而言,本研究不仅在技术层面实现了大语言模型微调框架的创新设计与实现,还在应用层面为解决企业私有库代码生成问题提供了系统化解决方案。
|
||||
通过文档驱动的自适应微调方法,成功提升了大语言模型在特定领域的代码生成能力,为企业级软件开发的智能化与高效化提供了有力支持。
|
||||
|
||||
\subsection{不足与局限}
|
||||
基于对项目代码库和论文内容的深入分析,本项目虽然在大语言模型微调框架方面取得了一定成果,但仍存在以下几个方面的不足与局限性:
|
||||
|
||||
\paragraph{文档处理能力的局限性}
|
||||
当前系统在文档处理方面主要支持Markdown格式的技术文档解析,对其他格式文档(如PDF、Word、HTML等)的支持有限。
|
||||
这种单一格式的依赖在实际企业环境中可能造成应用障碍,因为企业技术文档通常以多种格式存在。
|
||||
此外,文档解析过程中缺乏对复杂结构(如嵌套表格、图表等)的有效处理机制,可能导致关键信息的丢失或误解。
|
||||
|
||||
\paragraph{训练语料质量不稳定}
|
||||
生成的训练语料质量高度依赖于原始文档的质量和大模型的能力。
|
||||
在实际应用中,如果原始文档存在描述不清、术语不一致或结构混乱等问题,将直接影响生成的训练语料质量。
|
||||
同时,系统缺乏对生成语料的自动化质量评估机制,难以在大规模语料生成过程中保证数据质量的一致性,这可能导致微调效果的不稳定。
|
||||
|
||||
\paragraph{微调技术的单一性}
|
||||
当前系统主要采用LoRA微调方法,虽然该方法在参数效率上有显著优势,但在处理特定领域深度知识或复杂语义理解任务时可能存在效果不佳的情况。
|
||||
系统未能提供多种微调方法(如P-Tuning、Prefix-Tuning等)的集成支持,限制了用户根据具体需求选择最适合的微调策略的灵活性。
|
||||
|
||||
\paragraph{超参数优化机制不足}
|
||||
微调过程中的超参数选择主要依靠经验设定,缺乏自动化优化机制。
|
||||
这种人工干预的方式不仅增加了用户的使用门槛,也难以保证在不同数据集和任务上获得最优的微调效果。
|
||||
系统未能实现如贝叶斯优化、网格搜索等自动化超参数调优方法,这在处理多样化的企业私有库时可能导致性能次优。
|
||||
|
||||
\paragraph{评估体系不完善}
|
||||
当前系统缺乏对微调后模型效果的全面评估机制,难以客观量化模型在特定领域的提升程度。
|
||||
评估指标单一,主要关注生成代码的语法正确性,而对代码的功能正确性、安全性、可维护性等多维度评估不足。
|
||||
这种评估体系的不完善使得用户难以全面了解微调效果,也为系统的持续优化和迭代带来了挑战。
|
||||
|
||||
\paragraph{多模态融合能力欠缺}
|
||||
系统主要处理文本形式的技术文档,缺乏对图表、UML图、流程图等非文本信息的理解和处理能力。
|
||||
在实际的软件开发文档中,这些非文本信息往往承载了重要的设计思想和架构信息,忽略这部分内容可能导致模型对代码结构和设计模式的理解不足,从而影响生成代码的质量。
|
||||
|
||||
\paragraph{安全性考虑不充分}
|
||||
在处理企业私有库和敏感技术文档时,系统对数据安全和隐私保护的考虑不够全面。
|
||||
缺乏对训练数据的脱敏处理机制,以及对生成模型可能泄露原始训练数据的防护措施。
|
||||
这在处理包含商业机密或敏感信息的企业私有库时,可能带来潜在的安全风险。
|
||||
|
||||
通过识别和分析这些不足与局限性,为未来研究提供了明确的改进方向,包括扩展文档处理能力、提高训练语料质量、丰富微调方法、实现超参数自动优化、降低资源需求、完善评估体系、增强多模态融合能力以及加强安全性保障等方面。
|
||||
这些改进将有助于构建更加完善、实用的大语言模型微调框架,更好地满足企业级软件开发的智能辅助需求。
|
||||
|
||||
\subsection{未来展望}
|
||||
基于当前研究基础和技术发展趋势,本研究框架的后续演进可从以下六个维度展开深度探索:
|
||||
\paragraph{边缘智能集成}
|
||||
研究模型轻量化与边缘计算融合技术,探索基于TensorRT、ONNX Runtime等推理引擎的异构加速方案。
|
||||
通过开发自适应模型切片技术,实现大模型在边缘设备(如Jetson系列)的分布式推理,构建端云协同的智能编码辅助体系,有效降低服务延迟并提升隐私保护等级。
|
||||
|
||||
\paragraph{动态自适应学习机制}
|
||||
设计基于强化学习的在线学习框架,建立代码评审反馈闭环系统。
|
||||
通过开发增量式微调算法(如AdaLoRA),使模型能够动态适应企业代码库的持续演进,形成``开发-训练-优化''的自我迭代生态,解决传统静态模型与动态代码库的版本错配问题。
|
||||
|
||||
\paragraph{智能化伦理安全框架}
|
||||
构建多层次安全防护体系,研发面向代码生成的差分隐私保护模块(DP-SGD)和模型水印技术。
|
||||
引入代码合规性验证层,集成SAST(静态应用安全测试)工具链,确保生成代码符合企业安全规范及行业监管要求,防范潜在的法律风险。
|
||||
|
||||
\paragraph{跨平台生态构建}
|
||||
开发统一的API网关和服务编排引擎,支持与主流IDE(VSCode/IntelliJ/PyCharm)深度集成。
|
||||
研究容器化微服务架构,实现模型服务在Kubernetes集群的弹性伸缩,构建跨Windows/Linux/macOS的多平台支持能力,提升框架的工程适用性。
|
||||
|
||||
\paragraph{开发者知识图谱构建}
|
||||
融合代码抽象语法树(AST)分析与文档实体识别技术,构建企业级开发知识图谱。
|
||||
通过图神经网络(GNN)实现编码规范、API调用关系、架构模式等隐性知识的可视化表达与推理,为开发者提供智能化的代码导航和架构决策支持。
|
||||
|
||||
\paragraph{CI/CD深入集成}
|
||||
建立基于Git版本流的自动化训练数据采集管道,开发代码变更敏感度分析模型。
|
||||
结合主动学习策略(Active Learning)构建智能数据筛选机制,实现训练样本的按需获取和高效标注,形成可持续进化的模型优化体系。
|
||||
|
||||
这些技术方向的突破将推动智能编码辅助系统从单一功能工具向开发全生命周期支持平台演进,最终形成具备自我进化能力的智能软件开发生态系统,为软件工程领域带来范式级变革。
|
36
paper/chapters/cover.tex
Normal file
@ -0,0 +1,36 @@
|
||||
% 封面部分
|
||||
\begin{titlepage}
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
\includegraphics[scale=1]{pic//logo.jpg}
|
||||
\end{figure}
|
||||
\vspace{0.2cm}
|
||||
\centering
|
||||
|
||||
{\zihao{1}\CJKfamily{fzkaiti}\bfseries 本~科~毕~业~论~文(设~计)}
|
||||
|
||||
\vspace{2.5cm}
|
||||
|
||||
\begin{flushleft}
|
||||
{{\CJKfamily{zhkai} \zihao{-2} \hspace{3.5em}课题名称}\quad{\dlmu[8.5cm]{ \scalebox{0.8}[1.0]{ \zihao{3}\CJKfamily{zhkai}\thesistitle} }}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\CJKfamily{zhkai}\zihao{-2} \hspace{3.5em}学\qquad 院}\quad\dlmu[8.5cm]{\zihao{3}\CJKfamily{zhkai}\department}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\CJKfamily{zhkai}\zihao{-2} \hspace{3.5em}专\qquad 业}\quad\dlmu[8.5cm]{\zihao{3}\CJKfamily{zhkai}\major}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\CJKfamily{zhkai}\zihao{-2} \hspace{3.5em}班级名称}\quad\dlmu[8.5cm]{\zihao{3}\CJKfamily{zhkai}\classname}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\CJKfamily{zhkai}\zihao{-2} \hspace{3.5em}学生姓名}\quad\dlmu[8.5cm]{\zihao{3}\CJKfamily{zhkai}\studentname}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\CJKfamily{zhkai}\zihao{-2} \hspace{3.5em}学\qquad 号}\quad\dlmu[8.5cm]{\zihao{3}\CJKfamily{zhkai}\studentid}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\CJKfamily{zhkai}\zihao{-2} \hspace{3.5em}指导老师}\quad\dlmu[8.5cm]{\zihao{3}\CJKfamily{zhkai}\adviser}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\CJKfamily{zhkai}\zihao{-2} \hspace{3.5em}完成日期}\quad\dlmu[8.5cm]{\zihao{3}\CJKfamily{zhkai}\number\year 年\number\month 月\number\day 日}\par}
|
||||
\end{flushleft}
|
||||
|
||||
\vspace{4cm}
|
||||
|
||||
{\CJKfamily{zhkai} \zihao{3} 教\hspace{0.5em}务\hspace{0.5em}处\hspace{0.5em}制}
|
||||
|
||||
\end{titlepage}
|
304
paper/chapters/implementation.tex
Normal file
@ -0,0 +1,304 @@
|
||||
% 第四章:关键技术实现
|
||||
|
||||
\section{关键技术实现}
|
||||
|
||||
\subsection{系统架构设计}
|
||||
本系统采用经典的三层架构设计,分为表现层、业务逻辑层和数据访问层。
|
||||
|
||||
在表现层中,基于Gradio框架构建了一个用户友好的Web界面,包含7个功能模块:模型管理、模型推理、模型微调、数据集生成、数据集管理、提示词模板管理和系统设置。
|
||||
该界面采用响应式设计,支持流式输出和灵活的参数配置,以满足不同用户的交互需求。
|
||||
|
||||
业务逻辑层是系统的核心部分,负责处理具体的业务逻辑。
|
||||
其中,模型训练模块基于Unsloth和TRL库实现了高效的LoRA微调功能;模型推理模块支持流式生成,并允许用户配置多种采样参数;数据集生成模块则基于LangChain PromptTemplate处理模板,支持Markdown文档解析和结构化数据生成,采用异步调用提高生成效率。
|
||||
|
||||
数据访问层主要负责数据的存储与管理。
|
||||
系统使用SQLite存储系统配置,同时采用TinyDB内存数据库管理数据集,支持JSON格式的数据导入和导出。
|
||||
通过这种分层设计,各层之间明确分工,不仅提升了系统的可扩展性和可维护性,还为后续的功能扩展奠定了基础。
|
||||
|
||||
\input{figures/system_architecture.tex}
|
||||
|
||||
\subsection{数据库设计与实现}
|
||||
|
||||
\subsubsection{双数据库架构设计(SQLite + TinyDB)}
|
||||
本系统创新性地采用SQLite与TinyDB相结合的双数据库架构,以应对不同类型数据的管理需求。
|
||||
对于API提供者信息等结构化数据,系统选用SQLite作为核心数据库,并通过SQLModel这一ORM工具实现面向对象的数据操作,其内置的线程锁机制有效保障了多线程环境下的数据并发安全。
|
||||
SQLite数据库的实体文件持久化存储于\texttt{workdir/db/db.sqlite}路径,确保数据的可追溯性。
|
||||
|
||||
针对数据集信息(包含文档元数据及问答对集合)和提示词模板等半结构化数据的管理,系统则采用轻量级文档数据库TinyDB。
|
||||
数据集采用内存存储与JSON文件持久化相结合的混合模式,而提示词模板则直接通过JSON文件进行存储。
|
||||
TinyDB的无模式(Schema-free)特性为数据模型的灵活扩展提供了便利,其对原生JSON格式的处理能力显著提升了数据序列化与反序列化的效率。
|
||||
这种双数据库协同架构在保障事务完整性的同时,充分兼顾了半结构化数据处理的敏捷性需求,实现了数据存储方案的最优化配置。
|
||||
|
||||
\subsubsection{据模型定义与存储方案}
|
||||
本系统遵循领域驱动设计原则,并借助Pydantic框架构建了层次化的数据模型体系,以确保业务数据的完整性和一致性。
|
||||
在数据集建模方面,设计了四级递进模型结构:Doc模型用于描述文档的基础元数据(如名称、存储路径、版本号等);Q\_A模型封装了单个问答对的核心要素;DatasetItem模型聚合多个问答对,形成逻辑上的数据单元;最终,Dataset模型整合元数据与数据项集合,构成完整的数据集结构。
|
||||
|
||||
提示词模板模型则通过promptTemplate实体进行抽象,包含模板ID、名称、描述、内容体及创建时间等关键字段。
|
||||
系统预置了验证规则,强制要求模板内容必须包含document\_slice变量占位符,以确保模板在实际应用中具备上下文填充能力。
|
||||
|
||||
在存储实现层面,数据集采用了内存数据库与文件系统持久化相结合的双重保障机制,并利用TinyDB的临时文件特性实现原子写入操作。
|
||||
提示词模板则直接采用JSON文件存储方案,其良好的可读性便于人工维护。
|
||||
这种差异化的存储策略旨在保证数据完整性的同时,提升数据访问和管理的效率。
|
||||
|
||||
\subsubsection{数据库初始化与管理实现}
|
||||
本系统实施了分层且智能化的数据库初始化与管理策略。
|
||||
针对SQLite数据库,初始化阶段将自动检测并创建数据库文件,并通过SQLModel的元数据创建功能动态构建表结构,同时支持从环境变量注入初始数据集,从而实现部署环境的快速配置。
|
||||
对于TinyDB子系统,初始化时将执行自动化目录扫描,对workdir/dataset路径下的JSON文件进行格式校验和数据加载,建立内存与文件系统之间的双向同步机制。
|
||||
|
||||
为保障数据可靠性,系统采用了多维度管理策略:在访问控制层面,SQLite数据库利用线程级锁机制实现并发安全,TinyDB则通过文件锁保证写入操作的互斥性;在数据操作层面,系统集成了Pydantic模型验证框架,在数据持久化之前执行严格的类型校验;在容错机制方面,系统采用预写式日志(WAL)记录关键操作,并结合异常捕获机制实现故障的可追溯性。
|
||||
特别设计的原子写入模块通过临时文件交换技术,确保在任何异常情况下存储文件的完整性,从而有效防范数据损坏的风险。
|
||||
|
||||
\subsection{语料生成与处理技术}
|
||||
|
||||
\subsubsection{Markdown文档解析}
|
||||
该解析器采用树形结构组织Markdown文档内容,核心是通过栈结构维护标题层级关系。
|
||||
当遇到\#号开头的标题行时,解析器会根据\#号数量确定当前标题的层级,并通过栈结构维护父子关系。
|
||||
如果遇到比栈顶元素层级低的标题,会不断弹出栈顶元素直到找到合适的父节点。
|
||||
对于代码块内容,解析器会特殊处理以```或\textasciitilde{}\textasciitilde{}\textasciitilde{}开头的行,将其间的所有内容视为原始文本直接附加到当前节点,不进行任何解析。
|
||||
这种处理方式保证了代码块内的特殊字符不会被误解析为Markdown语法。
|
||||
文档内容的组织采用递归遍历方式。
|
||||
\texttt{process\_markdown\_file}函数会先构建完整的文档树,然后通过\texttt{traverse}函数递归遍历所有节点。
|
||||
对于叶子节点(没有子节点的节点),会将从根节点到该节点的所有标题用``-\textgreater{}''连接,并与节点内容组合输出,形成完整的上下文信息。
|
||||
解析器还提供了\texttt{print\_tree}函数用于可视化文档结构,可以清晰展示各层级标题的嵌套关系和内容分布。
|
||||
这种树形结构表示法特别适合处理具有复杂层级关系的长文档,能够准确反映文档的原始组织结构。
|
||||
|
||||
\subsubsection{prompt模板套用和提示词格式引导}
|
||||
通过PromptTemplate类构建动态提示词模板,前端界面支持选择预存模板并自动提取变量生成可编辑表格,实现提示词参数化;采用提示词追加JSON格式要求和API强制返回结构的双重保障机制确保输出结构化;支持多轮生成并记录详细耗时和token使用情况,同时具备异常处理能力;通过严格的数据验证流程将响应解析映射到数据模型,确保数据质量;特别实现了文档切片与模板变量的智能组合,有效支持从长文档生成结构化QA数据集,形成了一套完整的提示词模板应用与数据集生成解决方案。
|
||||
|
||||
\subsubsection{OpenAI API的协程并发语料生成}
|
||||
本系统的OpenAI API协程并发语料生成模块采用异步IO架构实现高并发处理,其核心逻辑体现在 \texttt{reasoning.py} 中的 \texttt{call\_openai\_api} 方法。
|
||||
该方法通过实例化 \texttt{openai.AsyncOpenAI} 异步客户端,支持多轮次(rounds参数)连续对话请求,自动解析JSON格式响应并记录完整的调用元数据。
|
||||
在并发控制方面,基于Python原生\texttt{asyncio}事件循环实现非阻塞式请求处理,通过await关键字异步等待API响应,这种设计理论上可扩展为使用\texttt{asyncio.gather}实现并行请求批处理。
|
||||
数据流设计采用 \texttt{dataset\_generation.py} 中定义的 \texttt{LLMRequest} 请求对象封装输入参数,生成 \texttt{LLMResponse} 响应列表。
|
||||
错误处理机制采用全异常捕获策略,在发生API超时或格式错误时保留错误上下文和\texttt{response\_id}追踪链,同时维护包含耗时统计(精确到毫秒)、prompt/completion tokens使用量及总资源消耗的性能监控体系。
|
||||
该模块通过 \texttt{dataset\_generate\_page.py} 集成到前端生成流程,实现文档切片处理、可配置并发参数(当前UI隐藏)和实时进度反馈的完整工作流。
|
||||
|
||||
\subsubsection{JSON格式校验、反序列化和持久化}
|
||||
本系统采用三层架构实现JSON数据处理全流程管理:在数据输入层,通过动态Schema绑定技术结合大语言模型的格式约束参数,构建双向校验机制,确保原始数据符合预定义结构规范;在数据处理层,设计基于异常传播模型的三级解析体系,通过语法验证、语义补全和类型强转实现安全反序列化,采用领域驱动设计模式将原始JSON映射为业务对象;在数据存储层,运用分层持久化策略,通过内存序列化缓存、文档数据库中间存储和文件系统冷备份三级存储机制,实现数据生命周期管理。
|
||||
系统通过管道过滤器模式串联各处理模块,建立数据校验→结构转换→持久存储的完整处理链路,各组件间通过标准接口解耦,形成高内聚低耦合的可扩展架构,有效提升复杂JSON数据处理的可靠性和可维护性。
|
||||
|
||||
\subsection{语言模型训练技术}
|
||||
|
||||
\subsubsection{训练数据准备与格式化}
|
||||
|
||||
语言模型的监督式微调效果高度依赖于训练数据的质量与组织形式。
|
||||
本节重点阐述数据预处理的核心逻辑,主要包括数据结构设计、对话模板转换和高效数据处理三个关键环节。
|
||||
|
||||
在数据组织结构层面,本研究采用``问题-答案''(question-answer)双字段结构作为基础数据单元。
|
||||
这种结构化设计源于对话型语言模型的训练需求,每个样本对应完整的对话轮次,其中用户提问(question)构成输入引导,助理解答(answer)作为目标输出。
|
||||
原始数据需经过严格的质量筛选与语义对齐处理,确保问答对具有明确的意图匹配性和逻辑连贯性,这是避免模型产生幻觉现象的重要基础。
|
||||
|
||||
对话模板的应用是数据格式化的核心步骤。
|
||||
通过预定义的qwen-2.5模板规范,系统将原始问答对转换为包含角色标识符(user/assistant)和特殊符号(\textless{}\textbar{}im\_start\textbar{}\textgreater{})的标准化对话序列。
|
||||
该转换过程遵循两阶段结构化原则:首先构建对话轮次列表,保持用户与助手消息的严格交替;其次通过分词器的模板解析功能,自动添加必要的系统提示符和消息分隔符。
|
||||
这种格式化处理不仅统一了不同来源数据的表达形式,更重要的是建立了模型预期的对话结构记忆,为后续监督学习提供稳定的模式识别基础。
|
||||
|
||||
针对对话数据的特性,本研究实施了响应聚焦的损失计算策略。
|
||||
在模板转换过程中,通过指令部分(instruction\_part)与响应部分(response\_part)的显式划分,系统仅在助手生成内容对应的token位置计算训练损失。
|
||||
这种选择性损失计算机制使模型专注于学习回答生成模式,有效避免了输入文本重复性对参数更新的干扰,同时降低了无关token对梯度传播的影响强度。
|
||||
|
||||
在数据处理技术实现层面,采用Hugging Face
|
||||
Datasets库构建高效数据管道。
|
||||
将原始Python列表转换为内存映射格式的HFDataset对象,该设计显著提升了大规模数据的加载效率。
|
||||
通过map操作实现批量化数据处理,配合多进程并行机制,在保证数据转换一致性的同时,实现了预处理速度与内存占用的优化平衡。
|
||||
这种工业化数据处理流程的确立,为后续高频次的模型训练迭代提供了可靠的基础设施支持。
|
||||
|
||||
\subsubsection{训练流程实现与配置}
|
||||
为了高效且便捷地进行大规模语言模型的监督式微调,本项目选用了一系列成熟且广泛应用的开源框架,核心依赖于Hugging Face的transformers库,该库提供了丰富的预训练模型、分词器以及用于模型训练的基础设施。
|
||||
在此基础上,结合使用了trl(Transformer Reinforcement Learning)库,特别是其提供的监督式微调训练器(SFTTrainer)。
|
||||
该训练器是专门为简化监督式微调任务而设计的,它在`transformers`的训练接口之上进行了封装和优化,使得研究者能够更专注于数据准备和模型配置,而无需处理底层复杂的训练循环逻辑,极大地提高了开发效率。
|
||||
这种框架组合提供了强大的功能性和灵活性,能够支持复杂模型的加载、PEFT技术的应用以及多样化的训练策略。
|
||||
|
||||
模型训练的效果很大程度上取决于训练参数的合理设置。
|
||||
在本项目中,通过配置一系列关键参数来控制训练过程的行为,这些参数包括但不限于:学习率(learning\allowbreak\_rate),它决定了模型在每次参数更新时的步长;每个设备的训练批次大小(per\allowbreak\_device\allowbreak\_train\allowbreak\_batch\allowbreak\_size),影响显存占用和梯度更新的稳定性;梯度累积步数(gradient\allowbreak\_accumulation\allowbreak\_steps),通过累积多个小批次的梯度来模拟使用更大的批次进行训练;训练的总步数或总轮数(max\_steps / epoch),定义了整个训练过程的长度;学习率调度器类型(lr\allowbreak\_scheduler\allowbreak\_type),控制学习率随训练进程的变化策略;权重衰减(weight\allowbreak\_decay),作为一种正则化手段,有助于防止模型过拟合;以及随机种子(seed),用于确保训练结果的可复现性。
|
||||
对这些参数的细致调整是获得高性能模型的关键环节。
|
||||
|
||||
训练大语言模型对计算资源要求极高,因此采用了多种优化技术来提升训练效率并降低资源消耗。
|
||||
首先是混合精度训练(mixed precision training),利用半精度浮点数(如fp16或bf16)进行计算和存储,相比于传统的全精度(FP32),可以显著减少显存占用并加速计算,同时通过配合少数全精度参数,可以保证训练的稳定性和模型的精度,本项目会根据硬件支持情况自动选择合适的半精度类型。
|
||||
其次,在优化器选择上,采用了诸如adamw\_8bit的8位量化版本,这种优化器能够大幅度减少优化器状态所需的显存,使得在相同硬件条件下可以训练更大的模型或使用更大的批次大小。
|
||||
此外,还采用了梯度检查点(use\_gradient\_checkpointing)技术,这项技术通过在反向传播时重新计算前向传播中的一些中间激活值来显著降低显存峰值占用,尤其在使用优化实现时,能更高效地平衡计算量和显存消耗。
|
||||
|
||||
在将准备好的训练数据输入模型之前,需要一个数据整理器(Data Collator)来处理一个批次内的样本。
|
||||
特别是在处理变长序列时,数据整理器的作用至关重要。
|
||||
本项目使用了针对序列设计的整理器,负责将批次内长度不一的文本序列进行填充(padding),使其达到批次内的最大长度或预设的最大长度,从而能够被模型以张量的形式统一处理。
|
||||
同时,数据整理器还会生成相应的注意力掩码(attention mask),告知模型哪些部分是真实的序列内容,确保模型不会在填充位置进行不必要的计算或注意力分配。
|
||||
对于监督式微调任务,它还需要协助处理标签的准备,配合生成适当的损失掩码(loss mask),确保损失计算仅发生在目标响应的token上,忽略输入提示部分的损失。
|
||||
|
||||
\subsubsection{模型训练执行与监控}
|
||||
|
||||
在完成语言模型微调所需的数据准备、模型配置和训练参数设置后,接下来便是训练流程的实际执行阶段。
|
||||
这一阶段的核心任务是将处理好的数据输入到配置好的模型中,通过优化算法不断调整模型参数,使其学习到预期的能力。
|
||||
训练的启动意味着计算资源被分配和调度,数据批次被送入模型进行前向传播,计算损失,并通过反向传播计算梯度,最终利用优化器更新模型权重。
|
||||
整个过程是一个迭代循环,直至达到预设的训练轮次或满足其他停止条件。
|
||||
|
||||
为了确保训练过程的稳定性和有效性,并对训练进度和效果进行实时跟踪与评估,模型训练的执行通常伴随着详尽的监控机制。
|
||||
监控是训练过程中不可或缺的一环,它允许研究人员和开发者观察关键指标的变化趋势,例如训练损失(Training Loss)、学习率(Learning Rate)以及其他可能的评估指标。
|
||||
通过监测这些指标,可以及时发现潜在问题,如模型不收敛、过拟合或欠拟合等,从而及时调整训练策略或参数。
|
||||
|
||||
训练过程中的重要组成部分是检查点的保存。
|
||||
检查点是指在训练进行到特定阶段时,将模型的当前参数、优化器状态、学习率调度器状态等完整信息保存下来。
|
||||
这具有多重意义:首先,它提供了一种容错机制,即使训练过程意外中断,也可以从最近的检查点恢复训练,避免从头开始;其次,通过保存多个检查点,可以在训练结束后选择性能最佳的模型版本,或者用于后续的进一步研究或部署;最后,检查点也为评估模型在不同训练程度下的表现提供了可能。
|
||||
检查点的保存策略(例如,按固定的步数或周期保存)和保存路径是训练配置中的重要考量。
|
||||
|
||||
除了检查点,详细的训练日志记录也是必不可少的。
|
||||
日志会记录训练过程中的各种事件和指标,例如每一步或每若干步的损失值、梯度范数、内存使用情况等。
|
||||
这些日志信息可以被保存到文件,供事后分析,也可以被实时导出到可视化工具中。
|
||||
目前,业界广泛使用诸如TensorBoard这类可视化工具来呈现训练过程中的曲线图、直方图等,使得复杂的训练数据变得直观易懂。
|
||||
通过这些可视化界面,研究人员可以清晰地看到损失如何随训练步数下降,学习率如何变化,权重或梯度的分布情况等,从而深入理解训练动态,辅助决策和优化。
|
||||
|
||||
总而言之,模型训练的执行是一个计算密集型的过程,而有效的监控系统则是确保这一过程高效、稳定并最终取得成功的关键。
|
||||
通过合理的检查点策略和详细的日志记录及可视化,可以全面掌握训练状态,及时调整策略,并为后续的模型评估和部署奠定基础。
|
||||
|
||||
|
||||
\subsubsection{模型保存与导出}
|
||||
|
||||
在语言模型训练完成后,将训练得到的模型参数和相关的配置信息进行持久化存储是至关重要的步骤。
|
||||
模型持久化的目的是为了能够在后续阶段加载模型进行推理、评估,或者进一步的迭代开发,而无需每次都重新训练。
|
||||
这一过程通常包括保存模型权重(即模型学习到的参数)以及与模型紧密关联的分词器(Tokenizer)的配置和词表。
|
||||
分词器负责文本的输入和输出预处理,其状态必须与模型保持一致才能确保模型能够正确理解输入并生成有效的输出。
|
||||
标准的模型保存方法会将模型权重和分词器信息分别存储在指定的文件或目录中,形成一个完整的模型资产包。
|
||||
|
||||
|
||||
针对采用参数高效微调(如LoRA)训练得到的模型,模型保存的方式会更加灵活。
|
||||
一种常见的做法是仅保存LoRA 层的权重。
|
||||
由于 LoRA只修改了基模型的小部分参数,这种方式保存的文件体积非常小,便于存储和传输。
|
||||
在进行推理时,需要将保存的 LoRA 权重与原始的基模型加载并合并使用。
|
||||
另一种方式是将训练好的 LoRA 权重与原始基模型的对应层权重进行合并,生成一个包含所有参数的完整模型。
|
||||
这种合并后的模型可以直接加载进行推理,无需额外步骤,适用于部署到不需要区分基模型和LoRA层的环境中。
|
||||
合并时可以选择不同的精度(如16位浮点或4位整数),以平衡模型大小和推理性能。
|
||||
|
||||
|
||||
除了标准的保存格式,为了适应不同的部署环境和推理框架,模型有时需要被导出为特定的格式。
|
||||
GGUF(GPT-Generated Unified Format)就是一种为 LLM 推理设计的格式,它支持多种量化方法,可以将模型参数压缩到更小的体积,同时优化在 CPU 或特定硬件上的推理性能。
|
||||
将模型导出为 GGUF 并选择合适的量化级别(如Q4\_K\_M, Q8\_0等),可以在保证一定推理精度的情况下,显著降低模型的资源消耗,使其更容易在终端设备或资源受限的环境中运行。
|
||||
|
||||
|
||||
此外,将训练好的模型发布到模型社区或平台(例如 HuggingFace)是实现模型共享和便捷部署的常用方式。
|
||||
通过将模型文件(包括合并后的模型、LoRA权重或特定格式如 GGUF的模型)推送到这些平台,其他用户可以轻松地下载和使用您的模型,同时也方便您自己从任何地方访问您的模型资产。
|
||||
发布时也可以选择包含多种量化版本的模型,以满足不同用户的需求。
|
||||
|
||||
|
||||
综上所述,模型保存与导出是语言模型训练流程中连接训练与应用的桥梁。
|
||||
选择合适的保存格式和方法取决于模型类型、微调策略以及预期的部署环境和性能需求,旨在实现模型的有效管理、便捷加载和高效推理。
|
||||
|
||||
|
||||
\subsection{前端交互系统实现}
|
||||
|
||||
\subsubsection{Gradio交互框架设计}
|
||||
|
||||
Gradio交互框架设计采用了模块化的架构思想,将复杂的大模型开发流程分解为七个功能明确的子模块。
|
||||
系统主界面通过gr.Blocks()构建容器框架,采用Tabs组件实现多页面导航,每个Tab对应一个独立功能模块的实现文件。
|
||||
这种设计既保持了界面风格统一,又实现了功能模块的高内聚。
|
||||
|
||||
|
||||
\subsubsection{全局状态管理机制}
|
||||
|
||||
本系统在前端交互层面构建了一套模块化的全局状态管理机制,核心在于通过
|
||||
\texttt{global\_var.py}
|
||||
模块实现一个基于单例模式的状态容器。
|
||||
此容器采用私有化变量(如\texttt{\_model}、\texttt{\_tokenizer}等)封装核心组件,并通过工厂模式支持大语言模型的动态加载。
|
||||
状态的读取与修改通过公有访问器方法(如\texttt{get\_model()} 和\texttt{set\_model()})进行受控管理,确保状态变更的可追踪性和安全性。
|
||||
具体实现上,模型对象在通过HuggingFace Transformers库加载后会缓存于内存,而分词器和数据集对象则采用惰性加载策略。
|
||||
数据集的版本化管理通过TinyDB 文档数据库实现。
|
||||
为保障并发环境下的线程安全性,系统利用 Python的全局解释器锁(GIL)机制,并对关键状态变更操作(如模型切换)采用原子性事务处理序列,确保操作的完整性,例如执行``卸载旧模型→ 清理显存 → 加载新模型''的原子操作。
|
||||
这种设计模式使得各功能模块,例如\texttt{train\_page.py}中的训练模块,能够通过统一接口获取系统实时状态,同时有效地降低了模块间的耦合度,为系统的可扩展性提供了标准化的接入点。
|
||||
|
||||
|
||||
系统的状态生命周期通过 \texttt{init\_global\_var()}初始化函数进行全面管理,该过程包含一个三阶段的控制流程。
|
||||
首先,系统会锚定工作目录,基于给定的路径参数创建标准化的存储目录结构,包括\texttt{models}、\texttt{datasets} 和 \texttt{training}三级子目录,并验证其可写权限。
|
||||
其次,系统建立双层持久化存储机制,利用SQLite 数据库对模型元数据进行关系型管理,同时借助 TinyDB完成非结构化数据集的文档存储。
|
||||
最后,执行环境预热步骤,包括预加载默认模型的分词器权重文件至显存以及初始化CUDA 计算上下文。
|
||||
这一初始化链式调用贯穿系统启动的 entireprocess,工作目录作为核心的路径解析基准,不仅确保了在不同环境(开发、生产)下的配置无缝切换,而且通过SQLite 关系数据库与 JSON文档数据库的混合存储模式,实现了结构化元数据与非结构化训练数据的有效隔离与管理。
|
||||
|
||||
|
||||
在跨组件通信方面,系统基于 \texttt{global\_var.py}模块构建了一个发布-订阅模式的状态同步机制。
|
||||
当模型管理页面(通过\texttt{model\_manage\_page.py})调用 \texttt{set\_model()}方法更新当前使用的模型时,系统会触发一个全局状态变更事件。
|
||||
订阅了该状态的组件,例如训练页面(通过\texttt{train\_page.py}),可以通过 \texttt{get\_model()}接口实时获取最新的模型实例(如第 21 行对 \texttt{get\_model()}的调用)。
|
||||
同样,数据集的更新操作(如新增训练样本,通过
|
||||
\texttt{get\_datasets().insert()})会自动广播到所有关联组件。
|
||||
这意味着训练页面中的数据集下拉列表(如第22 行 \texttt{datasets\_list}的构建)能够即时刷新以显示最新的数据集,从而实现多视图状态的无缝同步。
|
||||
通过接口隔离原则和事件驱动机制的应用,各功能模块无需感知彼此的内部实现细节,仅需通过标准接口进行交互,这在保证系统响应实时性的同时,将模块间的耦合度降低至函数调用级别。
|
||||
|
||||
|
||||
\subsubsection{前后端数据流设计}
|
||||
|
||||
Gradio框架的前后端数据流设计核心在于通过组件(Components)和事件(Events)实现用户界面与Python 后端逻辑的交互。
|
||||
当用户在 Gradio 构建的 Web界面(前端)中与输入组件(如文本框、滑块、文件上传等)进行互动或触发某个事件(如点击按钮)时,前端会将输入组件当前的数值或状态打包,通过HTTP 请求发送到运行在服务器端的 Python后端。
|
||||
后端接收到这些数据后,会根据您定义的处理函数(HandlerFunction),以这些前端数据作为函数的输入参数来执行相应的业务逻辑。
|
||||
函数执行完毕后,返回的结果数据会被Gradio 框架捕获,并通过 HTTP响应发送回前端。
|
||||
前端接收到后端返回的数据后,会根据您配置的输出组件(如文本框、图片展示、画廊等),自动更新界面以展示处理结果,从而完成一次完整的数据交互和展示流程。
|
||||
整个过程由Gradio框架内部负责序列化、传输和反序列化数据,极大地简化了开发者构建交互式 Web应用的复杂度。
|
||||
|
||||
|
||||
\subsubsection{流式响应与实时反馈}
|
||||
|
||||
在实现前端聊天系统的交互时,为了提供更佳的用户体验,特别是与大语言模型进行对话时,采用传统的``一次性等待全部生成再显示''的方式会让用户感受到明显的延迟。
|
||||
因此,流式响应和实时反馈技术变得至关重要。
|
||||
这项技术的目的是让用户能够像看到对方正在``打字''一样,文字内容可以随着模型的生成进度逐步显示在聊天界面上,而不是等到模型完全生成完毕才一次性出现。
|
||||
|
||||
|
||||
实现流式响应的核心在于后端如何将语言模型的输出分批、分步地发送给前端,以及前端如何接收并逐步更新显示。
|
||||
具体来说,当用户发送消息后,后端不再等待语言模型生成完整的回复文本,而是配置模型以``流''的形式进行输出。
|
||||
这意味着模型在生成过程中,会不断地吐出部分文本片段(通常是词或字),并通过一个特定的通道(比如一个流式生成器对象)进行传输。
|
||||
|
||||
|
||||
为了不阻塞处理用户请求的主进程或主线程,耗时的语言模型文本生成任务会在一个独立的线程中启动。
|
||||
这个独立线程负责调用模型进行生成,并将生成的文本片段源源不断地送入到前面提到的那个``流''中。
|
||||
与此同时,主线程则负责监听并读取这个``流''中的内容。
|
||||
每当从流中读取到新的文本片段时,主线程就将这部分内容附加到当前正在构建的回复文本后面,并立即将更新后的聊天历史(包含不完整的、正在增长的助手回复)发送给前端界面。
|
||||
|
||||
|
||||
前端界面接收到后端发送的带有最新文本片段的聊天历史后,会立即更新聊天框中对应的助手回复消息。
|
||||
由于这个更新过程是高频率进行的,用户在界面上看到的效果就是助手的回复文字正在一个词一个词、甚至一个字一个字地逐步``打''出来,形成了实时反馈的视觉效果。
|
||||
整个流程持续进行,直到语言模型完成全部生成并在流中发送结束信号,或者达到预设的生成长度限制。
|
||||
通过这种流式传输和逐步更新的方式,极大地提升了对话的实时性和用户感知到的系统响应速度。
|
||||
|
||||
|
||||
\subsubsection{异常处理与用户反馈}
|
||||
|
||||
本系统在异常处理方面构建了一套全面且用户友好的机制。
|
||||
首先,通过装饰器模式实现全局异常捕获,对所有API调用进行拦截,能够自动识别并处理模型加载失败、API请求超时及数据解析错误等问题。
|
||||
该机制进一步细化了异常类型,区分了可预见的业务异常(例如用户输入无效)和不可预见的系统异常,并建立了相应的错误代码体系,以便精确诊断问题。
|
||||
其次,为了提升用户体验,系统在聊天界面利用Gradio的Error组件实时展示错误摘要,并通过可折叠面板提供详细的错误堆栈信息,便于开发者调试。
|
||||
特别地,针对模型生成过程中出现的tokenization异常,系统能动态插入错误标记,同时维持对话历史的连贯性。
|
||||
最后,在输入端,系统建立了完善的参数验证体系,通过类型强制转换(如将字符串转为浮点数)和边界值检测(如限制温度参数范围)实现前端校验。
|
||||
对于检测到的非法输入,系统会高亮相应的参数框并显示动画提示,同时禁用提交按钮,从而有效防止无效请求的发送并引导用户正确操作。
|
||||
|
||||
|
||||
\subsubsection{基于子进程tensorboard的训练监控}
|
||||
|
||||
当前端页面上的用户配置好训练参数(如数据集、学习率、批次大小等)并点击``开始微调''按钮时,前端界面会触发一个对应的后端处理函数。
|
||||
这个函数首先会根据当前的日期或序列号等信息,为本次训练创建一个独立的、用于存放训练日志和模型检查点的目录,确保不同训练任务的日志不会混淆。
|
||||
|
||||
|
||||
接着,系统会扫描查找当前计算机上一个可用的网络端口,用于启动TensorBoard服务。
|
||||
找到合适的端口后,程序不会在当前主进程中直接运行TensorBoard,而是通过调用操作系统的命令,启动一个全新的、独立的TensorBoard进程(即子进程)。
|
||||
这个子进程被告知需要监控刚刚创建的训练日志目录,并在之前找到的可用端口上提供服务。
|
||||
由于是在单独的进程中运行,即使主训练过程非常耗时或发生其他情况,TensorBoard的服务也不会受到直接影响。
|
||||
|
||||
|
||||
TensorBoard子进程成功启动并开始监听指定端口后,后端处理函数会立即构建一个HTML的\texttt{\textless{}iframe\textgreater{}}标签。
|
||||
这个标签的作用就像网页中的一个``窗口'',它可以加载并显示另一个网页的内容。
|
||||
在这里,\texttt{\textless{}iframe\textgreater{}}的源地址(src属性)被设置为TensorBoard子进程正在提供服务的本地地址(例如
|
||||
\texttt{http://localhost:端口号})。
|
||||
这个生成的\texttt{\textless{}iframe\textgreater{}}标签会被发送回前端界面,更新页面上预留的显示区域,使得TensorBoard的界面直接呈现在用户眼前。
|
||||
|
||||
|
||||
随后,实际的模型训练过程才正式开始。
|
||||
在训练过程中,模型训练框架会按照预定的频率(例如每隔一定步数或每个epoch结束时),将当前的训练指标(如损失值、准确率等)记录并写入到之前为本次训练专门创建的日志目录中。
|
||||
TensorBoard子进程一直在监控这个目录中的日志文件变化。
|
||||
一旦检测到新的数据写入,TensorBoard会自动读取这些数据,更新其内部的图表和可视化内容。
|
||||
由于前端页面的\texttt{\textless{}iframe\textgreater{}}实时连接着TensorBoard的服务,这些更新的可视化结果也会同步反映在前端界面上,用户可以实时地看到模型的训练进度和性能变化。
|
||||
|
||||
|
||||
最后,无论模型训练是正常完成还是因为错误而中断,系统都会执行清理操作。
|
||||
在训练结束时(通过异常处理机制中的finally块保证执行),程序会发送终止信号给之前启动的TensorBoard子进程,确保该子进程被关闭,释放占用的系统资源和网络端口。
|
||||
这样就完成了一次基于子进程TensorBoard的训练监控的完整流程,既提供了实时可视化功能,又保持了主训练过程的独立性和稳定性。
|
||||
|
||||
|
||||
\subsection{扩展性实现}
|
||||
|
||||
本项目在扩展性设计方面采用了模块化的架构思想,通过清晰的目录结构将功能划分为前端、数据模型、工具和训练等独立模块,每个模块职责明确且相互解耦。
|
||||
在数据模型层面,采用Pydantic进行数据建模,支持数据验证和序列化,核心数据模型通过BaseModel继承实现可扩展性;工具系统采用插件化设计,通过统一的导出接口支持新工具的便捷添加;前端界面基于Gradio框架实现组件化设计,支持页面的灵活组织;配置管理方面使用global\_var模块统一管理全局变量,并支持环境变量配置;模型管理支持多种保存格式和可配置的加载参数;数据存储采用SQLModel和TinyDB提供抽象化的数据操作接口。
|
||||
此外,项目还实现了统一的异常处理机制和规范化的错误输出,并采用MIT开源协议支持代码的自由使用和修改。
|
||||
这些设计使得项目具有良好的可维护性和可扩展性,新功能可以通过添加新模块或扩展现有模块来实现,而无需大规模修改现有代码。
|
127
paper/chapters/introduction.tex
Normal file
@ -0,0 +1,127 @@
|
||||
% 第一章:绪论
|
||||
|
||||
\section{绪论}
|
||||
|
||||
\subsection{研究背景与意义}
|
||||
|
||||
在现代软件开发领域,程序员的编码工作日益依赖于先进的大语言模型支持,这些模型凭借其强大的能力,显著自动化了代码生成流程,有效减轻了开发者的工作负担,并大幅度提升了开发效率。
|
||||
然而,尽管这些模型在公开数据集与广泛使用的开源项目中展现出非凡的性能,但在处理企业内部高度专业化的私有库时,其局限性便显露无遗。
|
||||
核心原因在于,大语言模型往往基于广泛的通用数据集进行训练,缺乏对特定企业或项目中私有库内专有函数、类及其交互细节的深度理解和应用适应性。
|
||||
|
||||
相较于广泛采用的公开编码模型,针对私有库设计的专有模型显得尤为必要。
|
||||
公开模型虽强大,但在面对包含企业核心业务逻辑、技术秘密及高度定制化功能的私有库时,往往捉襟见肘。
|
||||
由于缺乏对私有库具体实现细节的认知,生成的代码往往无法精准引用库中的类、方法或属性,这不仅增加了后续人工调整的工作量,还可能引入潜在的安全风险。
|
||||
此外,企业间的私有库差异巨大,从架构设计到API接口各不相同,要求任何自动化工具都必须具备高度的灵活性和可定制性,以适应这些多样化的环境。
|
||||
|
||||
鉴于上述现状,本项目通过深度解析私有库的文档资源,精准提取关键信息,并以此为基础对大语言模型进行针对性的微调与优化。
|
||||
这一过程不仅增强了模型对私有库特定功能和用法的理解能力,还极大地提升了生成代码的准确性和实用性。
|
||||
通过本项目,我们期望能够让生成的代码片段无缝集成于企业的私有库生态中,真正实现企业级软件开发的智能化与高效化,满足企业对高质量、高安全性代码的迫切需求。
|
||||
|
||||
\subsection{国内外研究现状}
|
||||
|
||||
\subsubsection{大语言模型微调研究现状}
|
||||
大语言模型微调研究在国内外均呈现快速发展态势\cite{zhang2024}。
|
||||
在国内,智源研究院与TeleAI联合开发的"悟道·天鹰"系列模型代表了重要进展,其520亿参数版本的开源显著促进了国内AI社区发展。
|
||||
这些模型在部分中英文基准测试中表现出与Llama3-70B和GPT-4相当甚至更优的性能。
|
||||
为解决"幻觉"问题,智源研究院开发的BGE系列向量模型通过检索增强生成(Retrieval-Augmented Generation,RAG)技术有效提升了模型准确性。
|
||||
|
||||
国内外研究均呈现出对小型语言模型(Small Language Model,SLM)的高度关注。
|
||||
SLM在计算资源需求和训练成本方面具有明显优势,表明经过领域特定微调的SLM在特定任务上可超越更大的通用模型。清华大学、北京大学和中国科学院等机构在LLM研究中发挥重要作用,其应用范围从古籍数字化到医学研究等多个领域。
|
||||
|
||||
国际研究重点关注长输出大语言模型及其生成连贯长文本的能力。
|
||||
研究人员广泛探索了参数知识适应(如Domain-Adaptive Pre-Training (DAPT)、Instruction Tuning (IT)、Prompt Learning (PL)和模型编辑)和半参数知识适应(如RAG和基于Agent的系统)等技术,以在保留通用知识的同时提高特定任务性能。
|
||||
研究发现,即使少量监督微调数据也能有效激活预训练模型中的知识。
|
||||
|
||||
尽管取得进展,微调研究仍面临诸多挑战。
|
||||
国内主要挑战包括模型创新不足、高质量训练数据稀缺以及"幻觉"问题限制了模型在高精度应用中的可靠性。
|
||||
国际上,长输出LLM面临高质量长序列数据缺乏和连贯性维持困难等问题,同时大模型的高计算成本也推动了对更高效模型的需求。
|
||||
|
||||
未来研究趋势包括:基于大模型的具身智能;提示工程和认知工程的深化应用;检索增强生成技术的进一步发展;通过量化和剪枝提高LLM效率;增强模型可解释性;以及探索Transformer之外的新型模型架构和训练范式。
|
||||
|
||||
\subsubsection{微调后大语言模型在企业级应用领域的应用现状}
|
||||
微调后大语言模型在企业级应用领域的应用正在国内外快速发展\cite{vmk2024}。
|
||||
在国内,企业主要在客户服务领域探索微调LLM,创建智能客服机器人以提高客户满意度和运营效率。
|
||||
内容生成是另一重要应用,通过对行业特定数据进行微调,模型可生成符合品牌风格的营销文案和产品描述。
|
||||
北京大学在古籍数字化领域的探索和智源研究院的Emu3多模态模型也展示了在特定企业级应用领域的潜力。
|
||||
总体而言,微调LLM在中国企业级应用领域尚处早期阶段,但潜力巨大。
|
||||
|
||||
国际上,微调后大语言模型在企业级应用领域更为成熟和广泛。
|
||||
客户服务领域的智能支持系统能提供全天候多语言帮助,处理各类咨询并将复杂问题上报人工客服。
|
||||
内容生成方面,微调LLM被广泛应用于营销、广告和媒体行业,快速生成各类文本内容。
|
||||
金融机构和咨询公司利用微调LLM协助撰写专业分析报告。
|
||||
此外,LLM在数据标注和合成方面的应用对需要大量高质量标注数据的企业级应用至关重要,显著提高了数据标注效率和一致性。
|
||||
微调后大语言模型已广泛应用于国际企业级应用领域,并不断扩展到更多行业和业务流程。
|
||||
|
||||
\subsubsection{AI辅助编码研究现状}
|
||||
AI辅助编码的研究和应用在中国尚处于起步阶段。
|
||||
虽然一些大型科技公司和研究机构已开始关注这一领域并推出了内部或限量使用的工具,但像GitHub Copilot这样具有广泛影响力的AI编码助手仍然相对稀少。
|
||||
可以推断,国内研究主要集中在使用机器学习和自然语言处理技术帮助开发者提高编码效率、减少错误以及学习新的编程语言或框架。
|
||||
这可能包括代码自动补全、语法错误检查、代码片段推荐以及基于自然语言描述生成代码等功能。
|
||||
然而,由于缺乏直接相关的公开研究信息,国内AI辅助编码工具的具体功能、性能以及对软件开发流程的影响仍需进一步调查和分析。
|
||||
尽管如此,随着中国软件产业的发展和对开发效率需求的日益增长,AI辅助编码在国内具有广阔的应用前景。
|
||||
|
||||
在国际上,AI辅助编码的研究和应用已取得了显著进展。
|
||||
GitHub Copilot、Tabnine、IntelliCode等AI编码工具被开发者广泛使用。
|
||||
这些工具通常在大规模代码语料库上进行训练,能够提供智能代码补全、错误检测、代码建议和代码生成。
|
||||
研究表明,这些工具可以显著提高开发者的编码速度和效率,减少代码错误,并帮助开发者更好地理解和使用各种编程语言和框架。
|
||||
国际研究着重于进一步提升AI编码工具的智能化水平,例如使其能够理解更复杂的代码逻辑,更好地处理上下文信息,生成更符合开发者意图的代码,以及与各种开发环境和工作流程更深入地集成。
|
||||
此外,还在研究AI编码工具对软件开发流程、代码质量以及开发者学习曲线的影响。
|
||||
总的来说,AI辅助编码在国际上已成为一个成熟且持续发展的研究领域,正在深刻改变软件开发模式。
|
||||
\cite{haque2025}
|
||||
|
||||
\subsubsection{提示工程研究现状}
|
||||
提示工程是一门新兴技术,随着大语言模型的普及在中国受到越来越多的关注。
|
||||
上海交通大学的研究人员已经认识到提示工程在未来人工智能应用中的重要性。
|
||||
可以推断,国内的研究和实践主要集中在探索如何设计更有效、更精准的自然语言提示来引导大语言模型生成期望的输出。
|
||||
这可能包括研究不同的提示技巧,例如使用清晰具体的指令、提供相关的上下文信息以及利用少量样本提示。
|
||||
一些国内开发者和企业也开始在实际场景中应用提示工程,例如优化提示以提高智能客服系统的响应质量,增强内容生成的连贯性和相关性。
|
||||
然而,与国际研究相比,在中国提示工程方面的系统性研究和理论框架可能仍处于早期发展阶段。
|
||||
随着大语言模型技术的不断进步及其在中国应用范围的扩大,提示工程有望成为一个越来越重要的研究和实践领域。
|
||||
|
||||
在国际上,提示工程已成为一个热门研究领域。
|
||||
研究人员广泛探索了各种提示技巧和策略,例如零样本提示、少量样本提示和思维链提示,并研究了它们对大语言模型输出质量的影响。
|
||||
同时,出现了多个提示工程框架和工具,旨在帮助用户更有效地设计和管理提示。
|
||||
国际研究还侧重于理解为什么某些提示能产生更好的结果以及如何自动生成或优化提示。
|
||||
此外,还在进行一些关于提示压缩的研究以提高效率。
|
||||
总的来说,国际上在提示工程方面的研究已经形成一定的体系,并正在持续发展和完善,为更好地利用大语言模型提供了重要的理论基础和实践指导。
|
||||
\cite{Meskó2023}
|
||||
|
||||
\subsection{本文结构安排}
|
||||
|
||||
本文围绕基于大语言模型的自动化微调框架展开研究与实现,全文共分为六章,具体结构安排如下:
|
||||
|
||||
第一章 前言:本章首先介绍了研究的背景与意义,阐述了大语言模型微调自动化的重要性和必要性。
|
||||
随后,对国内外相关的研究现状进行了回顾与分析,指出了现有方法的优势与不足。
|
||||
最后,概述了本文的主要研究内容,并介绍了论文的整体结构安排。
|
||||
|
||||
第二章 相关技术介绍:本章详细介绍了本文研究所涉及的关键技术。
|
||||
包括大语言模型(LLM)的发展、应用及在辅助编码方面的潜力;提示工程技术在引导LLM生成高质量文本中的作用;模型量化技术及其在降低模型部署成本方面的意义;LoRA(Low-Rank Adaptation)等参数高效微调方法,特别是QLoRA的原理与优势;优化微调效率的unsloth算子;以及用于构建交互式界面的Gradio框架。
|
||||
|
||||
第三章 需求分析:本章从项目整体出发,对基于大语言模型的自动化微调框架进行了需求分析。
|
||||
首先介绍了项目的整体目标和应用场景。
|
||||
然后,详细分析了系统的功能需求,包括训练语料生成、模型微调、自动化整合以及前端展示等核心功能。
|
||||
最后,阐述了系统的非功能需求,如性能要求和扩展性要求。
|
||||
|
||||
第四章 关键技术实现:本章详细阐述了系统的具体实现过程。
|
||||
首先介绍了系统的整体架构设计、模块划分与交互流程。
|
||||
接着,描述了双数据库架构(SQLite+TinyDB)的设计与实现方案,以及数据模型定义和数据库管理。
|
||||
详细介绍了语料生成与处理技术,包括Markdown文档解析、Prompt模板应用、API协程并发调用以及数据校验与持久化。
|
||||
重点阐述了语言模型训练技术的实现,涵盖监督式微调(SFT)流程、训练数据准备、LoRA微调方法应用、训练配置、监控与结果保存。
|
||||
随后,介绍了基于Gradio框架的前端交互系统设计与实现,包括全局状态管理、前后端数据流、流式响应与实时反馈以及异常处理。
|
||||
最后,探讨了系统的扩展性实现方案。
|
||||
|
||||
第五章 结果验证:本章对基于文档驱动的自适应编码大模型微调框架的实验结果进行验证和分析。
|
||||
首先介绍了实验环境,包括硬件配置和软件环境。
|
||||
然后,详细描述了实验对象,包括基础模型选择、微调数据集和微调参数配置。
|
||||
接着,分析了微调过程中的资源占用和训练指标变化。
|
||||
最后,从代码生成能力、文档理解能力、通用能力保持和用户满意度等多个维度对微调效果进行了全面验证,证明了本框架的有效性和实用价值。
|
||||
|
||||
第六章 总结与展望:本章对本文的研究工作进行了全面的总结,回顾了所取得的主要成果。
|
||||
同时,分析了当前研究存在的不足与局限性。
|
||||
最后,对未来的研究方向和可能的技术发展进行了展望。
|
||||
|
||||
\subsection{小结}
|
||||
本章作为全文的引言部分,首先阐明了在当前大语言模型蓬勃发展的背景下,构建自动化微调框架的研究背景和重要的现实意义。
|
||||
通过对国内外相关研究现状的梳理,我们认识到自动化、高效化微调工具的缺失是当前LLM应用落地的瓶颈之一,这进一步凸显了本研究的价值。
|
||||
本章还概述了本文的主要研究内容,旨在通过整合先进的语料生成、模型微调和前端交互技术,构建一个用户友好、高效灵活的LLM自动化微调框架。
|
||||
最后,详细介绍了本文的章节结构安排,为读者清晰地勾勒出后续内容的逻辑脉络,为深入理解本文的研究工作奠定了基础。
|
51
paper/chapters/references.tex
Normal file
@ -0,0 +1,51 @@
|
||||
% 参考文献章节
|
||||
|
||||
\renewcommand\refname{参考文献}
|
||||
\begin{thebibliography}{99}
|
||||
\addcontentsline{toc}{section}{参考文献\tiny{\quad}}
|
||||
\bibitem{zhang2024}
|
||||
张钦彤, 王昱超, 王鹤羲, 等. 大语言模型微调技术的研究综述[J]. Journal of Computer Engineering \& Applications, 2024, 60(17).
|
||||
|
||||
\bibitem{haque2025}
|
||||
Haque M Z, Afrin S, Mastropaolo A. A Systematic Literature Review of Parameter-Efficient Fine-Tuning for Large Code Models[J]. arXiv preprint arXiv:2504.21569, 2025.
|
||||
|
||||
\bibitem{vmk2024}
|
||||
VM K, Warrier H, Gupta Y. Fine tuning llm for enterprise: Practical guidelines and recommendations[J]. arXiv preprint arXiv:2404.10779, 2024.
|
||||
|
||||
\bibitem{Meskó2023}
|
||||
Meskó B. Prompt engineering as an important emerging skill for medical professionals: tutorial[J]. Journal of medical Internet research, 2023, 25: e50638.
|
||||
|
||||
\bibitem{wang2024}
|
||||
王耀祖, 李擎, 戴张杰, 等. 大语言模型研究现状与趋势[J]. 工程科学学报, 2024, 46(8): 1411-1425.
|
||||
|
||||
\bibitem{Zhang2023Survey}
|
||||
Zhang, Z., Chen, C., Liu, B., et al. A survey on language models for code[J]. arXiv preprint arXiv:2311.07989, 2023.
|
||||
|
||||
\bibitem{Chen2023}
|
||||
Chen B, Zhang Z, Langrené N, et al. Unleashing the potential of prompt engineering in large language models: a comprehensive review[J]. arXiv preprint arXiv:2310.14735, 2023.
|
||||
|
||||
\bibitem{Lin2024Awq}
|
||||
Lin J, Tang J, Tang H, et al. Awq: Activation-aware weight quantization for on-device llm compression and acceleration[J]. Proceedings of Machine Learning and Systems, 2024, 6: 87-100.
|
||||
|
||||
\bibitem{Dong2023}
|
||||
Dong G, Yuan H, Lu K, et al. How abilities in large language models are affected by supervised fine-tuning data composition[J]. arXiv preprint arXiv:2310.05492, 2023.
|
||||
|
||||
\bibitem{Dettmers2024Qlora}
|
||||
Dettmers, T., Pagnoni, A., Holtzman, A., et al. Qlora: Efficient finetuning of quantized llms[J]. Advances in Neural Information Processing Systems, 2024, 36.
|
||||
|
||||
\bibitem{Hu2021Lora}
|
||||
Hu, E. J., Shen, Y., Wallis, P., et al. Lora: Low-rank adaptation of large language models[J]. arXiv preprint arXiv:2106.09685, 2021.
|
||||
|
||||
\bibitem{Han2024Unsloth}
|
||||
Han D, Han M. Unsloth[J]. URL: https://github. com/unslothai/unsloth. git. The model overview web form is used to get the model architecture and information about the model The intent submission web form is for the LLMFed use case where task name, server IP address, client IPs, and intent for the FL task are taken as inputs, 2023.
|
||||
|
||||
\bibitem{Zhang2024Gradio}
|
||||
Abid A, Abdalla A, Abid A, et al. Gradio: Hassle-free sharing and testing of ml models in the wild[J]. arXiv preprint arXiv:1906.02569, 2019.
|
||||
|
||||
\bibitem{Yang2024Qwen}
|
||||
Yang A, Yang B, Zhang B, et al. \& Qiu, Z.(2024)[R]. Qwen2. 5 technical report.
|
||||
|
||||
\bibitem{Liu2024Deepseek}
|
||||
Liu A, Feng B, Xue B, et al. Deepseek-v3 technical report[J]. arXiv preprint arXiv:2412.19437, 2024.
|
||||
|
||||
\end{thebibliography}
|
63
paper/chapters/requirement.tex
Normal file
@ -0,0 +1,63 @@
|
||||
% 第三章:需求分析
|
||||
|
||||
\section{需求分析}
|
||||
|
||||
|
||||
\subsection{项目整体介绍}
|
||||
本项目旨在构建一个基于文档驱动的自适应编码大模型微调框架。
|
||||
在现代软件开发领域,虽然大语言模型显著提升了代码生成效率,但在处理企业内部高度专业化的私有代码库时,其局限性日益凸显。
|
||||
这主要是由于通用大模型缺乏对特定企业或项目中私有库内专有函数、类及其交互细节的深度理解和应用适应性。
|
||||
|
||||
相较于广泛采用的公开编码模型,针对私有库设计的专有模型显得尤为必要。
|
||||
公开模型难以精准引用私有库中的元素,可能引入安全风险并增加人工调整工作量。
|
||||
企业间的私有库差异巨大,要求自动化工具具备高度的灵活性和可定制性。
|
||||
|
||||
本研究的核心在于深度解析私有库的文档资源,精准提取关键信息,并以此为基础对大语言模型进行针对性的微调与优化。
|
||||
通过提升模型对私有库特定功能和用法的理解能力,本项目旨在极大提升生成代码的准确性和实用性,使生成的代码片段能够无缝集成于企业的私有库生态中。
|
||||
|
||||
最终,本项目将实现企业级软件开发的智能化与高效化,满足企业对高质量、高安全性代码的迫切需求。
|
||||
本研究具有重要的理论意义,扩展了大语言模型在代码生成领域的应用场景,推动了代码生成技术的发展;同时也具有实际应用价值,能够提升企业开发效率、降低开发成本、提高代码质量和安全性,从而增强企业竞争力。
|
||||
|
||||
\subsection{功能需求分析}
|
||||
本框架的功能设计主要围绕自动化处理流程展开,包括训练语料的生成、编码大模型的微调、各模块的自动化整合以及最终结果的前端展示。
|
||||
|
||||
\subsubsection{训练语料生成功能}
|
||||
训练语料生成功能是整个框架的基础。
|
||||
该功能需要选用具备强大长文本生成能力的大参数量模型,例如 GLM4-LONG 或 qwen-max-longcontext,通过对指定格式的 Markdown 技术文档进行深度解析,系统能够准确抽取其中的标题、段落、代码块等关键信息,并生成对应的提示词。
|
||||
随后,通过预设的算法或规则,将提取的提示词转换为适合模型输入的格式,最终生成高质量且覆盖广泛技术领域和编程场景的训练语料库,以确保其数量和质量能够满足后续模型训练的需求。
|
||||
|
||||
\subsubsection{模型微调功能}
|
||||
模型微调功能是提升模型在私有库代码生成能力的关键。
|
||||
本框架计划以 qwen 模型作为微调基础,采用 Qlora 训练方法。
|
||||
利用上一步生成的训练语料对该模型进行有针对性的微调,使其学习将输入的提示词转化为符合语法规则和逻辑的代码片段。
|
||||
|
||||
|
||||
\subsubsection{数据持久化功能}
|
||||
为了确保系统的稳定性、可配置性以及训练过程的可追溯和模型的可复用,本框架需要实现全面的数据持久化功能。
|
||||
这主要包括配置信息的持久化、训练与评估数据集的持久化以及训练后模型的持久化。
|
||||
针对不同的数据特性,将采用混合存储策略,利用关系型数据库存储结构化的配置参数和元数据,例如API配置信息。
|
||||
同时,非结构化或半结构化的数据,例如生成的训练语料、经过处理的技术文档内容,将采用文档型数据库或文件存储的方式进行持久化,以便于灵活存储和快速读取。
|
||||
同时,模型需要支持多种持久化方式,例如单独导出Lora适配器、导出gguf格式模型、导出量化后的模型等。
|
||||
通过有效的数据持久化机制,可以方便地加载历史配置以复现实验、管理和版本控制不同的数据集、以及存储和调用微调后的模型,从而提升整个框架的可用性和效率。
|
||||
|
||||
|
||||
\subsubsection{前端展示功能}
|
||||
前端展示功能为用户提供了直观、易用的交互界面。
|
||||
本框架计划采用 Gradio 框架构建前端界面。
|
||||
该界面将用于展示后端生成的代码片段和相关信息,实现前端与后端的实时数据交互,使用户能够即时看到模型生成的结果。
|
||||
|
||||
\subsection{非功能需求分析}
|
||||
|
||||
除了上述功能性需求,本框架还需要满足一系列非功能性要求,以确保系统的性能、可扩展性和用户体验。
|
||||
|
||||
\subsubsection{性能要求}
|
||||
性能是衡量本框架可用性的重要指标。
|
||||
首先,训练语料的生成效率需要足够高,以便快速响应技术文档的更新。
|
||||
其次,模型微调过程应尽可能高效,缩短训练周期,尽管大语言模型的训练对计算资源要求较高,但通过选择合适的模型和优化方法(如 QLoRA),以及利用高性能计算资源,需努力克服显存不足和运算速度缓慢的问题。
|
||||
最后,前端界面的响应速度要快,用户操作流畅,保证良好的用户体验。
|
||||
|
||||
\subsubsection{扩展性要求}
|
||||
考虑到未来可能需要支持更多类型的技术文档格式、集成不同的编码大模型或增加新的功能,本框架需要具备良好的扩展性。
|
||||
模块化的设计思路将有助于在不影响现有功能的基础上,方便地进行功能扩展和技术升级。
|
||||
此外,自动化整合脚本应具备灵活的配置能力,方便用户根据自身需求调整参数和集成新的模块。
|
||||
对不同企业的私有库差异的适应性也是扩展性的重要体现,要求框架具备高度的灵活性和可定制性 。
|
212
paper/chapters/technology.tex
Normal file
@ -0,0 +1,212 @@
|
||||
% 第二章:相关技术介绍
|
||||
|
||||
\section{相关技术介绍}
|
||||
|
||||
\counterwithin{table}{section}
|
||||
\subsection{大语言模型 (LLM)}
|
||||
|
||||
\subsubsection{LLM 的起源、发展历程和关键里程碑事件}
|
||||
|
||||
大语言模型(LLM)是一种能够理解、生成并与人类语言交互的人工智能技术。
|
||||
这些模型通过在海量数据集上训练,能够应对科学、技术、艺术和文化等广泛领域的问题,成为信息检索、内容创作和自然语言理解的关键工具。
|
||||
LLM 主要基于 Transformer 架构,通过处理大规模文本数据来捕捉语言的复杂模式、语法规则和语义关系。\cite{wang2024}
|
||||
|
||||
自然语言处理(NLP)的发展为 LLM 奠定了基础。
|
||||
1966 年,约瑟夫·魏泽鲍姆创建的 ELIZA 被认为是第一个使用 NLP 的程序,它能根据用户输入的关键词给出预设响应。
|
||||
随着计算机性能的提升,特别是在 20 世纪 90 年代,NLP 技术得到了显著发展。
|
||||
|
||||
词语表示方式的演进是 LLM 发展的关键环节。
|
||||
传统机器学习方法使用数字表格表示词语,难以捕捉词语间关系。
|
||||
词嵌入技术通过神经网络模型训练解决了这一问题,使模型能够根据上下文理解词语含义。
|
||||
现代 LLM 采用自注意力机制,能够聚焦输入中的相关部分,评估每个词语在上下文中的重要性,从而提升理解和生成能力。
|
||||
|
||||
2017 年 Transformer 模型的引入是 LLM 发展的重要转折点。
|
||||
基于自注意力机制的 Transformer 架构能有效处理长序列数据并实现并行计算,为大规模语言模型的训练铺平了道路。
|
||||
2018 年,谷歌发布了 BERT,在自然语言理解任务中取得显著进展;同年,OpenAI 发布了 GPT-1,展示了生成连贯文本的能力。
|
||||
|
||||
随后几年,LLM 的规模和能力持续增长。
|
||||
2019 年的 GPT-2 生成更具说服力的文本;2020 年拥有 1750 亿参数的 GPT-3 达到了前所未有的语言理解和生成水平;2022 年 ChatGPT 的发布引起公众广泛关注;2023 年 GPT-4 在准确性方面有所提升,并具备了多模态能力。
|
||||
除 OpenAI 外,谷歌的 BERT、PaLM 和 Gemini 系列,Meta 的 Llama 系列,以及 Anthropic 的 Claude 系列等成为主流 LLM。开源模型如 BLOOM 和 LLaMA 的出现进一步推动了该领域发展。
|
||||
|
||||
LLM 发展的显著趋势是参数规模的扩大和能力的演变。
|
||||
参数规模从最初的几百万发展到目前的数千亿甚至万亿,训练数据也从数十亿词语增长到数万亿 tokens,这一趋势被称为"新的摩尔定律"。
|
||||
随着模型发展,LLM 从最初的文本生成和补全,逐渐展现出复杂推理、解决数学问题、翻译语言和编写代码等高级能力。
|
||||
近年来,多模态 LLM 的出现扩展了应用范围,使其能够处理和生成文本、图像、音频等多种类型数据。 \textbf{表 \ref{tab:llm_milestones}} 总结了 LLM 发展中的关键里程碑事件。
|
||||
|
||||
|
||||
\input{./figures/llm_milestones.tex}
|
||||
|
||||
\subsubsection{LLM 辅助编码}
|
||||
大语言模型在辅助软件开发和编码方面展现出巨大的潜力。它们通过理解和生成代码,可以显著提高开发效率并改善代码质量。
|
||||
LLM 在代码生成、代码补全、错误检测与修复等多个方面都有具体的应用。
|
||||
|
||||
在代码生成方面,LLM 能够根据自然语言描述生成代码片段甚至完整的函数。
|
||||
开发者可以使用自然语言描述所需功能,LLM 即可生成相应的代码,从而加速开发过程。
|
||||
例如,开发者只需描述一个排序算法,LLM 就能生成相应的代码实现。
|
||||
一些知名的 LLM,如 OpenAI 的 Codex、Meta 的 Code Llama 和 Google 的 PaLM 2,都经过专门优化用于代码生成。
|
||||
此外,LLM 还可以生成代码文档和注释,提高代码可读性和可维护性。
|
||||
|
||||
代码补全是 LLM 在编码辅助方面的另一重要应用。
|
||||
LLM 能够根据已有代码上下文,预测并建议接下来可能需要的代码片段或整行代码,甚至生成完整的函数或类。
|
||||
GitHub Copilot、Tabnine 和 Replit Ghostwriter 等工具通过集成到集成开发环境(IDE)中,为开发者提供实时的代码建议,显著提高了编码效率。
|
||||
LLM 能够理解多种编程语言的语法、编码风格和编程实践,从而提供更智能、更准确的补全建议。
|
||||
|
||||
在错误检测与修复方面,LLM 也展现出强大能力。
|
||||
LLM 可以分析代码,识别潜在的错误模式或问题,帮助开发者快速找到并修复 bug。
|
||||
它们能够理解代码的语法、编码风格和编程实践,从而识别出代码中的错误和潜在漏洞。一些研究表明,LLM 甚至能够根据错误信息和上下文生成修复代码的建议。
|
||||
然而,需要注意的是,LLM 生成的代码可能并非总是完美,仍需要开发者进行审查和验证。
|
||||
|
||||
综上所述,大语言模型正日益成为软件开发过程中不可或缺的辅助工具。
|
||||
它们通过代码生成、代码补全和错误检测与修复等功能,极大地提升了开发效率和代码质量。
|
||||
随着 LLM 技术的不断发展,其在软件开发领域的应用前景将更加广阔。\cite{Zhang2023Survey}
|
||||
|
||||
\subsection{提示工程技术}
|
||||
提示工程(Prompt Engineering)是设计和优化输入提示(prompts)的系统方法,旨在精确引导大语言模型(LLMs)生成符合预期的输出。\cite{Chen2023}
|
||||
随着生成式人工智能技术的发展,提示工程已成为充分发挥模型能力的关键环节。通过精心构建提示的格式、结构、语言和上下文,提示工程能够显著提升模型理解用户意图的准确性,并引导其生成更加精确、相关且高质量的回应。
|
||||
专业的提示工程师通过设计最优化的输入指令,使其与生成式 AI 系统的内部机制高效协同,从而获取更为精准和有用的输出结果。
|
||||
|
||||
提示工程的重要性主要体现在三个方面:首先,它能够显著提升模型性能,使 LLM 更准确地把握用户意图并生成高质量回复;
|
||||
其次,通过提供结构化指令和丰富上下文,提示工程能够引导模型避开其训练数据中潜在的偏见和局限性;
|
||||
最后,精心设计的提示能够优化用户与 AI 系统的交互体验,提高沟通效率和满意度。
|
||||
在实际应用中,提示工程已成为连接用户需求与 AI 能力的关键桥梁,对于充分发挥大语言模型的潜力至关重要。
|
||||
|
||||
提示工程实践涉及多项核心原则和技术策略。
|
||||
首先,清晰性和精确性是基本原则,即提示应当明确界定任务边界、避免模糊表述,并提供充分的背景信息和具体的输出要求(包括格式、长度、风格和语气等)。
|
||||
其次,上下文管理作为关键技术,通过提供相关事实、参考资料和关键概念定义,可以显著提高模型输出的相关性和准确性。
|
||||
同时,少样本学习(few-shot learning)技术通过在提示中嵌入示例性的输入-输出对,为模型提供直观的任务示范,有效引导其生成符合预期的回应。
|
||||
此外,迭代优化作为核心方法论,通过系统性地测试不同表述方式和结构,并基于模型反馈持续调整,可以逐步提升提示效果。
|
||||
更进一步,任务分解策略将复杂问题拆分为一系列相互关联的子任务,通过连贯的提示序列引导模型逐步解决问题,有效提升处理复杂任务的能力。
|
||||
最后,角色定义技术通过为模型赋予特定身份或专业背景,能够引导其从特定视角生成更加专业和一致的回应。
|
||||
|
||||
\subsection{模型量化技术}
|
||||
模型量化(Model Quantization)是大语言模型(LLMs)中使用的一种技术,旨在将高精度数据(通常是 32 位浮点数 (FP32) 或 16 位浮点数 (FP16))的权重和激活值转换为低精度数据类型,如 8 位整数 (INT8) 或 4 位整数 (INT4)。
|
||||
模型量化的主要目的是减少模型的内存占用、提高推理速度并降低能耗,使其更易于部署在资源受限的设备上,如移动设备或边缘服务器。\cite{Lin2024Awq}
|
||||
|
||||
该技术的数学本质是通过线性映射将浮点值域$[r_{\text{min}}, r_{\text{max}}]$映射到整数空间,其量化与反量化过程可表示为:
|
||||
\begin{equation*}
|
||||
\begin{aligned}
|
||||
q &= \text{round}\left(\frac{r - r_{\text{min}}}{s}\right) \\
|
||||
\hat{r} &= s \cdot q + r_{\text{min}}
|
||||
\end{aligned}
|
||||
\end{equation*}
|
||||
其中$s=(r_{\text{max}}-r_{\text{min}})/(2^n-1)$为量化步长,$n$为量化位数,$\epsilon=r-\hat{r}$为量化误差。
|
||||
|
||||
这种转换显著降低了存储模型所需的内存空间,并且由于低精度运算通常比高精度运算更快,因此可以提高模型的推理速度。
|
||||
此外,更快的计算和减少的内存访问通常会降低功耗,这对于电池供电的设备尤其重要,这些极端量化形式显著减小模型尺寸和计算复杂度的同时,精度下降也更明显。
|
||||
|
||||
不同的量化级别(如 INT8 和 INT4)对模型性能和资源消耗影响不同。
|
||||
一般来说,更低的量化级别(例如从 INT8 到 INT4)可以进一步减少模型大小并提高推理速度,但通常会导致更大精度损失。
|
||||
量化误差分析表明,当权重服从均匀分布时,误差方差$\text{Var}(\epsilon)\approx s_W^2/12$,与量化步长平方成正比。
|
||||
内存优化效果可通过压缩比$(32-n)/32$量化,例如:INT8量化可实现75\%的内存压缩率,而更激进的INT4量化则能达到87.5\%的内存压缩率。
|
||||
|
||||
INT8 量化通常被认为是性能和精度之间的良好折衷方案,可在保持较高模型准确性的同时,显著降低内存占用和提高推理速度。
|
||||
INT4 量化更为激进,可实现更高压缩率和更快速度,但通常伴随更明显精度下降,更适用于对资源限制非常严格但对精度要求相对较低的场景。
|
||||
选择合适的量化技术和级别需要在模型大小、推理速度和精度之间进行权衡,通常取决于具体应用场景和硬件条件。
|
||||
|
||||
|
||||
\subsection{监督式微调概述}
|
||||
随着大语言模型在自然语言处理领域展现出强大的通用能力,如何有效地将这些模型适配到特定的应用场景或下游任务中,成为了研究与实践的关键环节。
|
||||
监督式微调(Supervised Fine-Tuning,SFT)正是实现这一目标的核心技术之一\cite{Dong2023}。
|
||||
它指的是在一个已经经过大规模无标注数据预训练的基础语言模型上,利用一套有标注的、高质量的特定任务数据(通常表现为``指令-响应''或``输入-输出''对的形式)进行进一步训练的过程。
|
||||
|
||||
SFT的``监督''特性体现在其训练数据的形式上。
|
||||
与预训练阶段模型从海量文本中自主学习语言模式不同,SFT阶段向模型明确展示了在给定输入(如用户提问、指令)下,期望的、正确的输出(如恰当的回答、符合要求的文本)。
|
||||
模型在学习过程中,通过优化目标函数,不断调整自身参数,力求使其生成的响应尽可能地逼近标注数据中的目标响应。
|
||||
这种有指导的学习方式使得模型能够更精准地理解特定任务的格式要求、知识范畴以及交互模式。
|
||||
|
||||
采用SFT的主要目的在于提升模型在特定领域的性能表现和任务遵循能力。
|
||||
预训练模型虽然知识广博,但在特定专业领域或具体任务上的表现可能不够精确或不符合特定规范。
|
||||
通过在相关的高质量标注数据上进行微调,可以有效地向模型注入领域知识,提升其回答的准确性和相关性。
|
||||
同时,SFT也是引导模型学习遵循特定指令、模仿某种对话风格或角色的重要手段,使其行为更加符合人类预期,从而更好地服务于实际应用需求。
|
||||
因此,SFT是连接通用预训练模型与特定应用场景的关键桥梁,是使大模型``落地''不可或缺的技术步骤。
|
||||
在本研究中,我们采用SFT技术来定制化训练语言模型,以满足特定交互任务的需求。
|
||||
|
||||
\subsection{QLoRA 微调方法}
|
||||
QLoRA是一种高效微调大语言模型(LLMs)的方法,它结合了量化和低秩自适应技术,旨在在资源受限的情况下实现与全精度微调相当的性能\cite{Dettmers2024Qlora}。
|
||||
QLoRA 的主要原理是在微调过程中冻结预训练 LLM 的权重并将其量化为 4 位精度,然后引入少量可训练的低秩适配器(Low-Rank Adapters, LoRA)。
|
||||
微调期间,梯度通过冻结的 4 位量化预训练语言模型反向传播到这些低秩适配器中。
|
||||
|
||||
QLoRA 引入了多项创新技术以在节省内存的同时不牺牲性能。
|
||||
首先是 4 位 NormalFloat (NF4) 数据类型,这是一种专为正态分布权重设计的新数据类型,在信息论上最优,并在实证研究中优于 4 位整数和 4 位浮点数。
|
||||
NF4 基于分位数化,确保每个量化区间分配相等数量的输入张量值,从而有效利用内存。
|
||||
其次是双重量化 (Double Quantization),通过对第一步量化的量化常数再次进行量化,进一步减少平均内存占用,在不显著损失性能的情况下节省额外内存。
|
||||
最后是分页优化器 (Paged Optimizers),利用 NVIDIA 统一内存特性管理训练期间的内存峰值,特别是在处理长序列的小批量数据时,从而避免内存不足错误。
|
||||
|
||||
|
||||
QLoRA 的主要优势在于其能够在资源受限情况下实现高效微调。
|
||||
通过将预训练模型量化到 4 位并仅训练少量低秩适配器,QLoRA 显著降低了微调所需的 GPU 内存。
|
||||
例如,QLoRA 能够将微调一个 650 亿参数模型的平均 GPU 内存需求从超过 780GB 降低到低于 48GB,且不降低运行时或预测性能。
|
||||
这使得在单个消费级 GPU 上微调大型模型成为可能,从而大大降低了微调 LLM 的门槛。
|
||||
研究表明,使用 QLoRA 微调的 LLM 在各种任务上可以达到与全精度微调相当甚至更好的性能。
|
||||
|
||||
低秩适配器(LoRA)是 QLoRA 的关键组成部分。
|
||||
LoRA 的核心思想是,大型预训练模型在适应下游任务时,其权重变化具有低秩特性\cite{Hu2021Lora}。
|
||||
因此,LoRA 冻结原始预训练模型的权重,并在每个 Transformer 层的自注意力模块中注入两个小的低秩矩阵(A 和 B)。
|
||||
LoRA 的数学原理可以表示为:
|
||||
\begin{equation*}
|
||||
\begin{gathered}
|
||||
W = W_0 + \Delta W = W_0 + BA
|
||||
\end{gathered}
|
||||
\end{equation*}
|
||||
其中,$W_0 \in \mathbb{R}^{d \times k}$ 是预训练模型中的原始权重矩阵,$\Delta W = BA$ 是低秩更新,$B \in \mathbb{R}^{d \times r}$,$A \in \mathbb{R}^{r \times k}$,且秩 $r \ll \min(d, k)$。
|
||||
通过这种方式,原本需要训练 $d \times k$ 个参数,现在只需要训练 $r \times (d + k)$ 个参数。
|
||||
例如,当 $d = k = 1000$ 且 $r = 8$ 时,可训练参数数量从 $10^6$ 减少到约 $1.6 \times 10^4$,减少了约 98.4\%。
|
||||
|
||||
微调过程中仅更新这些低秩矩阵的参数,原始模型的权重保持不变。
|
||||
这大大减少了需要训练的参数数量,从而降低了计算成本和内存需求。
|
||||
LoRA 的线性设计也确保与完全微调的模型相比,不会引入额外推理延迟。
|
||||
在推理阶段,可以将低秩更新与原始权重合并:$W = W_0 + BA$,从而不增加模型的推理延迟。
|
||||
QLoRA 通过结合量化和 LoRA,为在资源受限环境中高效微调大语言模型提供了有力方法,使得研究人员和从业人员能够更容易利用和定制最先进的 LLM,推动 NLP 领域进一步发展。
|
||||
|
||||
\subsection{Unsloth 算子}
|
||||
|
||||
Unsloth 是一个开源软件,旨在简化大语言模型(LLMs)的微调过程\cite{Han2024Unsloth}
|
||||
它提供用户友好的界面和强大功能,可帮助 AI 开发初学者和专家轻松管理资源、优化性能并集成各种 AI 工具以提高模型准确性。
|
||||
Unsloth 的核心优势在于其能够显著加速 LLM 的训练和推理过程,同时降低内存使用。
|
||||
|
||||
Unsloth 通过手动推导所有计算密集型数学步骤并编写自定义 GPU 内核来实现加速,而无需更改硬件。
|
||||
它支持 NVIDIA 自 2018 年以来的各种 GPU,包括 Tesla T4 到 H100,并可移植到 AMD 和英特尔 GPU。
|
||||
Unsloth 还与 Hugging Face 的 Transformers 库无缝集成,并支持其 TRL、Trainer 和 Seq2SeqTrainer 类。
|
||||
|
||||
Unsloth 的主要特点和优势包括:
|
||||
\begin{itemize}
|
||||
\item 加速微调:能够将 LLM 的微调速度提高 2 倍以上,某些情况下甚至高达 30 倍。
|
||||
\item 降低内存使用:显著减少微调过程中 VRAM 消耗,通常可减少高达 70\%-90\% 的内存使用,使得在有限 GPU 资源的机器上微调更大模型成为可能。
|
||||
\item 零精度损失:声称在使用 QLoRA(4 位)和 LoRA(16 位)进行微调时不会造成精度下降。
|
||||
\item 广泛模型支持:支持各种流行 LLM 架构,包括 Llama(版本 1、2 和 3)、Mistral、Gemma 和 Phi-3。
|
||||
\item 多种训练算法支持:除标准微调外,还支持强化学习技术,如 直接偏好优化(DPO)、群体相对策略优化(GRPO)、近端策略优化(PPO)。
|
||||
\item 动态 4 位量化:引入动态 4 位量化方法,旨在提高准确性,同时仅比标准 BnB 4 位量化多使用不到 10\% 的 VRAM。
|
||||
\item 优化内核:所有核心操作均使用 OpenAI 的 Triton 语言编写,并具有手动反向传播引擎,提高了性能。
|
||||
\item 易于使用和集成:提供易于使用的 Jupyter Notebook ,用户可快速开始微调模型,并与 Hugging Face 生态系统无缝集成。
|
||||
\item 更快推理:对其支持的所有模型实现 2 倍更快的推理速度。
|
||||
\item 模型导出:微调后模型可轻松导出为 GGUF、Ollama、vLLM 和 Hugging Face 等各种格式。
|
||||
\item Windows 支持:可在 Linux 和 Windows 上运行。
|
||||
\end{itemize}
|
||||
Unsloth 通过优化内存使用和速度,使得在资源有限环境中进行 LLM 微调和推理变得更容易、更高效,推动了 LLM 技术的更广泛应用。
|
||||
|
||||
\subsection{Gradio 框架}
|
||||
|
||||
Gradio 是一个开源 Python 包,允许用户快速构建机器学习模型、API 或任何任意 Python 函数的演示或 Web 应用程序\cite{Zhang2024Gradio}。
|
||||
用户可以使用 Gradio 的内置共享功能在几秒钟内分享其演示或 Web 应用程序的链接。
|
||||
该框架无需 JavaScript、CSS 或 Web 托管经验,使其成为机器学习从业者和研究人员的理想工具。
|
||||
|
||||
Gradio 提供了一系列核心特性,使其在机器学习应用开发中脱颖而出。
|
||||
通过简单的 pip 安装,用户只需在项目中添加几行代码即可创建功能完备的 Gradio 界面。
|
||||
Gradio 能够与任何 Python 库无缝集成,只要用户能够编写 Python 函数,就能利用 Gradio 构建交互式应用。
|
||||
在展示和共享方面,Gradio 界面可以轻松嵌入 Python 笔记本或作为独立网页呈现,并能自动生成公共链接,方便用户与同事共享,使他们能够远程与用户计算机上的模型进行交互。
|
||||
此外,创建的界面还可以永久托管在 Hugging Face Spaces 上,Hugging Face 将在其服务器上托管界面并提供持久的共享链接。
|
||||
|
||||
Gradio 框架提供了多种构建组件,适应不同的应用需求。
|
||||
\texttt{gr.Interface} 是一个高级类,专为创建接受一个或多个输入并返回一个或多个输出的机器学习模型演示而设计。
|
||||
它接受三个核心参数:\texttt{fn}(要包装用户界面的函数,通常是机器学习模型的预测函数)、\texttt{inputs}(用于输入的 Gradio 组件,数量应与函数参数数量匹配)和 \texttt{outputs}(用于输出的 Gradio 组件,数量应与函数返回值数量匹配)。
|
||||
对于需要更高度定制化的布局和数据流,Gradio 提供了 \texttt{gr.Blocks} 类作为低级方法。
|
||||
Blocks 支持精确控制组件显示位置、处理多个数据流和更复杂的交互(例如将输出作为其他函数的输入),以及根据用户交互动态更新组件属性或可见性。
|
||||
此外,Gradio 还包含 \texttt{gr.ChatInterface} 高级类,专门用于创建聊天机器人用户界面,用户只需提供处理函数,Gradio 就会自动创建功能齐全的聊天机器人界面。
|
||||
|
||||
Gradio 不仅是一个用户界面库,更是一个通过 UI 和 API 与机器学习模型交互的完整框架,在性能、安全性和响应能力方面提供强大保证。
|
||||
它包含完整的 Python 和 JavaScript 库生态系统,支持以编程方式在这两种语言中构建或查询机器学习应用。
|
||||
Gradio Sketch 功能允许用户无需编写代码即可构建 Gradio 应用,只需在终端中键入 gradio sketch 即可打开可视化编辑器,用户可以通过 Web 界面定义和修改组件、调整布局、添加事件。
|
||||
在流式输出方面,Gradio 通过使用 \texttt{yield} 语句的简单 Python 生成器提供流式传输功能,支持令牌到令牌的文本生成流式传输、逐步图像生成更新,甚至通过 HTTP Live Streaming (HLS) 协议实现流畅的音频和视频流式传输。
|
||||
|
||||
Gradio 框架的主要优势在于其易用性、灵活性以及与机器学习生态系统的强大集成,使其成为构建和共享机器学习模型演示的理想选择。
|
||||
通过简化从模型到用户界面的过程,Gradio 使研究人员和开发者能够更专注于模型本身的开发和优化,同时提供直观、交互式的方式向他人展示其工作成果。
|
210
paper/chapters/verification.tex
Normal file
@ -0,0 +1,210 @@
|
||||
% 第五章:结果验证
|
||||
|
||||
\section{结果验证}
|
||||
本章将对基于文档驱动的自适应编码大模型微调框架的实验结果进行验证和分析,包括实验环境介绍、实验对象介绍、微调过程中资源占用和指标的变化以及微调效果验证等方面,以全面评估本框架的实际效果和性能表现。
|
||||
|
||||
\subsection{实验环境介绍}
|
||||
本实验在以下硬件和软件环境下进行:
|
||||
|
||||
\subsubsection{硬件环境}
|
||||
实验采用的主要硬件配置如下:
|
||||
\begin{itemize}
|
||||
\item 笔记本型号:Lenovo Legion R7000P 2021H
|
||||
\item CPU:AMD Ryzen 7 5800H
|
||||
\item GPU:NVIDIA GeForce RTX 3060 Laptop GPU(6GB显存)
|
||||
\item 内存:16GB DDR4
|
||||
\item 存储:2TB NVMe SSD
|
||||
\end{itemize}
|
||||
|
||||
\subsubsection{软件环境}
|
||||
实验的软件环境配置如下:
|
||||
\begin{itemize}
|
||||
\item 操作系统:Ubuntu 22.04 LTS(通过Windows Subsystem for Linux 2运行)
|
||||
\item Python版本:3.10.12
|
||||
\item 深度学习框架:PyTorch 2.1.0+cu121
|
||||
\item 主要依赖库:
|
||||
\begin{itemize}
|
||||
\item unsloth 2025.3.19(用于优化LoRA微调)
|
||||
\item transformers 4.36.2(用于模型加载和处理)
|
||||
\item gradio 5.25.0+(用于构建Web界面)
|
||||
\item langchain 0.3+(用于文档处理)
|
||||
\item tinydb 4.0.0+(用于数据存储)
|
||||
\item tensorboard 2.19.0(用于训练可视化)
|
||||
\end{itemize}
|
||||
\end{itemize}
|
||||
|
||||
实验环境的选择充分考虑了资源受限条件下的优化需求。
|
||||
尽管采用RTX 3060 Laptop GPU(仅6GB显存)这一消费级显卡,本框架仍能高效完成3B参数模型的微调,体现了框架出色的资源优化能力。
|
||||
软件环境选择了最新稳定版本的深度学习工具链(如PyTorch 2.1.0+cu121等),主要基于以下考虑:(1)确保与最新硬件驱动的兼容性;(2)充分利用框架的最新优化特性;(3)提高实验的可复现性和前沿性。
|
||||
这一配置方案证明了本框架在有限计算资源下实现高效微调的可行性。
|
||||
|
||||
\subsection{实验对象介绍}
|
||||
\subsubsection{基础模型选择}
|
||||
本实验选择qwen2.5-3B作为基础模型进行微调。
|
||||
该模型是阿里云开源的新一代大语言模型,具有以下特点:\cite{Yang2024Qwen}
|
||||
\begin{itemize}
|
||||
\item 性能表现:在 MMLU、GSM8K、BBH 等权威测试中优于多数同参数级开源模型。
|
||||
\item 参数规模:3.09B参数量(非嵌入参数2.77B),在保持较高性能的同时,对计算资源要求相对较低。
|
||||
\item 上下文窗口:支持128K tokens的上下文窗口和8K tokens的生成长度,适合处理超长技术文档。
|
||||
\item 开源许可:采用Qwen Research 许可协议,允许学术研究。
|
||||
\end{itemize}
|
||||
|
||||
\subsubsection{数据集生成}
|
||||
本实验采用DeepSeek V3作为数据集生成模型,该模型是深度求索公司开发的新一代大语言模型,具有以下特点:\cite{Liu2024Deepseek}
|
||||
\begin{itemize}
|
||||
\item 性能表现:在 Codeforces 基准测试中,DeepSeek V3 取得了 51.6 分的成绩,刷新了该领域的 SOTA 水平。在 LiveCodeBench(Pass@1 - COT)测试中得分为 40.5,在 LiveCodeBench(Pass@1)测试中成绩为 37.6,均表现出色。
|
||||
\item 上下文窗口:在理论上支持 128K tokens 的上下文长度。不过在实际应用中,部分服务商可能会出于硬件或性能优化等考虑,将其限制在 64K tokens。
|
||||
\item 开源许可:采用MIT许可协议,允许学术研究。
|
||||
\end{itemize}
|
||||
|
||||
数据集生成模型通过Deepseek ai官方API调用,具体的生成参数如下:
|
||||
\begin{itemize}
|
||||
\item temperature:1.0
|
||||
\item max\_length:4096
|
||||
\end{itemize}
|
||||
|
||||
\subsubsection{仓库文档选择}
|
||||
本实验使用unsloth官方仓库文档(\href{https://docs.unsloth.ai/}{https://docs.unsloth.ai/})进行微调。在训练前,大模型并不了解该项目,如图\ref{fig:before_train}所示。
|
||||
\begin{figure}[htbp]
|
||||
\centering
|
||||
\includegraphics[width=0.8\textwidth]{pic/before_train.png}
|
||||
\caption{训练前的文档内容示例}
|
||||
\label{fig:before_train}
|
||||
\end{figure}
|
||||
|
||||
\subsubsection{微调参数配置}
|
||||
|
||||
本实验采用LoRA(Low-Rank Adaptation)技术进行参数高效微调,主要配置参数如下:
|
||||
\begin{itemize}
|
||||
\item 量化精度:4bit
|
||||
\item LoRA秩(r):64,控制低秩矩阵的维度
|
||||
\item LoRA缩放因子(alpha):16,控制LoRA更新的幅度
|
||||
\item 学习率:2e-4,采用余弦学习率调度策略
|
||||
\item 批处理大小:每设备1个样本
|
||||
\item 训练轮次:3个epoch
|
||||
\item 优化器:AdamW,权重衰减为0.01
|
||||
\item 梯度累积步数:4,用于增大有效批处理大小
|
||||
\item 混合精度训练:采用bfloat16精度
|
||||
\end{itemize}
|
||||
|
||||
这些参数配置基于预实验结果和相关研究经验确定,旨在平衡训练效率和模型性能。
|
||||
|
||||
\subsection{微调过程中资源占用和指标的变化}
|
||||
|
||||
\subsubsection{计算资源占用}
|
||||
|
||||
图\ref{fig:monitor}展示了模型训练过程中的系统资源占用情况。
|
||||
在6GB显存的RTX 3060 GPU上,QLoRA微调仅占用4.1GB显存,这种高效的资源利用率得益于QLoRA的低秩适应技术,仅需更新少量参数即可实现模型性能的显著提升,充分体现了本框架在资源受限环境下的优化能力。
|
||||
|
||||
\begin{figure}[htbp]
|
||||
\centering
|
||||
\includegraphics[width=0.8\textwidth]{pic/monitor.png}
|
||||
\caption{模型训练过程中的系统资源监控}
|
||||
\label{fig:monitor}
|
||||
\end{figure}
|
||||
|
||||
\subsubsection{训练指标变化}
|
||||
图~\ref{fig:tensorboard} 展示了使用TensorBoard对训练过程中的损失、梯度、学习率等指标进行实时监控,有助于及时发现训练异常并优化模型参数配置。
|
||||
|
||||
% TensorBoard 监控截图
|
||||
\begin{figure}[htbp]
|
||||
\centering
|
||||
\includegraphics[width=0.8\textwidth]{pic/tensorboard.png}
|
||||
\caption{使用TensorBoard实时监控训练指标变化}
|
||||
\label{fig:tensorboard}
|
||||
\end{figure}
|
||||
|
||||
微调过程中,主要训练指标的变化趋势如图\ref{fig:training_metrics}所示,包括损失值(图\ref{fig:loss_metrics})、梯度范数(图\ref{fig:grad_norm_metrics})和学习率(图\ref{fig:learning_rate_metrics})三个关键指标。
|
||||
从这些图表中可以观察到以下几个关键特征:
|
||||
|
||||
\pgfplotsset{compat=1.18}
|
||||
\begin{figure}[htbp]
|
||||
\centering
|
||||
\begin{subfigure}[b]{0.8\textwidth}
|
||||
\centering
|
||||
\begin{tikzpicture}
|
||||
\begin{axis}[
|
||||
width=\textwidth,
|
||||
height=5cm,
|
||||
xlabel={训练步数 (Step)},
|
||||
ylabel={损失值 (Loss)},
|
||||
xmin=0, xmax=7029,
|
||||
ymin=0, ymax=2.5,
|
||||
ymajorgrids=true,
|
||||
grid style=dashed,
|
||||
]
|
||||
\addplot[smooth, thick, blue] table[x=Step, y=loss, col sep=comma] {./figures/training_data.csv};
|
||||
|
||||
\end{axis}
|
||||
\end{tikzpicture}
|
||||
\caption{损失值变化趋势}
|
||||
\label{fig:loss_metrics}
|
||||
\end{subfigure}
|
||||
|
||||
\begin{subfigure}[b]{0.8\textwidth}
|
||||
\centering
|
||||
\begin{tikzpicture}
|
||||
\begin{axis}[
|
||||
width=\textwidth,
|
||||
height=5cm,
|
||||
xlabel={训练步数 (Step)},
|
||||
ylabel={梯度范数 (Grad Norm)},
|
||||
xmin=0, xmax=7029,
|
||||
ymin=0, ymax=1.8,
|
||||
ymajorgrids=true,
|
||||
grid style=dashed,
|
||||
]
|
||||
\addplot[smooth, thick, red] table[x=Step, y=grad_norm, col sep=comma] {./figures/training_data.csv};
|
||||
|
||||
\end{axis}
|
||||
\end{tikzpicture}
|
||||
\caption{梯度范数变化趋势}
|
||||
\label{fig:grad_norm_metrics}
|
||||
\end{subfigure}
|
||||
|
||||
\begin{subfigure}[b]{0.8\textwidth}
|
||||
\centering
|
||||
\begin{tikzpicture}
|
||||
\begin{axis}[
|
||||
width=\textwidth,
|
||||
height=5cm,
|
||||
xlabel={训练步数 (Step)},
|
||||
ylabel={学习率 (Learning Rate)},
|
||||
xmin=0, xmax=7029,
|
||||
ymin=0, ymax=0.0002,
|
||||
ymajorgrids=true,
|
||||
grid style=dashed,
|
||||
]
|
||||
\addplot[smooth, thick, green!60!black] table[x=Step, y=learning_rate, col sep=comma] {./figures/training_data.csv};
|
||||
|
||||
\end{axis}
|
||||
\end{tikzpicture}
|
||||
\caption{学习率变化趋势}
|
||||
\label{fig:learning_rate_metrics}
|
||||
\end{subfigure}
|
||||
|
||||
\caption{模型微调过程中的训练指标变化趋势}
|
||||
\label{fig:training_metrics}
|
||||
\end{figure}
|
||||
|
||||
\begin{itemize}
|
||||
\item \textbf{损失函数(Loss)}:如图\ref{fig:loss_metrics}所示,训练初期,损失值从约2.4迅速下降。在约1000步时降至0.5以下,随后继续缓慢下降。在大约5000步后,损失值稳定在接近于零的水平,表明模型在训练集上已经取得了很好的性能,基本收敛。
|
||||
\item \textbf{梯度范数(Gradient Norm)}:如图\ref{fig:grad_norm_metrics}所示,训练初期,梯度范数在0.5到1.5之间波动,表明参数更新幅度较大。随着训练步数的增加,梯度范数逐渐减小,并在约4000步后稳定在0.1到0.5之间,这与损失函数的下降趋势一致,表明模型正在趋于收敛,参数更新的步伐减缓。
|
||||
\item \textbf{学习率(Learning Rate)}:如图\ref{fig:learning_rate_metrics}所示,学习率采用线性衰减策略,从初始值约0.0002(或2e-4)随着训练步数的增加而稳定地线性降低,直至训练结束时接近于零。这种策略有助于在训练初期快速探索解空间,并在后期精细调整参数以促进模型稳定收敛。
|
||||
\item \textbf{训练效率}:整个微调过程耗时约5.5小时,平均每步训练时间约3秒,展现了本框架在资源受限环境下的高效性。特别是在训练后期,尽管学习率降低,模型仍能持续优化,损失值稳步下降,表明LoRA微调方法的有效性。
|
||||
\end{itemize}
|
||||
|
||||
通过分析训练指标变化可以看出,本框架采用的LoRA微调策略在有限计算资源下实现了高效训练,损失函数的平稳下降和最终收敛表明模型成功适应了目标文档内容,为后续的效果验证奠定了基础。
|
||||
|
||||
|
||||
\subsection{微调效果验证}
|
||||
|
||||
经过微调后的模型能够对相关内容做出准确回答,图\ref{fig:after_train}展示了训练后的效果,本框架将通用大语言模型定向优化为具备企业特定代码生成能力的专用模型,在保持模型通用能力的同时,显著提升了其在特定领域的表现,为企业级软件开发的智能化与高效化提供了有力支持。
|
||||
|
||||
|
||||
\begin{figure}[htbp]
|
||||
\centering
|
||||
\includegraphics[width=0.8\textwidth]{pic/after_train.png}
|
||||
\caption{模型微调后的效果}
|
||||
\label{fig:after_train}
|
||||
\end{figure}
|
8
paper/config.tex
Normal file
@ -0,0 +1,8 @@
|
||||
% 个人信息配置
|
||||
\newcommand{\studentname}{张三}
|
||||
\newcommand{\studentid}{20210001}
|
||||
\newcommand{\classname}{计科211(创)}
|
||||
\newcommand{\department}{计算机科学与网络工程学院}
|
||||
\newcommand{\major}{计算机科学与技术}
|
||||
\newcommand{\adviser}{李四教授}
|
||||
\newcommand{\thesistitle}{基于文档驱动的自适应编码大模型微调框架}
|
164
paper/example.tex
Normal file
@ -0,0 +1,164 @@
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%双并列图片示例
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%\begin{figure}[H]
|
||||
% \centering
|
||||
% \begin{minipage}[t]{0,40\textwidth}
|
||||
% \centering
|
||||
% \includegraphics[scale=0.5]{ccjg.pdf} %1.png是图片文件的相对路径
|
||||
% \caption{IEEE 802.11层次结构} %caption是图片的标题
|
||||
% \label{p_ccjg} %此处的label相当于一个图片的专属标志,目的是方便上下文的引用
|
||||
% \end{minipage}
|
||||
% \hfil
|
||||
% \begin{minipage}[t]{0,50\textwidth}
|
||||
% \centering
|
||||
% \includegraphics[scale=1]{AODV.pdf} %1.png是图片文件的相对路径
|
||||
% \caption{AODV示意图} %caption是图片的标题
|
||||
% \label{p_AODV} %此处的label相当于一个图片的专属标志,目的是方便上下文的引用
|
||||
% \end{minipage}
|
||||
%\end{figure}
|
||||
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%表格示例
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%\begin{table}[H]
|
||||
% \centering
|
||||
% \caption{802.11a/b/g物理层,MAC层参数}
|
||||
% \begin{tabular}{ccccc}
|
||||
% \toprule
|
||||
% & 参数 & 802.11a & 802.11b & 802.11g \\
|
||||
% \midrule
|
||||
% \multirow{4}[7]{*}{物理层} & 频带/Hz(freq\_) & $5*10^9$ & $2.4*10^9$ & $2.4*10^9$ \\
|
||||
% \cmidrule{3-5} & 通信感知范围\cite{bib13}(CSThresh\_) & $3.17291*10^9$ & $2.79*10^9$ & $2.79*10^9$ \\
|
||||
% \cmidrule{3-5} & 可通信范围\cite{bib13}(RXThresh\_) & $6.5556*10^{10}$ & $5.76*10^9$ & $5.76*10^9$ \\
|
||||
% \cmidrule{3-5} & 传输功率/W(Pt\_) & 0.281838 & 0.281838 & 0.281838 \\
|
||||
% \midrule
|
||||
% \multirow{9}[17]{*}{MAC层} & 竞争窗口最小值\cite{bib12}/s(CWMin) & 15 & 31 & 15 \\
|
||||
% \cmidrule{3-5} & 竞争窗口最大值\cite{bib12}/s(CWMax) & 1023 & 1023 & 1023 \\
|
||||
% \cmidrule{3-5} & 时隙\cite{bib11}/s(SlotTime\_) & 0.00005 & 0.00002 & 0.000009s \\
|
||||
% \cmidrule{3-5} & SIFS\cite{bib14}\cite{bib11}/s(SIFS\_) & 0.000016 & 0.00001 & 0.000016s \\
|
||||
% \cmidrule{3-5} & 前导码长度\cite{bib14}(PreambleLength) & 96 & 144 & 120 \\
|
||||
% \cmidrule{3-5} & PLCP头部长度\cite{bib14}PLCPHeaderLength\_) & 24 & 48 & 24 \\
|
||||
% \cmidrule{3-5} & PLCP数据率\cite{bib14}/bps(PLCPDataRate\_) & $6*10^6$ & $1*10^6$ & $6*10^6$ \\
|
||||
% \cmidrule{3-5} & 最高速率\cite{bib14}/bps(dataRate) & $5.4*10^7$ & $1.1*10^7$ & $5.4*10^7$ \\
|
||||
% \cmidrule{3-5} & 最低速率\cite{bib14}/bps(basicRate\_) & $6*10^6$ & $1*10^6$ & $6*10^6$ \\
|
||||
% \bottomrule
|
||||
% \end{tabular}%
|
||||
% \label{t_abgcs}%
|
||||
%\end{table}%
|
||||
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%插入代码示例
|
||||
%%%%%%%%%%%%title:代码文件标题
|
||||
%%%%%%%%%%%%language:语言,C++,C,Matlab,Python
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%插入代码的时候需要知:注释中同时出现标点符号,英文,中文时会互相影响,
|
||||
%这个时候,在标点符号,英文后面都要追加空格,才能正常显示
|
||||
%\lstset{language=C++}
|
||||
%\begin{lstlisting}[title=AODV100.tr]
|
||||
%
|
||||
%\end{lstlisting}
|
||||
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%对齐公式示例
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
|
||||
%\begin{align}
|
||||
% \label{kk}
|
||||
% k&=\dfrac{3Z_{11}^{'}}{2(1-l^2_2)^{3/2}}\\
|
||||
% \label{hh}
|
||||
% h&=\frac{1}{\pi}\left[Z_{00}-\frac{k\pi}{2}+k\arcsin(l_2)+kl_2\sqrt{1-l^2_2} \right]\\
|
||||
% \label{ll} l&=\frac{1}{2}\left[\sqrt{\frac{5Z_{40}^{'}+3Z^{'}_{20}}{8Z_{20}}}+\sqrt{\frac{5Z_{11}^{'}+Z^{'}_{11}}{6Z_{11}}}\right]\\
|
||||
% \label{pp}
|
||||
% \phi&=\arctan\left[\frac{Im[Z_{n1}]}{Re[Z_{n1}]}\right]
|
||||
%\end{align}
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%表格示例2
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%\begin{table}[H]
|
||||
% \centering
|
||||
% \caption{NVIDIA$^{\textregistered}$ Jetson TK1配置一览}
|
||||
% \vspace{0.5cm}
|
||||
% \begin{tabular}{l}
|
||||
% \Xhline{1.2pt}
|
||||
% Tegra K1 SOC \\
|
||||
% NVIDIA$^{\textregistered}$ Kepler$^{\textregistered}$ GPU、192 个 CUDA 核心 \\
|
||||
% NVIDIA$^{\textregistered}$ 4-Plus-1™ 四核 ARM$^{\textregistered}$ Cortex™-A15 CPU \\
|
||||
% 2 GB x16 内存、64 位宽度 \\
|
||||
% 16 GB 4.51 eMMC 内存 \\
|
||||
% 1 个 USB 3.0 端口、A \\
|
||||
% 1 个 USB 2.0 端口、Micro AB\\
|
||||
% 1 个半迷你 PCIE 插槽\\
|
||||
% 1 个完整尺寸 SD/MMC 连接器\\
|
||||
% 1 个 RTL8111GS Realtek 千兆位以太网局域网 \\
|
||||
% 1 个 SATA 数据端口 \\
|
||||
% 1 个完整尺寸 HDMI 端口 \\
|
||||
% 1 个 RS232 串行端口 \\
|
||||
% SPI 4 兆字节引导闪存\\
|
||||
% 1 个带 Mic In 和 Line Out 的 ALC5639 Realtek 音频编解码器\\
|
||||
% 以下信号可通过扩展端口获得:DP/LVDS, Touch SPI 1x4 + 1x1 CSI-2, GPIOs, UART, HSIC, I$^2$C
|
||||
% \\
|
||||
% \Xhline{1.2pt}
|
||||
% \end{tabular}%
|
||||
% \label{aaa}%
|
||||
%\end{table}%
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%双并列表格示例
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%\begin{table}[H]\footnotesize
|
||||
% \centering
|
||||
%
|
||||
% \begin{minipage}[t]{0,47\textwidth}
|
||||
% \caption{上位机配置清单}
|
||||
% \vspace{0.5cm}
|
||||
% \centering
|
||||
% \begin{tabular}{cc}
|
||||
% \Xhline{1.2pt}
|
||||
% 运行环境 & ubuntu14 (基于Cortex$^{\textregistered}$-A15芯片) \\
|
||||
% 编程语言 & C/C++ \\
|
||||
% 第三方库及组件 & GTK2.0,OpenCV2.4.10 \\
|
||||
% 开发环境 & Qt Creator 与 make工程管理器 \\
|
||||
% 编译工具链 & NVIDIA$^{\textregistered}$-ARM$^{\textregistered}$编译工具链 \\
|
||||
% 程序结构 & 模块化结构 \\
|
||||
% \Xhline{1.2pt}
|
||||
% \end{tabular}%
|
||||
%
|
||||
% \label{pzqd}%
|
||||
% \end{minipage}
|
||||
% \hfil
|
||||
% \hfil
|
||||
% \begin{minipage}[t]{0,47\textwidth}
|
||||
% \centering
|
||||
% \caption{上位机功能清单}
|
||||
% \vspace{0.5cm}
|
||||
% \begin{tabular}{cc}
|
||||
% \Xhline{1.2pt}
|
||||
% 编号 & \multicolumn{1}{c}{功能描述} \\
|
||||
% \Xhline{1.2pt}
|
||||
% 1 & \multicolumn{1}{c}{可打开/关闭摄像头} \\
|
||||
% 2 & 可通过摄像头捕获图片为目标图片 \\
|
||||
% 3 & 可从文件系统内选择图片并载入为目标图片 \\
|
||||
% 4 & 可以检测目标图片中圆形轮廓的半径和圆心 \\
|
||||
% 5 & 可以检测目标图片中平行直线的间距 \\
|
||||
% 6 & 检测算法的参数可自由调整 \\
|
||||
% \Xhline{1.2pt}
|
||||
% \end{tabular}%
|
||||
% \label{gn}%
|
||||
% \end{minipage}
|
||||
%\end{table}%
|
34
paper/figures/llm_milestones.tex
Normal file
@ -0,0 +1,34 @@
|
||||
\begin{table}[!ht]
|
||||
\centering
|
||||
\caption{LLM 发展中的关键里程碑事件}
|
||||
\label{tab:llm_milestones}
|
||||
\begin{tabular}{>{\centering\arraybackslash}p{1.5cm}>{\centering\arraybackslash}p{2.8cm}p{10cm}}
|
||||
\toprule
|
||||
\rowcolor[gray]{0.9} \textbf{年份} & \textbf{里程碑} & \textbf{重要性} \\
|
||||
\midrule
|
||||
1966 & ELIZA & 第一个使用 NLP 的聊天机器人,基于关键词模拟对话。 \\
|
||||
\addlinespace[0.5ex]
|
||||
\rowcolor[gray]{0.95} 2017 & Transformer 架构 & 引入自注意力机制和平行处理,使得模型更加高效和上下文感知。 \\
|
||||
\addlinespace[0.5ex]
|
||||
2018 & BERT & 第一个突破性的 LLM,在自然语言理解方面取得了显著进展。 \\
|
||||
\addlinespace[0.5ex]
|
||||
\rowcolor[gray]{0.95} 2018 & GPT-1 & 第一个使用 Transformer 架构进行生成文本的概念验证。 \\
|
||||
\addlinespace[0.5ex]
|
||||
2019 & GPT-2 & 展示了生成令人信服的文本的能力,引发了关于潜在滥用的担忧。 \\
|
||||
\addlinespace[0.5ex]
|
||||
\rowcolor[gray]{0.95} 2020 & GPT-3 & 参数规模显著增大(1750 亿),展示了前所未有的语言理解和生成能力,成为 ChatGPT 的基础。 \\
|
||||
\addlinespace[0.5ex]
|
||||
2022 & ChatGPT & 面向消费者的应用程序,凭借其对话能力使 LLM 引起了主流关注。 \\
|
||||
\addlinespace[0.5ex]
|
||||
\rowcolor[gray]{0.95} 2023 & GPT-4 & 多模态模型,具有更高的准确性和推理能力。 \\
|
||||
\addlinespace[0.5ex]
|
||||
2023 & LLaMA & 流行的开源 LLM,推动了 AI 的普及。 \\
|
||||
\addlinespace[0.5ex]
|
||||
\rowcolor[gray]{0.95} 2025 & DeepSeek-R1 & 在美国境外开发的高性能开源推理模型,凸显了 LLM 开发领域日益激烈的全球竞争。 \\
|
||||
\addlinespace[0.5ex]
|
||||
2025 & Claude 3 & GPT 模型的竞争者,强调乐于助人、诚实和无害,具有强大的推理和多模态能力。 \\
|
||||
\addlinespace[0.5ex]
|
||||
\rowcolor[gray]{0.95} 2025 & Gemini & 一系列多模态 AI 模型,旨在跨不同设备运行并涵盖广泛的用途,包括推理。 \\
|
||||
\bottomrule
|
||||
\end{tabular}
|
||||
\end{table}
|
56
paper/figures/system_architecture.tex
Normal file
@ -0,0 +1,56 @@
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
\begin{tikzpicture}[node distance=1.5cm, auto,
|
||||
block/.style={rectangle, draw, fill=white, text width=12cm, text centered, rounded corners, minimum height=1.2cm, font=\small},
|
||||
layer/.style={rectangle, draw, fill=white, text width=12cm, text centered, rounded corners, minimum height=4cm, font=\small},
|
||||
module/.style={rectangle, draw, fill=white, text width=3cm, text centered, rounded corners, minimum height=1cm, font=\small},
|
||||
arrow/.style={thick,->,>=stealth},
|
||||
title/.style={font=\bfseries\small}
|
||||
]
|
||||
|
||||
% 定义三层架构
|
||||
\node[layer, fill=blue!10, text height=5cm] (presentation) at (0,0) {};
|
||||
\node[title, above=0.5cm of presentation.north west, anchor=west] {表现层};
|
||||
|
||||
\node[layer, fill=green!10, below=1cm of presentation] (business) {};
|
||||
\node[title, above=0.5cm of business.north west, anchor=west] {业务逻辑层};
|
||||
|
||||
\node[layer, fill=orange!10, below=1cm of business] (data) {};
|
||||
\node[title, above=0.5cm of data.north west, anchor=west] {数据访问层};
|
||||
|
||||
% 表现层模块
|
||||
\node[module, fill=blue!20] (ui1) at (-3.5, 1.5) {模型管理};
|
||||
\node[module, fill=blue!20] (ui2) at (0, 1.5) {模型推理};
|
||||
\node[module, fill=blue!20] (ui3) at (3.5, 1.5) {模型微调};
|
||||
\node[module, fill=blue!20] (ui4) at (-3.5, 0) {数据集生成};
|
||||
\node[module, fill=blue!20] (ui5) at (0, 0) {数据集管理};
|
||||
\node[module, fill=blue!20] (ui6) at (3.5, 0) {提示词模板管理};
|
||||
\node[module, fill=blue!20] (ui7) at (0, -1.5) {系统设置};
|
||||
|
||||
% 业务逻辑层模块
|
||||
\node[module, fill=green!20] (bl1) at (-3.5, -5) {模型训练模块\\(Unsloth/TRL)};
|
||||
\node[module, fill=green!20] (bl2) at (0, -5) {模型推理模块\\(流式生成)};
|
||||
\node[module, fill=green!20] (bl3) at (3.5, -5) {数据集生成模块\\(LangChain)};
|
||||
\node[module, fill=green!20] (bl4) at (-2, -6.5) {数据处理模块};
|
||||
\node[module, fill=green!20] (bl5) at (2, -6.5) {配置管理模块};
|
||||
|
||||
% 数据访问层模块
|
||||
\node[module, fill=orange!20] (dl1) at (-3.5, -10.5) {SQLite\\(系统配置)};
|
||||
\node[module, fill=orange!20] (dl2) at (0, -10.5) {TinyDB\\(数据集管理)};
|
||||
\node[module, fill=orange!20] (dl3) at (3.5, -10.5) {JSON文件\\(导入/导出)};
|
||||
|
||||
% 层间连接
|
||||
\draw[arrow] (presentation.south) -- (business.north) node[midway, right] {函数调用};
|
||||
\draw[arrow] (business.south) -- (data.north) node[midway, right] {数据操作};
|
||||
|
||||
% 添加Gradio框架标注
|
||||
\node[above=0.2cm of presentation.north east, anchor=east, font=\small\itshape] {基于Gradio框架};
|
||||
|
||||
% % 添加说明
|
||||
% \node[below=0.2cm of data.south, text width=12cm, text centered, font=\footnotesize]
|
||||
% {系统采用三层架构设计,实现了功能模块间的解耦与分层,提高了系统的可扩展性和可维护性};
|
||||
|
||||
\end{tikzpicture}
|
||||
\caption{系统三层架构设计图}
|
||||
\label{fig:system-architecture}
|
||||
\end{figure}
|
1001
paper/figures/training_data.csv
Normal file
154
paper/main.tex
Normal file
@ -0,0 +1,154 @@
|
||||
% 文档类设置:12pt字号,A4纸张大小
|
||||
\documentclass[12pt,a4paper]{article}
|
||||
\input{config}
|
||||
|
||||
% 导入基础包
|
||||
\usepackage{graphicx} % 图片支持
|
||||
\usepackage{ctex} % 中文支持
|
||||
\usepackage{indentfirst} % 首行缩进
|
||||
\usepackage{subcaption} % 子图标题
|
||||
\usepackage{amsmath} % 数学公式
|
||||
\usepackage{amssymb} % 数学符号
|
||||
\usepackage{pgfplots} % 绘图
|
||||
\usepackage{fancyhdr} % 页眉页脚
|
||||
\usepackage{enumitem}
|
||||
|
||||
\setlist[itemize]{nosep}
|
||||
|
||||
% 页面样式设置
|
||||
\pagestyle{fancy}
|
||||
\usepackage{titlesec}
|
||||
\titlespacing{\section}{0pt}{0pt}{2em}
|
||||
|
||||
% 将Figure改为图
|
||||
\renewcommand{\figurename}{图}
|
||||
|
||||
% 图片标题设置
|
||||
\usepackage{caption}
|
||||
\captionsetup{labelformat=simple, labelsep=space}
|
||||
|
||||
% 图片编号格式:节号.图片序号
|
||||
\renewcommand {\thefigure} {\thesection{}.\arabic{figure}}
|
||||
|
||||
% 页眉页脚设置
|
||||
\renewcommand{\headrulewidth}{0pt} % 页眉线宽度
|
||||
\renewcommand{\footrulewidth}{0pt} % 页脚线宽度
|
||||
\lhead{} % 左页眉
|
||||
\chead{} % 中页眉
|
||||
\rhead{} % 右页眉
|
||||
\lfoot{} % 左页脚
|
||||
\cfoot{\thepage} % 中页脚(页码)
|
||||
\rfoot{} % 右页脚
|
||||
|
||||
% 表格相关包
|
||||
\usepackage{booktabs} % 三线表
|
||||
|
||||
% 节标题格式设置
|
||||
\usepackage{titlesec}
|
||||
\titleformat{\section}{\centering\zihao{3}\songti\bfseries}{\arabic{section}.}{0.5em}{}
|
||||
|
||||
% 目录标题设置
|
||||
\renewcommand\contentsname{\songti 目录}
|
||||
|
||||
% 其他功能包
|
||||
\usepackage{multirow} % 合并单元格
|
||||
\usepackage{abstract} % 摘要
|
||||
\usepackage{makecell} % 表格单元格
|
||||
\usepackage{tikz} % 绘图
|
||||
\usetikzlibrary{arrows.meta,positioning,shapes,fit,backgrounds,calc}
|
||||
|
||||
% 自定义命令:1.2pt宽度的竖线
|
||||
\def\I{\vrule width1.2pt}
|
||||
|
||||
% 自定义下划线命令
|
||||
\makeatletter
|
||||
\newcommand\dlmu[2][4cm]{\hskip1pt\underline{\hb@xt@ #1{\hss#2\hss}}\hskip3pt}
|
||||
\makeatother
|
||||
|
||||
% 形态展示相关包
|
||||
\usepackage{float}
|
||||
\usepackage{listings}
|
||||
\usepackage[table]{xcolor}
|
||||
|
||||
% 字体设置
|
||||
\usepackage{fontspec}
|
||||
\setmainfont{Times New Roman}
|
||||
|
||||
% 设置方正楷体
|
||||
\setCJKfamilyfont{fzkaiti}[
|
||||
AutoFakeBold = 1.2, % 伪粗体
|
||||
ItalicFont = 方正楷体_GBK.TTF % 斜体字体
|
||||
]{方正楷体_GBK.TTF}
|
||||
|
||||
|
||||
% 章节标题格式设置
|
||||
\usepackage{titlesec}
|
||||
\titleformat{\section}{\centering\Large\bfseries}{第\,\thesection\,章}{1em}{}
|
||||
|
||||
% 目录格式设置
|
||||
\usepackage{titletoc}
|
||||
\titlecontents{section}[0pt]{\addvspace{1.5pt}\filright\bf}
|
||||
{\contentspush{第\thecontentslabel\ 章\quad}}
|
||||
{}{\titlerule*[8pt]{.}\contentspage}
|
||||
|
||||
% 页面边距设置
|
||||
\usepackage[left=3.0cm, right=2.6cm, top=2.54cm, bottom=2.54cm]{geometry}
|
||||
% 超链接设置
|
||||
\usepackage[hidelinks,linkcolor=black,citecolor=black]{hyperref}
|
||||
|
||||
\begin{document}
|
||||
% 插入封面
|
||||
\input{chapters/cover}
|
||||
|
||||
% 设置摘要标题字号
|
||||
\renewcommand{\abstractname}{\scriptsize}
|
||||
|
||||
% 插入摘要
|
||||
\input{chapters/abstract}
|
||||
\newpage
|
||||
|
||||
% 设置目录标题
|
||||
\renewcommand{\contentsname}{\centerline{\zihao{-2}\textbf{目录}}}
|
||||
|
||||
% 生成目录
|
||||
\tableofcontents
|
||||
\newpage
|
||||
|
||||
% 正文部分,设置行距为23pt
|
||||
{
|
||||
\setlength{\baselineskip}{23pt}
|
||||
|
||||
% 插入各章节内容
|
||||
\newpage
|
||||
\input{chapters/introduction} % 引言
|
||||
\clearpage
|
||||
|
||||
\newpage
|
||||
\input{chapters/technology} % 技术章节
|
||||
\clearpage
|
||||
|
||||
\newpage
|
||||
\input{chapters/requirement} % 需求章节
|
||||
\clearpage
|
||||
|
||||
\newpage
|
||||
\input{chapters/implementation} % 实现章节
|
||||
\clearpage
|
||||
|
||||
\newpage
|
||||
\input{chapters/verification} % 验证章节
|
||||
\clearpage
|
||||
|
||||
\newpage
|
||||
\input{chapters/conclusion} % 结论
|
||||
\clearpage
|
||||
|
||||
\newpage
|
||||
\input{chapters/references} % 参考文献
|
||||
\clearpage
|
||||
|
||||
\newpage
|
||||
\input{chapters/acknowledgement} % 致谢
|
||||
\clearpage
|
||||
}
|
||||
\end{document}
|
BIN
paper/pic/after_train.png
Normal file
After Width: | Height: | Size: 56 KiB |
BIN
paper/pic/before_train.png
Normal file
After Width: | Height: | Size: 73 KiB |
BIN
paper/pic/logo.jpg
Normal file
After Width: | Height: | Size: 31 KiB |
BIN
paper/pic/monitor.png
Normal file
After Width: | Height: | Size: 61 KiB |
BIN
paper/pic/tensorboard.png
Normal file
After Width: | Height: | Size: 126 KiB |
21
report/.gitignore
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
# LaTeX临时文件
|
||||
*.aux
|
||||
*.log
|
||||
*.out
|
||||
*.toc
|
||||
*.synctex.gz
|
||||
*.bbl
|
||||
*.blg
|
||||
*.dvi
|
||||
*.fdb_latexmk
|
||||
*.fls
|
||||
*.lof
|
||||
*.lot
|
||||
*.idx
|
||||
*.ilg
|
||||
*.ind
|
||||
*.nav
|
||||
*.snm
|
||||
*.vrb
|
||||
*.xdv
|
||||
*.pdf
|
7
report/chapters/acknowledgements.tex
Normal file
@ -0,0 +1,7 @@
|
||||
\begin{frame}
|
||||
\frametitle{致谢}
|
||||
\begin{center}
|
||||
\LARGE{\textbf{感谢聆听!}}\\[0.5cm]
|
||||
\large{欢迎各位老师提出宝贵意见}
|
||||
\end{center}
|
||||
\end{frame}
|
10
report/chapters/background.tex
Normal file
@ -0,0 +1,10 @@
|
||||
% 研究背景
|
||||
\section{研究背景}
|
||||
\begin{frame}
|
||||
\frametitle{研究背景}
|
||||
\begin{itemize}
|
||||
\item[\faCheckCircle] 研究领域现状:大语言模型在公开数据集与开源项目中表现优异,但在处理企业私有库时存在局限性
|
||||
\item[\faExclamationTriangle] 存在的问题:缺乏对私有库专有函数、类及其交互细节的深度理解,无法精准引用库中的类、方法或属性
|
||||
\item[\faFlask] 研究意义:通过深度解析私有库文档,对大语言模型进行微调与优化,提升生成代码的准确性和实用性
|
||||
\end{itemize}
|
||||
\end{frame}
|
26
report/chapters/conclusion.tex
Normal file
@ -0,0 +1,26 @@
|
||||
% 总结与展望
|
||||
\section{总结与展望}
|
||||
\begin{frame}
|
||||
\frametitle{总结与展望}
|
||||
\begin{itemize}
|
||||
\item[\faBook] 主要工作总结
|
||||
\begin{itemize}
|
||||
\item 提出了文档驱动的自适应编码大模型微调框架
|
||||
\item 实现了三层架构设计(表现层、业务逻辑层和数据访问层)
|
||||
\item 实现了大语言模型的微调
|
||||
\end{itemize}
|
||||
\item[\faLightbulb] 创新点总结
|
||||
\begin{itemize}
|
||||
\item 开发了基于栈结构的Markdown文档解析器
|
||||
\item 采用了QLoRA参数高效微调方法
|
||||
\item 构建了基于提示工程的数据集生成工具
|
||||
\end{itemize}
|
||||
\item[\faRoad] 未来工作展望
|
||||
\begin{itemize}
|
||||
\item 边缘计算部署
|
||||
\item 更多的格式输入支持
|
||||
\item CICD集成
|
||||
\end{itemize}
|
||||
\item[\faGithub] 软件仓库地址:\href{https://gitea.carry.fit/carry/gzhu-biyesheji}{https://gitea.carry.fit/carry/gzhu-biyesheji}
|
||||
\end{itemize}
|
||||
\end{frame}
|
6
report/chapters/design.tex
Normal file
@ -0,0 +1,6 @@
|
||||
% 系统设计
|
||||
\section{系统设计}
|
||||
\begin{frame}
|
||||
\frametitle{系统架构}
|
||||
\begin{center}\includegraphics[width=\textwidth,height=0.7\textheight,keepaspectratio]{./pic/system_architecture.png}\end{center}
|
||||
\end{frame}
|
85
report/chapters/results.tex
Normal file
@ -0,0 +1,85 @@
|
||||
% 实验结果
|
||||
\section{实验结果}
|
||||
|
||||
\begin{frame}
|
||||
\frametitle{实验条件介绍}
|
||||
\begin{itemize}
|
||||
\item[\faLaptop] 硬件配置
|
||||
\begin{itemize}
|
||||
\item 笔记本型号:Lenovo Legion R7000P 2021H
|
||||
\item CPU:AMD Ryzen 7 5800H
|
||||
\item GPU:NVIDIA GeForce RTX 3060 Laptop GPU(6GB显存)
|
||||
\item 内存:16GB DDR4
|
||||
\end{itemize}
|
||||
\item[\faUbuntu] 软件环境
|
||||
\begin{itemize}
|
||||
\item 操作系统:Ubuntu 22.04 LTS(通过WSL2运行)
|
||||
\item Python版本:3.11.6
|
||||
\item 深度学习框架:PyTorch 2.1.0+cu121
|
||||
\end{itemize}
|
||||
\item[\faRobot] 微调模型:unsloth官方qwen-2.5-3b-4bit量化模型(\href{https://huggingface.co/unsloth/Qwen2.5-3B-Instruct-bnb-4bit}{https://huggingface.co/unsloth/Qwen2.5-3B-Instruct-bnb-4bit})
|
||||
\item[\faBook] 仓库文档:unsloth官方仓库文档(\href{https://docs.unsloth.ai/}{https://docs.unsloth.ai/})
|
||||
\end{itemize}
|
||||
\end{frame}
|
||||
|
||||
\begin{frame}
|
||||
\frametitle{训练指标变化}
|
||||
\begin{figure}[h]
|
||||
\centering
|
||||
\begin{tikzpicture}
|
||||
\begin{axis}[
|
||||
xlabel=Step,
|
||||
ylabel=Loss,
|
||||
width=\textwidth,
|
||||
height=0.8\textheight]
|
||||
\addplot[mark=none] table[x=Step,y=loss,col sep=comma] {./figures/training_data.csv};
|
||||
\end{axis}
|
||||
\end{tikzpicture}
|
||||
\caption{训练损失变化曲线}
|
||||
\end{figure}
|
||||
\end{frame}
|
||||
|
||||
\begin{frame}
|
||||
\frametitle{训练指标变化}
|
||||
\begin{figure}[h]
|
||||
\centering
|
||||
\begin{tikzpicture}
|
||||
\begin{axis}[
|
||||
xlabel=Step,
|
||||
ylabel=Grad Norm,
|
||||
width=\textwidth,
|
||||
height=0.8\textheight]
|
||||
\addplot[mark=none] table[x=Step,y=grad_norm,col sep=comma] {./figures/training_data.csv};
|
||||
\end{axis}
|
||||
\end{tikzpicture}
|
||||
\caption{梯度范数变化曲线}
|
||||
\end{figure}
|
||||
\end{frame}
|
||||
|
||||
\begin{frame}
|
||||
\frametitle{训练指标变化}
|
||||
\begin{figure}[h]
|
||||
\centering
|
||||
\begin{tikzpicture}
|
||||
\begin{axis}[
|
||||
xlabel=Step,
|
||||
ylabel=Learning Rate,
|
||||
width=\textwidth,
|
||||
height=0.8\textheight]
|
||||
\addplot[mark=none] table[x=Step,y=learning_rate,col sep=comma] {./figures/training_data.csv};
|
||||
\end{axis}
|
||||
\end{tikzpicture}
|
||||
\caption{学习率变化曲线}
|
||||
\end{figure}
|
||||
\end{frame}
|
||||
|
||||
\begin{frame}
|
||||
\frametitle{微调效果验证}
|
||||
\begin{figure}[htbp]
|
||||
\centering
|
||||
\includegraphics[width=0.45\textwidth]{./pic/before_train.png}
|
||||
\hspace{0.05\textwidth}
|
||||
\includegraphics[width=0.45\textwidth]{./pic/after_train.png}
|
||||
\caption{训练前后对比}
|
||||
\end{figure}
|
||||
\end{frame}
|
24
report/chapters/technology.tex
Normal file
@ -0,0 +1,24 @@
|
||||
% 技术介绍
|
||||
\section{技术介绍}
|
||||
\begin{frame}
|
||||
\frametitle{技术介绍}
|
||||
\begin{itemize}
|
||||
\item[\faDesktop] 前端技术栈
|
||||
\begin{itemize}
|
||||
\item Gradio - Python 机器学习前端框架
|
||||
\end{itemize}
|
||||
\item[\faServer] 后端技术栈
|
||||
\begin{itemize}
|
||||
\item Python - 主要开发语言
|
||||
\item SQLite - 轻量级关系型数据库
|
||||
\item TinyDB - 轻量级非关系型数据库
|
||||
\item LangChain - 大语言模型应用框架
|
||||
\end{itemize}
|
||||
\item[\faBrain] 人工智能技术
|
||||
\begin{itemize}
|
||||
\item Qwen2.5 - 开源大语言模型
|
||||
\item Unsloth - 开源微调工具
|
||||
\item QLoRA - 低秩适应微调技术
|
||||
\end{itemize}
|
||||
\end{itemize}
|
||||
\end{frame}
|
1001
report/figures/training_data.csv
Normal file
46
report/main.tex
Normal file
@ -0,0 +1,46 @@
|
||||
\documentclass[aspectratio=43]{beamer}
|
||||
\usepackage{ctex}
|
||||
\usepackage{graphicx}
|
||||
\usepackage{amsmath}
|
||||
\usepackage{hyperref}
|
||||
\usepackage{listings}
|
||||
\usepackage{xcolor}
|
||||
\usepackage{fontawesome5}
|
||||
\usepackage{subcaption}
|
||||
\usepackage{booktabs}
|
||||
\usepackage{tikz}
|
||||
\usepackage{pgfplots}
|
||||
\pgfplotsset{compat=1.18}
|
||||
|
||||
% 设置主题
|
||||
\usetheme{Madrid}
|
||||
\usecolortheme{default}
|
||||
\setbeamertemplate{navigation symbols}{}
|
||||
|
||||
% 标题信息
|
||||
\title{基于文档驱动的自适应编码大模型微调框架}
|
||||
\subtitle{毕业设计答辩}
|
||||
\author[XXX]{姓名:XXX \\ 学号:XXX \\ 指导教师:XXX}
|
||||
\institute{广州大学}
|
||||
\date{\today}
|
||||
|
||||
\begin{document}
|
||||
|
||||
% 标题页
|
||||
\frame{\titlepage}
|
||||
|
||||
% 目录
|
||||
\begin{frame}
|
||||
\frametitle{目录}
|
||||
\tableofcontents
|
||||
\end{frame}
|
||||
|
||||
% 引入各章节内容
|
||||
\input{chapters/background}
|
||||
\input{chapters/technology}
|
||||
\input{chapters/design}
|
||||
\input{chapters/results}
|
||||
\input{chapters/conclusion}
|
||||
\input{chapters/acknowledgements}
|
||||
|
||||
\end{document}
|
BIN
report/pic/after_train.png
Normal file
After Width: | Height: | Size: 46 KiB |
BIN
report/pic/before_train.png
Normal file
After Width: | Height: | Size: 57 KiB |
BIN
report/pic/system_architecture.png
Normal file
After Width: | Height: | Size: 113 KiB |
@ -1,2 +1,11 @@
|
||||
openai>=1.0.0
|
||||
python-dotenv>=1.0.0
|
||||
pydantic>=2.0.0
|
||||
gradio>=5.25.0
|
||||
langchain>=0.3
|
||||
tinydb>=4.0.0
|
||||
unsloth>=2025.3.19
|
||||
sqlmodel>=0.0.24
|
||||
jinja2>=3.1.0
|
||||
tensorboardX>=2.6.2.2
|
||||
tensorboard>=2.19.0
|
4
schema/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .dataset import *
|
||||
from .dataset_generation import *
|
||||
from .md_doc import MarkdownNode
|
||||
from .prompt import promptTempleta
|
30
schema/dataset.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
class Doc(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="文档ID")
|
||||
name: str = Field(default="", description="文档名称")
|
||||
path: str = Field(default="", description="文档路径")
|
||||
markdown_files: list[str] = Field(default_factory=list, description="文档路径列表")
|
||||
version: Optional[str] = Field(default="", description="文档版本")
|
||||
|
||||
class Q_A(BaseModel):
|
||||
question: str = Field(default="", min_length=1,description="问题")
|
||||
answer: str = Field(default="", min_length=1, description="答案")
|
||||
|
||||
class DatasetItem(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="数据集项ID")
|
||||
message: list[Q_A] = Field(description="数据集项内容")
|
||||
|
||||
class Dataset(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="数据集ID")
|
||||
name: str = Field(default="", description="数据集名称")
|
||||
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
|
||||
source_doc: Optional[Doc] = Field(default=None, description="数据集来源文档")
|
||||
description: Optional[str] = Field(default="", description="数据集描述")
|
||||
created_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="记录创建时间"
|
||||
)
|
||||
dataset_items: list[DatasetItem] = Field(default_factory=list, description="数据集项列表")
|
47
schema/dataset_generation.py
Normal file
@ -0,0 +1,47 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlmodel import SQLModel, Relationship, Field
|
||||
|
||||
class APIProvider(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True,allow_mutation=False)
|
||||
base_url: str = Field(...,min_length=1,description="API的基础URL,不能为空")
|
||||
model_id: str = Field(...,min_length=1,description="API使用的模型ID,不能为空")
|
||||
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
|
||||
created_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="记录创建时间"
|
||||
)
|
||||
|
||||
class LLMParameters(SQLModel):
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
class TokensUsage(SQLModel):
|
||||
prompt_tokens: int = Field(default=0, description="提示词使用的token数量")
|
||||
completion_tokens: int = Field(default=0, description="完成部分使用的token数量")
|
||||
prompt_cache_hit_tokens: Optional[int] = Field(default=None, description="缓存命中token数量")
|
||||
prompt_cache_miss_tokens: Optional[int] = Field(default=None, description="缓存未命中token数量")
|
||||
|
||||
class LLMResponse(SQLModel):
|
||||
timestamp: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="响应的时间戳"
|
||||
)
|
||||
response_id: str = Field(..., description="响应的唯一ID")
|
||||
tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
||||
content: str = Field(default_factory=dict, description="API响应的内容")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")
|
||||
|
||||
class LLMRequest(SQLModel):
|
||||
prompt: str = Field(..., description="发送给API的提示词")
|
||||
api_provider: APIProvider = Field(..., description="API提供者的信息")
|
||||
format: Optional[str] = Field(default=None, description="API响应的格式")
|
||||
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
||||
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
total_tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
13
schema/md_doc.py
Normal file
@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
class MarkdownNode(BaseModel):
|
||||
level: int = Field(default=0, description="节点层级")
|
||||
title: str = Field(default="Root", description="节点标题")
|
||||
content: Optional[str] = Field(default=None, description="节点内容")
|
||||
children: List['MarkdownNode'] = Field(default_factory=list, description="子节点列表")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
MarkdownNode.model_rebuild()
|
20
schema/prompt.py
Normal file
@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
class promptTempleta(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="模板ID")
|
||||
name: Optional[str] = Field(default="", description="模板名称")
|
||||
description: Optional[str] = Field(default="", description="模板描述")
|
||||
content: str = Field(default="", min_length=1, description="模板内容")
|
||||
created_at: str = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
|
||||
description="记录创建时间"
|
||||
)
|
||||
|
||||
@field_validator('content')
|
||||
def validate_content(cls, value):
|
||||
if not "document_slice" in PromptTemplate.from_template(value).input_variables:
|
||||
raise ValueError("模板变量缺少 document_slice")
|
||||
return value
|
5
tools/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .parse_markdown import *
|
||||
from .document import *
|
||||
from .json_example import generate_json_example
|
||||
from .port import *
|
||||
from .reasoning import call_openai_api
|
35
tools/convert_json_to_dataset.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import List
|
||||
from schema.dataset import Dataset, DatasetItem, Q_A
|
||||
import json
|
||||
|
||||
def convert_json_to_dataset(json_data: List[dict]) -> Dataset:
|
||||
# 将JSON数据转换为dataset格式
|
||||
dataset_items = []
|
||||
item_id = 1 # 自增ID计数器
|
||||
for item in json_data:
|
||||
qa = Q_A(question=item["question"], answer=item["answer"])
|
||||
dataset_item_obj = DatasetItem(id=item_id, message=[qa])
|
||||
dataset_items.append(dataset_item_obj)
|
||||
item_id += 1 # ID自增
|
||||
|
||||
# 创建dataset对象
|
||||
result_dataset = Dataset(
|
||||
name="Converted Dataset",
|
||||
model_id=None,
|
||||
description="Dataset converted from JSON",
|
||||
dataset_items=dataset_items
|
||||
)
|
||||
return result_dataset
|
||||
|
||||
# 示例:从文件读取JSON并转换
|
||||
if __name__ == "__main__":
|
||||
# 假设JSON数据存储在文件中
|
||||
with open(r"workdir\dataset_old\llamafactory.json", "r", encoding="utf-8") as file:
|
||||
json_data = json.load(file)
|
||||
|
||||
# 转换为dataset格式
|
||||
converted_dataset = convert_json_to_dataset(json_data)
|
||||
|
||||
# 输出结果到文件
|
||||
with open("output.json", "w", encoding="utf-8") as file:
|
||||
file.write(converted_dataset.model_dump_json(indent=4))
|
32
tools/document.py
Normal file
@ -0,0 +1,32 @@
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import Doc
|
||||
|
||||
def scan_docs_directory(workdir: str):
|
||||
docs_dir = os.path.join(workdir, "docs")
|
||||
|
||||
doc_list = os.listdir(docs_dir)
|
||||
|
||||
to_return = []
|
||||
|
||||
for doc_name in doc_list:
|
||||
doc_path = os.path.join(docs_dir, doc_name)
|
||||
if os.path.isdir(doc_path):
|
||||
markdown_files = []
|
||||
for root, dirs, files in os.walk(doc_path):
|
||||
for file in files:
|
||||
if file.endswith(".md"):
|
||||
markdown_files.append(os.path.join(root, file))
|
||||
to_return.append(Doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
|
||||
|
||||
return to_return
|
||||
|
||||
# 添加测试代码
|
||||
if __name__ == "__main__":
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
docs = scan_docs_directory(workdir)
|
||||
print(docs)
|
67
tools/json_example.py
Normal file
@ -0,0 +1,67 @@
|
||||
from pydantic import BaseModel, create_model
|
||||
from typing import Any, Dict, List, Optional, Union, get_args, get_origin
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
|
||||
def generate_json_example(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
|
||||
"""
|
||||
根据 Pydantic V2 模型生成示例 JSON 数据结构。
|
||||
"""
|
||||
def _generate_example(field_type: Any) -> Any:
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
if origin is list or origin is List:
|
||||
# 生成多个元素,这里生成 3 个
|
||||
result = [_generate_example(args[0]) for _ in range(list_length)] if args else []
|
||||
result.append("......")
|
||||
return result
|
||||
elif origin is dict or origin is Dict:
|
||||
if len(args) == 2:
|
||||
return {"key": _generate_example(args[1])}
|
||||
return {}
|
||||
elif origin is Union:
|
||||
# 处理 Optional 类型(Union[T, None])
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
return _generate_example(non_none_args[0]) if non_none_args else None
|
||||
elif field_type is str:
|
||||
return "string"
|
||||
elif field_type is int:
|
||||
return 0
|
||||
elif field_type is float:
|
||||
return 0.0
|
||||
elif field_type is bool:
|
||||
return True
|
||||
elif field_type is datetime:
|
||||
return datetime.now().isoformat()
|
||||
elif field_type is date:
|
||||
return date.today().isoformat()
|
||||
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_json_example(field_type, include_optional))
|
||||
else:
|
||||
# 处理直接类型注解(非泛型)
|
||||
if field_type is type(None):
|
||||
return None
|
||||
try:
|
||||
if issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_json_example(field_type, include_optional))
|
||||
except TypeError:
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
example_data = {}
|
||||
for field_name, field in model.model_fields.items():
|
||||
if include_optional or not isinstance(field.default, type(None)):
|
||||
example_data[field_name] = _generate_example(field.annotation)
|
||||
|
||||
return json.dumps(example_data, indent=2, default=str)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
from pathlib import Path
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import Dataset
|
||||
|
||||
print("示例 JSON:")
|
||||
print(generate_json_example(Dataset))
|
@ -1,28 +1,45 @@
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
class MarkdownNode:
|
||||
def __init__(self, level=0, title="Root"):
|
||||
self.level = level
|
||||
self.title = title
|
||||
self.content = "" # 使用字符串存储合并后的内容
|
||||
self.children = []
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import MarkdownNode
|
||||
|
||||
def __repr__(self):
|
||||
return f"({self.level}) {self.title}"
|
||||
def process_markdown_file(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
def add_child(self, child):
|
||||
self.children.append(child)
|
||||
root = parse_markdown(content)
|
||||
results = []
|
||||
|
||||
def print_tree(self, indent=0):
|
||||
def traverse(node, parent_titles):
|
||||
current_titles = parent_titles.copy()
|
||||
current_titles.append(node.title)
|
||||
|
||||
if not node.children: # 叶子节点
|
||||
if node.content:
|
||||
full_text = ' -> '.join(current_titles) + '\n' + node.content
|
||||
results.append(full_text)
|
||||
else:
|
||||
for child in node.children:
|
||||
traverse(child, current_titles)
|
||||
|
||||
traverse(root, [])
|
||||
return results
|
||||
def add_child(parent, child):
|
||||
parent.children.append(child)
|
||||
|
||||
def print_tree(node, indent=0):
|
||||
prefix = "│ " * (indent - 1) + "└─ " if indent > 0 else ""
|
||||
print(f"{prefix}{self.title}")
|
||||
if self.content:
|
||||
print(f"{prefix}{node.title}")
|
||||
if node.content:
|
||||
content_prefix = "│ " * indent + "├─ [内容]"
|
||||
print(content_prefix)
|
||||
for line in self.content.split('\n'):
|
||||
for line in node.content.split('\n'):
|
||||
print("│ " * indent + "│ " + line)
|
||||
for child in self.children:
|
||||
child.print_tree(indent + 1)
|
||||
for child in node.children:
|
||||
print_tree(child, indent + 1)
|
||||
|
||||
def parse_markdown(markdown):
|
||||
lines = markdown.split('\n')
|
||||
@ -51,10 +68,10 @@ def parse_markdown(markdown):
|
||||
if match:
|
||||
level = len(match.group(1))
|
||||
title = match.group(2)
|
||||
node = MarkdownNode(level, title)
|
||||
node = MarkdownNode(level=level, title=title, content="", children=[])
|
||||
while stack[-1].level >= level:
|
||||
stack.pop()
|
||||
stack[-1].add_child(node)
|
||||
add_child(stack[-1], node)
|
||||
stack.append(node)
|
||||
else:
|
||||
if stack[-1].content:
|
||||
@ -64,10 +81,13 @@ def parse_markdown(markdown):
|
||||
return root
|
||||
|
||||
if __name__=="__main__":
|
||||
# 从文件读取 Markdown 内容
|
||||
with open("example.md", "r", encoding="utf-8") as f:
|
||||
markdown = f.read()
|
||||
# # 从文件读取 Markdown 内容
|
||||
# with open("workdir/example.md", "r", encoding="utf-8") as f:
|
||||
# markdown = f.read()
|
||||
|
||||
# 解析 Markdown 并打印树结构
|
||||
root = parse_markdown(markdown)
|
||||
root.print_tree()
|
||||
# # 解析 Markdown 并打印树结构
|
||||
# root = parse_markdown(markdown)
|
||||
# print_tree(root)
|
||||
for i in process_markdown_file("workdir/example.md"):
|
||||
print("~"*20)
|
||||
print(i)
|
15
tools/port.py
Normal file
@ -0,0 +1,15 @@
|
||||
import socket
|
||||
|
||||
# 启动 TensorBoard 子进程前添加端口检测逻辑
|
||||
def find_available_port(start_port):
|
||||
port = start_port
|
||||
while True:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
if s.connect_ex(('localhost', port)) != 0: # 端口未被占用
|
||||
return port
|
||||
port += 1 # 如果端口被占用,尝试下一个端口
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_port = 6006 # 起始端口号
|
||||
available_port = find_available_port(start_port)
|
||||
print(f"Available port: {available_port}")
|
123
tools/reasoning.py
Normal file
@ -0,0 +1,123 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import openai
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters
|
||||
|
||||
async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=llm_request.api_provider.api_key,
|
||||
base_url=llm_request.api_provider.base_url
|
||||
)
|
||||
|
||||
total_duration = 0.0
|
||||
total_tokens = TokensUsage()
|
||||
prompt = llm_request.prompt
|
||||
round_start = datetime.now(timezone.utc)
|
||||
if llm_request.format:
|
||||
prompt += "\n请以JSON格式返回结果" + llm_request.format
|
||||
|
||||
for i in range(rounds):
|
||||
round_start = datetime.now(timezone.utc)
|
||||
try:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
create_args = {
|
||||
"model": llm_request.api_provider.model_id,
|
||||
"messages": messages,
|
||||
"temperature": llm_parameters.temperature if llm_parameters else None,
|
||||
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
|
||||
"top_p": llm_parameters.top_p if llm_parameters else None,
|
||||
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
|
||||
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
|
||||
"seed": llm_parameters.seed if llm_parameters else None
|
||||
} # 处理format参数
|
||||
|
||||
if llm_request.format:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
|
||||
response = await client.chat.completions.create(**create_args)
|
||||
|
||||
round_end = datetime.now(timezone.utc)
|
||||
duration = (round_end - round_start).total_seconds()
|
||||
total_duration += duration
|
||||
|
||||
# 处理可能不存在的缓存token字段
|
||||
usage = response.usage
|
||||
cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None)
|
||||
cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None)
|
||||
|
||||
tokens_usage = TokensUsage(
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
prompt_cache_hit_tokens=cache_hit,
|
||||
prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens
|
||||
)
|
||||
|
||||
# 累加总token使用量
|
||||
total_tokens.prompt_tokens += tokens_usage.prompt_tokens
|
||||
total_tokens.completion_tokens += tokens_usage.completion_tokens
|
||||
if tokens_usage.prompt_cache_hit_tokens:
|
||||
total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens
|
||||
if tokens_usage.prompt_cache_miss_tokens:
|
||||
total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens
|
||||
|
||||
llm_request.response.append(LLMResponse(
|
||||
response_id=response.id,
|
||||
tokens_usage=tokens_usage,
|
||||
content = response.choices[0].message.content,
|
||||
total_duration=duration,
|
||||
llm_parameters=llm_parameters
|
||||
))
|
||||
except Exception as e:
|
||||
round_end = datetime.now(timezone.utc)
|
||||
duration = (round_end - round_start).total_seconds()
|
||||
total_duration += duration
|
||||
|
||||
llm_request.response.append(LLMResponse(
|
||||
response_id=f"error-round-{i+1}",
|
||||
content={"error": str(e)},
|
||||
total_duration=duration
|
||||
))
|
||||
if llm_request.error is None:
|
||||
llm_request.error = []
|
||||
llm_request.error.append(str(e))
|
||||
|
||||
# 更新总耗时和总token使用量
|
||||
llm_request.total_duration = total_duration
|
||||
llm_request.total_tokens_usage = total_tokens
|
||||
|
||||
return llm_request
|
||||
|
||||
if __name__ == "__main__":
|
||||
from json_example import generate_json_example
|
||||
from sqlmodel import Session, select
|
||||
from global_var import get_sql_engine, init_global_var
|
||||
from schema import DatasetItem
|
||||
|
||||
init_global_var("workdir")
|
||||
api_state = "1 deepseek-chat"
|
||||
with Session(get_sql_engine()) as session:
|
||||
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||
llm_request = LLMRequest(
|
||||
prompt="测试,随便说点什么",
|
||||
api_provider=api_provider,
|
||||
format=generate_json_example(DatasetItem)
|
||||
)
|
||||
|
||||
# # 单次调用示例
|
||||
# result = asyncio.run(call_openai_api(llm_request))
|
||||
# print(f"\n单次调用结果 - 响应数量: {len(result.response)}")
|
||||
# for i, resp in enumerate(result.response, 1):
|
||||
# print(f"响应{i}: {resp.response_content}")
|
||||
|
||||
# 多次调用示例
|
||||
params = LLMParameters(temperature=0.7, max_tokens=100)
|
||||
result = asyncio.run(call_openai_api(llm_request, 3,params))
|
||||
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
|
||||
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
|
||||
for i, resp in enumerate(result.response, 1):
|
||||
print(f"响应{i}: {resp.content}")
|
1
train/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .model import *
|
131
train/model.py
Normal file
@ -0,0 +1,131 @@
|
||||
import os
|
||||
from datasets import Dataset as HFDataset
|
||||
from unsloth import FastLanguageModel
|
||||
from trl import SFTTrainer # 用于监督微调的训练器
|
||||
from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数
|
||||
from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练
|
||||
from unsloth.chat_templates import get_chat_template, train_on_responses_only
|
||||
def get_model_name(model):
|
||||
return os.path.basename(model.name_or_path)
|
||||
def formatting_prompts(examples,tokenizer):
|
||||
"""格式化对话数据的函数
|
||||
Args:
|
||||
examples: 包含对话列表的字典
|
||||
Returns:
|
||||
包含格式化文本的字典
|
||||
"""
|
||||
questions = examples["question"]
|
||||
answer = examples["answer"]
|
||||
|
||||
convos = [
|
||||
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
|
||||
for q, r in zip(questions, answer)
|
||||
]
|
||||
|
||||
# 使用tokenizer.apply_chat_template格式化对话
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
|
||||
for convo in convos
|
||||
]
|
||||
|
||||
return {"text": texts}
|
||||
|
||||
|
||||
def train_model(
|
||||
model,
|
||||
tokenizer,
|
||||
dataset: list,
|
||||
train_dir: str,
|
||||
learning_rate: float,
|
||||
per_device_train_batch_size: int,
|
||||
epoch: int,
|
||||
save_steps: int,
|
||||
lora_rank: int,
|
||||
trainer_callback=None
|
||||
) -> None:
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
# 原始模型
|
||||
model,
|
||||
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
|
||||
# 建议: 8-32之间
|
||||
r=lora_rank, # 使用动态传入的LoRA秩
|
||||
# 需要应用LoRA的目标模块列表
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
|
||||
"gate_proj", "up_proj", "down_proj", # FFN相关层
|
||||
],
|
||||
# LoRA缩放因子,用于控制LoRA更新的幅度。值越大,LoRA的更新影响越大。
|
||||
lora_alpha=16,
|
||||
# LoRA层的dropout率,用于防止过拟合,这里设为0表示不使用dropout。
|
||||
# 如果数据集较小,建议设置0.1左右。
|
||||
lora_dropout=0,
|
||||
# 是否对bias参数进行微调,none表示不微调bias
|
||||
# none: 不微调偏置参数;
|
||||
# all: 微调所有参数;
|
||||
# lora_only: 只微调LoRA参数。
|
||||
bias="none",
|
||||
# 是否使用梯度检查点技术节省显存,使用unsloth优化版本
|
||||
# 会略微降低训练速度,但可以显著减少显存使用
|
||||
use_gradient_checkpointing="unsloth",
|
||||
# 随机数种子,用于结果复现
|
||||
random_state=3407,
|
||||
# 是否使用rank-stabilized LoRA,这里不使用
|
||||
# 会略微降低训练速度,但可以显著减少显存使用
|
||||
use_rslora=False,
|
||||
# LoFTQ配置,这里不使用该量化技术,用于进一步压缩模型大小
|
||||
loftq_config=None,
|
||||
)
|
||||
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="qwen-2.5",
|
||||
)
|
||||
|
||||
train_dataset = HFDataset.from_list(dataset)
|
||||
train_dataset = train_dataset.map(formatting_prompts,
|
||||
fn_kwargs={"tokenizer": tokenizer},
|
||||
batched=True)
|
||||
|
||||
# 初始化SFT训练器
|
||||
trainer = SFTTrainer(
|
||||
model=model, # 待训练的模型
|
||||
tokenizer=tokenizer, # 分词器
|
||||
train_dataset=train_dataset, # 训练数据集
|
||||
dataset_text_field="text", # 数据集字段的名称
|
||||
max_seq_length=model.max_seq_length, # 最大序列长度
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=1, # 数据集处理的并行进程数
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=per_device_train_batch_size, # 每个GPU的训练批次大小
|
||||
gradient_accumulation_steps=4, # 梯度累积步数,用于模拟更大的batch size
|
||||
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
|
||||
learning_rate=learning_rate, # 学习率
|
||||
lr_scheduler_type="linear", # 线性学习率调度器
|
||||
max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据)
|
||||
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
|
||||
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
|
||||
logging_steps=1, # 每1步记录一次日志
|
||||
optim="adamw_8bit", # 使用8位AdamW优化器节省显存,几乎不影响训练效果
|
||||
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
|
||||
seed=114514, # 随机数种子
|
||||
output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志
|
||||
save_strategy="steps", # 按步保存中间权重
|
||||
save_steps=save_steps, # 使用动态传入的保存步数
|
||||
logging_dir=train_dir + "/logs", # 日志文件存储路径
|
||||
report_to="tensorboard", # 使用TensorBoard记录日志
|
||||
),
|
||||
)
|
||||
|
||||
if trainer_callback is not None:
|
||||
trainer.add_callback(trainer_callback)
|
||||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part = "<|im_start|>user\n",
|
||||
response_part = "<|im_start|>assistant\n",
|
||||
)
|
||||
|
||||
# 开始训练
|
||||
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
48
train/save_model.py
Normal file
@ -0,0 +1,48 @@
|
||||
import os
|
||||
from global_var import get_model, get_tokenizer
|
||||
|
||||
def save_model_to_dir(save_model_name, models_dir, model, tokenizer, save_method="default"):
|
||||
"""
|
||||
保存模型到指定目录
|
||||
:param save_model_name: 要保存的模型名称
|
||||
:param models_dir: 模型保存的基础目录
|
||||
:param model: 要保存的模型
|
||||
:param tokenizer: 要保存的tokenizer
|
||||
:param save_method: 保存模式选项
|
||||
- "default": 默认保存方式
|
||||
- "merged_16bit": 合并为16位
|
||||
- "merged_4bit": 合并为4位
|
||||
- "lora": 仅LoRA适配器
|
||||
- "gguf": 保存为GGUF格式
|
||||
- "gguf_q4_k_m": 保存为q4_k_m GGUF格式
|
||||
- "gguf_f16": 保存为16位GGUF格式
|
||||
:return: 保存结果消息或错误信息
|
||||
"""
|
||||
try:
|
||||
if model is None:
|
||||
return "没有加载的模型可保存"
|
||||
|
||||
save_path = os.path.join(models_dir, save_model_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
if save_method == "default":
|
||||
model.save_pretrained(save_path)
|
||||
tokenizer.save_pretrained(save_path)
|
||||
elif save_method == "merged_16bit":
|
||||
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_16bit")
|
||||
elif save_method == "merged_4bit":
|
||||
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_4bit_forced")
|
||||
elif save_method == "lora":
|
||||
model.save_pretrained_merged(save_path, tokenizer, save_method="lora")
|
||||
elif save_method == "gguf":
|
||||
model.save_pretrained_gguf(save_path, tokenizer)
|
||||
elif save_method == "gguf_q4_k_m":
|
||||
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="q4_k_m")
|
||||
elif save_method == "gguf_f16":
|
||||
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="f16")
|
||||
else:
|
||||
return f"不支持的保存模式: {save_method}"
|
||||
|
||||
return f"模型已保存到 {save_path} (模式: {save_method})"
|
||||
except Exception as e:
|
||||
return f"保存模型时出错: {str(e)}"
|