Compare commits
147 Commits
mvp
...
4a67c20b70
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4a67c20b70 | ||
![]() |
d210ddcca9 | ||
![]() |
97ee546bdf | ||
![]() |
2af26560b3 | ||
![]() |
0977002c06 | ||
![]() |
fa83e06346 | ||
![]() |
03221547bb | ||
![]() |
7cc26feaa9 | ||
![]() |
9494c9c913 | ||
![]() |
ff31213aa8 | ||
![]() |
aa2f75f67f | ||
![]() |
4d61a4cffd | ||
![]() |
3e520604ba | ||
![]() |
959ff2033d | ||
![]() |
d653e05a61 | ||
![]() |
6685b742ed | ||
![]() |
3718c75cee | ||
![]() |
905658073a | ||
![]() |
9806334517 | ||
![]() |
0a4efa5641 | ||
![]() |
994d600221 | ||
![]() |
d5774eee0c | ||
![]() |
87501c9353 | ||
![]() |
5fc3b4950b | ||
![]() |
c28e4819d9 | ||
![]() |
e7cf51d662 | ||
![]() |
4c9caff668 | ||
![]() |
9236f49b36 | ||
![]() |
868fcd45ba | ||
![]() |
5a21c8598a | ||
![]() |
1e829c9268 | ||
![]() |
9fc3ab904b | ||
![]() |
d827f9758f | ||
![]() |
ff1e9731bc | ||
![]() |
90fde639ff | ||
![]() |
5fc90903fb | ||
![]() |
81c2ad4a2d | ||
![]() |
314434951d | ||
![]() |
e16882953d | ||
![]() |
86bcf90c66 | ||
![]() |
961a017f19 | ||
![]() |
5a386d6401 | ||
![]() |
feaea1fb64 | ||
![]() |
7242a2ce03 | ||
![]() |
db6e2271dc | ||
![]() |
d764537143 | ||
![]() |
8c35a38c47 | ||
![]() |
7ee751c88f | ||
![]() |
b715b36a5f | ||
![]() |
8023233bb2 | ||
![]() |
2a86b3b5b0 | ||
![]() |
ca1505304e | ||
![]() |
df9260e918 | ||
![]() |
df9aba0c6e | ||
![]() |
6b87dcb58f | ||
![]() |
d0aebd17fa | ||
![]() |
d9abf08184 | ||
![]() |
a27a1ab079 | ||
![]() |
aa758e3c2a | ||
![]() |
664944f0c5 | ||
![]() |
9298438f98 | ||
![]() |
4f7926aec6 | ||
![]() |
148f4afb25 | ||
![]() |
11a3039775 | ||
![]() |
a4289815ba | ||
![]() |
088067d335 | ||
![]() |
9fb31c46c8 | ||
![]() |
4f09823123 | ||
![]() |
1a2ca3e244 | ||
![]() |
bb1d8fbd38 | ||
![]() |
4558929c52 | ||
![]() |
0722748997 | ||
![]() |
e08f0059bb | ||
![]() |
6d1fecbdac | ||
![]() |
79d3eb153c | ||
![]() |
80dae7c6e2 | ||
![]() |
2d39b91764 | ||
![]() |
5094febcb4 | ||
![]() |
eeee68dbd1 | ||
![]() |
539e14d39c | ||
![]() |
9784f2aed3 | ||
![]() |
611904cef9 | ||
![]() |
8a9a080745 | ||
![]() |
a23ad88769 | ||
![]() |
83427aaaba | ||
![]() |
61672021ef | ||
![]() |
fb6157af05 | ||
![]() |
f655936741 | ||
![]() |
ab7897351a | ||
![]() |
216bfe39ae | ||
![]() |
0fa2b51a79 | ||
![]() |
cbb3a09dd8 | ||
![]() |
2e552c186d | ||
![]() |
1b3f546669 | ||
![]() |
402bc73dce | ||
![]() |
bb5851f800 | ||
![]() |
a407fa1f76 | ||
![]() |
4b465ec917 | ||
![]() |
e7cc03297b | ||
![]() |
051d1a7535 | ||
![]() |
97172f9596 | ||
![]() |
f582820443 | ||
![]() |
8fb9f785b9 | ||
![]() |
2c8e54bb1e | ||
![]() |
932d1e2687 | ||
![]() |
202d4c44df | ||
![]() |
4d77c429bd | ||
![]() |
41447c5ed4 | ||
![]() |
84fe78243a | ||
![]() |
4d8754aad2 | ||
![]() |
541d37c674 | ||
![]() |
6a00699472 | ||
![]() |
ff8162890d | ||
![]() |
daddcd34da | ||
![]() |
5c7ced30df | ||
![]() |
9741ce6b92 | ||
![]() |
67281fe06a | ||
![]() |
2d905a0270 | ||
![]() |
374b124cf8 | ||
![]() |
74ae5e1426 | ||
![]() |
0a6ae7a4ee | ||
![]() |
faf72d1e99 | ||
![]() |
cce5e4e114 | ||
![]() |
293f63017f | ||
![]() |
2e31f4f57c | ||
![]() |
967133162e | ||
![]() |
dc28c25c65 | ||
![]() |
70b64dc3d3 | ||
![]() |
b52ca9b1af | ||
![]() |
46b4453ccd | ||
![]() |
d5b528d375 | ||
![]() |
475cd033d9 | ||
![]() |
3970a67df3 | ||
![]() |
286db405ca | ||
![]() |
d40f5b1f24 | ||
![]() |
7a77f61ee6 | ||
![]() |
841e14a093 | ||
![]() |
2ff077bb1c | ||
![]() |
513b639bce | ||
![]() |
f93f213a31 | ||
![]() |
10b4c29bda | ||
![]() |
b1e98ca913 | ||
![]() |
2d5a5277ae | ||
![]() |
519a5f3773 | ||
![]() |
1f4d491694 | ||
![]() |
8ce4f1e373 | ||
![]() |
3395b860e4 |
28
.gitignore
vendored
28
.gitignore
vendored
@@ -11,6 +11,7 @@ env/
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
.roo
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
@@ -28,3 +29,30 @@ workdir/
|
||||
|
||||
# cache
|
||||
unsloth_compiled_cache
|
||||
|
||||
# 测试和参考代码
|
||||
test.ipynb
|
||||
test.py
|
||||
refer/
|
||||
|
||||
# LaTeX临时文件
|
||||
*.aux
|
||||
*.log
|
||||
*.out
|
||||
*.toc
|
||||
*.synctex.gz
|
||||
*.bbl
|
||||
*.blg
|
||||
*.dvi
|
||||
*.fdb_latexmk
|
||||
*.fls
|
||||
*.lof
|
||||
*.lot
|
||||
*.idx
|
||||
*.ilg
|
||||
*.ind
|
||||
*.nav
|
||||
*.snm
|
||||
*.vrb
|
||||
*.xdv
|
||||
*.pdf
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 C-a-r-r-y
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
91
README.md
Normal file
91
README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# 基于文档驱动的自适应编码大模型微调框架
|
||||
|
||||
## 简介
|
||||
|
||||
### 项目概述
|
||||
本项目是一个基于文档驱动的自适应编码大模型微调框架,通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
|
||||
|
||||
### 核心功能
|
||||
- 文档解析与语料生成
|
||||
- 大语言模型高效微调
|
||||
- 交互式训练与推理界面
|
||||
- 训练过程可视化监控
|
||||
|
||||
## 技术架构
|
||||
|
||||
### 系统架构
|
||||
```
|
||||
[前端界面] -> [模型微调] -> [数据存储]
|
||||
↑ ↑ ↑
|
||||
│ │ │
|
||||
[Gradio] [unsloth/QLoRA] [SQLite/TinyDB]
|
||||
```
|
||||
|
||||
### 技术栈
|
||||
- **前端界面**: Gradio构建的交互式Web界面
|
||||
- **模型微调**: 基于unsloth框架的QLoRA高效微调
|
||||
- **数据存储**: SQLite(结构化数据) + TinyDB(非结构化数据)
|
||||
- **工作流引擎**: LangChain实现文档解析与语料生成
|
||||
- **训练监控**: TensorBoard集成
|
||||
|
||||
## 功能模块
|
||||
|
||||
### 1. 模型管理
|
||||
- 支持多种格式的大语言模型加载
|
||||
- 模型信息查看与状态管理
|
||||
- 模型卸载与内存释放
|
||||
|
||||
### 2. 模型推理
|
||||
- 对话式交互界面
|
||||
- 流式响应输出
|
||||
- 历史对话管理
|
||||
|
||||
### 3. 模型微调
|
||||
- 训练参数配置(学习率、batch size等)
|
||||
- LoRA参数配置(秩、alpha等)
|
||||
- 训练过程实时监控
|
||||
- 训练中断与恢复
|
||||
|
||||
### 4. 数据集生成
|
||||
- 文档解析与清洗
|
||||
- 指令-响应对生成
|
||||
- 数据集质量评估
|
||||
- 数据集版本管理
|
||||
|
||||
## 微调技术
|
||||
|
||||
### QLoRA原理
|
||||
QLoRA(Quantized Low-Rank Adaptation)是一种高效的大模型微调技术,核心特点包括:
|
||||
1. **4-bit量化**: 将预训练模型量化为4-bit表示,大幅减少显存占用
|
||||
2. **低秩适配**: 通过低秩矩阵分解(LoRA)实现参数高效更新
|
||||
3. **内存优化**: 使用梯度检查点等技术进一步降低显存需求
|
||||
|
||||
### 参数配置
|
||||
- **学习率**: 建议2e-5到2e-4
|
||||
- **LoRA秩**: 控制适配器复杂度(建议16-64)
|
||||
- **LoRA Alpha**: 控制适配器更新幅度(通常设为秩的1-2倍)
|
||||
|
||||
### 训练监控
|
||||
- **TensorBoard集成**: 实时查看损失曲线、学习率等指标
|
||||
- **日志记录**: 训练过程详细日志保存
|
||||
- **模型检查点**: 定期保存中间权重
|
||||
|
||||
## 快速开始
|
||||
|
||||
1. 安装依赖:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. 启动应用:
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
3. 访问Web界面:
|
||||
```
|
||||
http://localhost:7860
|
||||
```
|
||||
|
||||
## 许可证
|
||||
MIT License
|
12
db/__init__.py
Normal file
12
db/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .init_db import load_sqlite_engine, initialize_sqlite_db
|
||||
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
||||
from .dataset_store import get_all_dataset, save_dataset
|
||||
|
||||
__all__ = [
|
||||
"load_sqlite_engine",
|
||||
"initialize_sqlite_db",
|
||||
"get_prompt_tinydb",
|
||||
"initialize_prompt_store",
|
||||
"get_all_dataset",
|
||||
"save_dataset"
|
||||
]
|
81
db/dataset_store.py
Normal file
81
db/dataset_store.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tinydb import TinyDB, Query
|
||||
from tinydb.storages import MemoryStorage
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset import Dataset, DatasetItem, Q_A
|
||||
|
||||
def get_all_dataset(workdir: str) -> TinyDB:
|
||||
"""
|
||||
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径
|
||||
|
||||
Returns:
|
||||
TinyDB: 包含所有数据集对象的TinyDB对象
|
||||
"""
|
||||
dataset_dir = os.path.join(workdir, "dataset")
|
||||
if not os.path.exists(dataset_dir):
|
||||
return TinyDB(storage=MemoryStorage)
|
||||
|
||||
db = TinyDB(storage=MemoryStorage)
|
||||
for filename in os.listdir(dataset_dir):
|
||||
if filename.endswith(".json"):
|
||||
filepath = os.path.join(dataset_dir, filename)
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
db.insert(data)
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
print(f"Error loading dataset file {filename}: {str(e)}")
|
||||
continue
|
||||
|
||||
return db
|
||||
|
||||
|
||||
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
|
||||
"""
|
||||
将TinyDB中的数据集保存为单独的json文件
|
||||
|
||||
Args:
|
||||
db (TinyDB): 包含数据集对象的TinyDB实例
|
||||
workdir (str): 工作目录路径
|
||||
name (str, optional): 要保存的数据集名称,None表示保存所有
|
||||
"""
|
||||
dataset_dir = os.path.join(workdir, "dataset")
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
datasets = db.all() if name is None else db.search(Query().name == name)
|
||||
|
||||
for dataset in datasets:
|
||||
try:
|
||||
filename = f"{dataset.get(dataset['name'])}.json"
|
||||
filepath = os.path.join(dataset_dir, filename)
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取所有数据集
|
||||
datasets = get_all_dataset(workdir)
|
||||
# 打印结果
|
||||
print(f"Found {len(datasets)} datasets:")
|
||||
for ds in datasets.all():
|
||||
print(f"- {ds['name']} (ID: {ds['id']})")
|
||||
|
||||
# 询问要保存的数据集名称
|
||||
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
|
||||
|
||||
# 保存数据集到文件
|
||||
save_dataset(datasets, workdir, name)
|
||||
print(f"Datasets {'all' if name is None else name} saved to json files")
|
79
db/init_db.py
Normal file
79
db/init_db.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
import sys
|
||||
from sqlmodel import SQLModel, create_engine, Session
|
||||
from sqlmodel import select
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset_generation import APIProvider
|
||||
|
||||
# 全局变量,用于存储数据库引擎实例
|
||||
_engine: Optional[Engine] = None
|
||||
|
||||
def load_sqlite_engine(workdir: str) -> Engine:
|
||||
"""
|
||||
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||||
|
||||
Returns:
|
||||
Engine: SQLAlchemy 数据库引擎实例。
|
||||
"""
|
||||
global _engine
|
||||
if not _engine:
|
||||
# 创建数据库目录(如果不存在)
|
||||
db_dir = os.path.join(workdir, "db")
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
# 定义数据库文件路径
|
||||
db_path = os.path.join(db_dir, "db.sqlite")
|
||||
# 创建数据库URL
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
# 创建数据库引擎
|
||||
_engine = create_engine(db_url)
|
||||
return _engine
|
||||
|
||||
def initialize_sqlite_db(engine: Engine) -> None:
|
||||
"""
|
||||
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
|
||||
|
||||
Args:
|
||||
engine (Engine): SQLAlchemy 数据库引擎实例。
|
||||
"""
|
||||
# 创建所有表结构
|
||||
SQLModel.metadata.create_all(engine)
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
# 从环境变量中获取API相关配置
|
||||
api_key = os.getenv("API_KEY")
|
||||
base_url = os.getenv("BASE_URL")
|
||||
model_id = os.getenv("MODEL_ID")
|
||||
|
||||
# 如果所有必要的环境变量都存在,则插入初始数据
|
||||
if api_key and base_url and model_id:
|
||||
with Session(engine) as session:
|
||||
# 查询是否已存在APIProvider记录
|
||||
statement = select(APIProvider).limit(1)
|
||||
existing_provider = session.exec(statement).first()
|
||||
|
||||
# 如果不存在,则插入新的APIProvider记录
|
||||
if not existing_provider:
|
||||
provider = APIProvider(
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
api_key=api_key
|
||||
)
|
||||
session.add(provider)
|
||||
session.commit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取数据库引擎
|
||||
engine = load_sqlite_engine(workdir)
|
||||
# 初始化数据库
|
||||
initialize_sqlite_db(engine)
|
62
db/prompt_store.py
Normal file
62
db/prompt_store.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from tinydb import TinyDB, Query
|
||||
from tinydb.storages import JSONStorage
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.prompt import promptTempleta
|
||||
|
||||
# 全局变量,用于存储TinyDB实例
|
||||
_db_instance: Optional[TinyDB] = None
|
||||
# 自定义存储类,用于格式化JSON数据
|
||||
def get_prompt_tinydb(workdir: str) -> TinyDB:
|
||||
"""
|
||||
获取TinyDB实例。如果实例尚未创建,则创建一个新的并返回。
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||||
|
||||
Returns:
|
||||
TinyDB: TinyDB数据库实例
|
||||
"""
|
||||
global _db_instance
|
||||
if not _db_instance:
|
||||
# 创建数据库目录(如果不存在)
|
||||
db_dir = os.path.join(workdir, "db")
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
# 定义数据库文件路径
|
||||
db_path = os.path.join(db_dir, "prompts.json")
|
||||
# 创建TinyDB实例
|
||||
_db_instance = TinyDB(db_path)
|
||||
return _db_instance
|
||||
|
||||
def initialize_prompt_store(db: TinyDB) -> None:
|
||||
"""
|
||||
初始化prompt模板存储
|
||||
|
||||
Args:
|
||||
db (TinyDB): TinyDB数据库实例
|
||||
"""
|
||||
# 检查数据库是否为空
|
||||
if not db.all(): # 如果数据库中没有数据
|
||||
db.insert(promptTempleta(
|
||||
id=1,
|
||||
name="default",
|
||||
description="默认提示词模板",
|
||||
content="""项目名为:{project_name}
|
||||
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
|
||||
文档节选:{document_slice}""").model_dump())
|
||||
# 如果数据库中已有数据,则跳过插入
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取数据库实例
|
||||
db = get_prompt_tinydb(workdir)
|
||||
# 初始化prompt存储
|
||||
initialize_prompt_store(db)
|
7
frontend/__init__.py
Normal file
7
frontend/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .train_page import *
|
||||
from .chat_page import *
|
||||
from .setting_page import *
|
||||
from .model_manage_page import *
|
||||
from .dataset_manage_page import *
|
||||
from .dataset_generate_page import *
|
||||
from .prompt_manage_page import *
|
97
frontend/chat_page.py
Normal file
97
frontend/chat_page.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import gradio as gr
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread # 需要导入 Thread
|
||||
from transformers import TextIteratorStreamer # 使用 TextIteratorStreamer
|
||||
|
||||
# 假设 global_var.py 在父目录
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer # 假设这两个函数能正确获取模型和分词器
|
||||
|
||||
def chat_page():
|
||||
with gr.Blocks() as demo:
|
||||
# 聊天框
|
||||
gr.Markdown("## 对话")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
chatbot = gr.Chatbot(type="messages", label="聊天机器人")
|
||||
msg = gr.Textbox(label="输入消息")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
# 新增超参数输入框
|
||||
max_new_tokens_input = gr.Textbox(label="最大生成长度", value="1024")
|
||||
temperature_input = gr.Textbox(label="温度 (Temperature)", value="0.8")
|
||||
top_p_input = gr.Textbox(label="Top-p 采样", value="0.95")
|
||||
repetition_penalty_input = gr.Textbox(label="重复惩罚", value="1.1")
|
||||
clear = gr.Button("清除对话")
|
||||
|
||||
def user(user_message, history: list):
|
||||
return "", history + [{"role": "user", "content": user_message}]
|
||||
|
||||
def bot(history: list, max_new_tokens, temperature, top_p, repetition_penalty):
|
||||
model = get_model()
|
||||
tokenizer = get_tokenizer()
|
||||
if not history:
|
||||
yield history
|
||||
return
|
||||
|
||||
if model is None or tokenizer is None:
|
||||
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"})
|
||||
yield history
|
||||
return
|
||||
|
||||
try:
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
history,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
# 将超参数转换为数值类型
|
||||
generation_kwargs = dict(
|
||||
input_ids=inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=int(max_new_tokens),
|
||||
temperature=float(temperature),
|
||||
top_p=float(top_p),
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
do_sample=True,
|
||||
use_cache=False
|
||||
)
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
history.append({"role": "assistant", "content": ""})
|
||||
|
||||
for new_text in streamer:
|
||||
if new_text:
|
||||
history[-1]["content"] += new_text
|
||||
yield history
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
|
||||
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
|
||||
history[-1]["content"] = error_message
|
||||
else:
|
||||
history.append({"role": "assistant", "content": error_message})
|
||||
yield history
|
||||
|
||||
# 更新 .then() 调用,将超参数传递给 bot 函数
|
||||
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
||||
bot, [chatbot, max_new_tokens_input, temperature_input, top_p_input, repetition_penalty_input], chatbot
|
||||
)
|
||||
|
||||
clear.click(lambda: [], None, chatbot, queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from model_manage_page import model_manage_page
|
||||
# 装载两个页面
|
||||
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
|
||||
demo.queue()
|
||||
demo.launch()
|
189
frontend/dataset_generate_page.py
Normal file
189
frontend/dataset_generate_page.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import gradio as gr
|
||||
import sys
|
||||
import json
|
||||
from tinydb import Query
|
||||
from pathlib import Path
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sqlmodel import Session, select
|
||||
from schema import Dataset, DatasetItem, Q_A
|
||||
from db.dataset_store import get_all_dataset
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
|
||||
from db import save_dataset
|
||||
from tools import call_openai_api, process_markdown_file, generate_json_example
|
||||
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
|
||||
|
||||
def dataset_generate_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 数据集生成")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
docs_list = [str(doc.name) for doc in get_docs()]
|
||||
initial_doc = docs_list[0] if docs_list else None
|
||||
prompts = get_prompt_store().all()
|
||||
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
|
||||
initial_prompt = prompt_list[0] if prompt_list else None
|
||||
|
||||
# 初始化Dataframe的值
|
||||
initial_dataframe_value = []
|
||||
if initial_prompt:
|
||||
selected_prompt_id = int(initial_prompt.split(" ")[0])
|
||||
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
|
||||
prompt_content = prompt_data["content"]
|
||||
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||
input_variables = prompt_template.input_variables
|
||||
input_variables.remove("document_slice")
|
||||
initial_dataframe_value = [[var, ""] for var in input_variables]
|
||||
# 从数据库获取API Provider列表
|
||||
with Session(get_sql_engine()) as session:
|
||||
providers = session.exec(select(APIProvider)).all()
|
||||
api_list = [f"{p.id} {p.model_id}" for p in providers]
|
||||
initial_api = api_list[0] if api_list else None
|
||||
|
||||
api_dropdown = gr.Dropdown(
|
||||
choices=api_list,
|
||||
value=initial_api,
|
||||
label="选择API",
|
||||
interactive=True
|
||||
)
|
||||
doc_dropdown = gr.Dropdown(
|
||||
choices=docs_list,
|
||||
value=initial_doc,
|
||||
label="选择文档",
|
||||
interactive=True
|
||||
)
|
||||
prompt_dropdown = gr.Dropdown(
|
||||
choices=prompt_list,
|
||||
value=initial_prompt,
|
||||
label="选择模板",
|
||||
interactive=True
|
||||
)
|
||||
rounds_input = gr.Number(
|
||||
value=1,
|
||||
label="生成轮次",
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
step=1,
|
||||
interactive=True
|
||||
)
|
||||
concurrency_input = gr.Number(
|
||||
value=1,
|
||||
label="并发数",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=1,
|
||||
interactive=True,
|
||||
visible=False
|
||||
)
|
||||
dataset_name_input = gr.Textbox(
|
||||
label="数据集名称",
|
||||
placeholder="输入数据集保存名称",
|
||||
interactive=True
|
||||
)
|
||||
prompt_choice = gr.State(value=initial_prompt)
|
||||
generate_button = gr.Button("生成数据集",variant="primary")
|
||||
doc_choice = gr.State(value=initial_doc)
|
||||
output_text = gr.Textbox(label="生成结果", interactive=False)
|
||||
api_choice = gr.State(value=initial_api)
|
||||
with gr.Column(scale=2):
|
||||
variables_dataframe = gr.Dataframe(
|
||||
headers=["变量名", "变量值"],
|
||||
datatype=["str", "str"],
|
||||
interactive=True,
|
||||
label="变量列表",
|
||||
value=initial_dataframe_value # 设置初始化数据
|
||||
)
|
||||
def on_doc_change(selected_doc):
|
||||
return selected_doc
|
||||
|
||||
def on_api_change(selected_api):
|
||||
return selected_api
|
||||
|
||||
def on_prompt_change(selected_prompt):
|
||||
if not selected_prompt:
|
||||
return None, []
|
||||
selected_prompt_id = int(selected_prompt.split(" ")[0])
|
||||
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
|
||||
prompt_content = prompt_data["content"]
|
||||
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||
input_variables = prompt_template.input_variables
|
||||
input_variables.remove("document_slice")
|
||||
dataframe_value = [] if input_variables is None else input_variables
|
||||
dataframe_value = [[var, ""] for var in input_variables]
|
||||
return selected_prompt, dataframe_value
|
||||
|
||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
|
||||
dataset_db = get_datasets()
|
||||
if not dataset_db.search(Query().name == dataset_name):
|
||||
raise gr.Error("数据集名称已存在")
|
||||
|
||||
doc = [i for i in get_docs() if i.name == doc_state][0]
|
||||
doc_files = doc.markdown_files
|
||||
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
|
||||
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
||||
prompt = PromptTemplate.from_template(prompt["content"])
|
||||
with Session(get_sql_engine()) as session:
|
||||
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||
|
||||
variables_dict = {}
|
||||
for _, row in variables_dataframe.iterrows():
|
||||
var_name = row['变量名'].strip()
|
||||
var_value = row['变量值'].strip()
|
||||
if var_name:
|
||||
variables_dict[var_name] = var_value
|
||||
|
||||
prompt = prompt.partial(**variables_dict)
|
||||
|
||||
dataset = Dataset(
|
||||
name=dataset_name,
|
||||
model_id=[api_provider.model_id],
|
||||
source_doc=doc,
|
||||
dataset_items=[]
|
||||
)
|
||||
|
||||
total_slices = len(document_slice_list)
|
||||
for i, document_slice in enumerate(document_slice_list):
|
||||
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
|
||||
request = LLMRequest(api_provider=api_provider,
|
||||
prompt=prompt.format(document_slice=document_slice),
|
||||
format=generate_json_example(DatasetItem))
|
||||
call_openai_api(request, rounds)
|
||||
|
||||
for resp in request.response:
|
||||
try:
|
||||
content = json.loads(resp.content)
|
||||
dataset_item = DatasetItem(
|
||||
message=[Q_A(
|
||||
question=content.get("question", ""),
|
||||
answer=content.get("answer", "")
|
||||
)]
|
||||
)
|
||||
dataset.dataset_items.append(dataset_item)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse response: {e}")
|
||||
|
||||
# 保存数据集到TinyDB
|
||||
dataset_db.insert(dataset.model_dump())
|
||||
|
||||
save_dataset(dataset_db,get_workdir(),dataset_name)
|
||||
|
||||
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
|
||||
|
||||
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
||||
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
||||
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
|
||||
|
||||
generate_button.click(
|
||||
on_generate_click,
|
||||
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
|
||||
outputs=output_text
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from global_var import init_global_var
|
||||
init_global_var("workdir")
|
||||
demo = dataset_generate_page()
|
||||
demo.launch()
|
55
frontend/dataset_manage_page.py
Normal file
55
frontend/dataset_manage_page.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import gradio as gr
|
||||
from global_var import get_datasets
|
||||
from tinydb import Query
|
||||
|
||||
def dataset_manage_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 数据集管理")
|
||||
with gr.Row():
|
||||
# 获取数据集列表并设置初始值
|
||||
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
||||
initial_dataset = datasets_list[0] if datasets_list else None
|
||||
|
||||
dataset_dropdown = gr.Dropdown(
|
||||
choices=datasets_list,
|
||||
value=initial_dataset, # 设置初始选中项
|
||||
label="选择数据集",
|
||||
allow_custom_value=True,
|
||||
interactive=True
|
||||
)
|
||||
|
||||
# 添加数据集展示组件
|
||||
qa_dataset = gr.Dataset(
|
||||
components=["text", "text"],
|
||||
label="问答数据",
|
||||
headers=["问题", "答案"],
|
||||
samples=[["示例问题", "示例答案"]],
|
||||
samples_per_page=20,
|
||||
)
|
||||
|
||||
def update_qa_display(dataset_name):
|
||||
if not dataset_name:
|
||||
return {"samples": [], "__type__": "update"}
|
||||
|
||||
# 从数据库获取数据集
|
||||
Dataset = Query()
|
||||
ds = get_datasets().get(Dataset.name == dataset_name)
|
||||
if not ds:
|
||||
return {"samples": [], "__type__": "update"}
|
||||
|
||||
# 提取所有Q_A数据
|
||||
qa_list = []
|
||||
for item in ds["dataset_items"]:
|
||||
for qa in item["message"]:
|
||||
qa_list.append([qa["question"], qa["answer"]])
|
||||
|
||||
return {"samples": qa_list, "__type__": "update"}
|
||||
|
||||
# 绑定事件,更新QA数据显示
|
||||
dataset_dropdown.change(
|
||||
update_qa_display,
|
||||
inputs=dataset_dropdown,
|
||||
outputs=qa_dataset
|
||||
)
|
||||
|
||||
return demo
|
107
frontend/model_manage_page.py
Normal file
107
frontend/model_manage_page.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import gradio as gr
|
||||
import os # 导入os模块以便扫描文件夹
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unsloth import FastLanguageModel
|
||||
import torch
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
|
||||
from train import get_model_name
|
||||
|
||||
def model_manage_page():
|
||||
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
|
||||
models_dir = os.path.join(workdir, "models")
|
||||
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))] # 扫描models文件夹下的所有子文件夹
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 模型管理")
|
||||
state_output = gr.Label(label="当前状态",value="当前未加载模型") # 将 Textbox 改为 Label
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
model_select_dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中,并设置为可选
|
||||
max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0)
|
||||
load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True)
|
||||
with gr.Column(scale=1):
|
||||
load_button = gr.Button("加载模型", variant="primary")
|
||||
unload_button = gr.Button("卸载模型", variant="stop")
|
||||
refresh_button = gr.Button("刷新模型列表", variant="secondary") # 新增刷新按钮
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
|
||||
with gr.Column(scale=1):
|
||||
save_button = gr.Button("保存模型", variant="secondary")
|
||||
|
||||
def load_model(selected_model, max_seq_length, load_in_4bit):
|
||||
try:
|
||||
# 判空操作,如果模型已加载,则先卸载
|
||||
if get_model() is not None:
|
||||
unload_model()
|
||||
|
||||
model_path = os.path.join(models_dir, selected_model)
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=model_path,
|
||||
max_seq_length=max_seq_length,
|
||||
load_in_4bit=load_in_4bit,
|
||||
)
|
||||
set_model(model)
|
||||
set_tokenizer(tokenizer)
|
||||
return f"模型 {get_model_name(model)} 已加载"
|
||||
except Exception as e:
|
||||
return f"加载模型时出错: {str(e)}"
|
||||
|
||||
load_button.click(fn=load_model, inputs=[model_select_dropdown, max_seq_length_input, load_in_4bit_input], outputs=state_output)
|
||||
|
||||
def unload_model():
|
||||
try:
|
||||
# 将模型移动到 CPU
|
||||
model = get_model()
|
||||
if model is not None:
|
||||
model.cpu()
|
||||
|
||||
# 清空 CUDA 缓存
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 将模型和tokenizer设置为 None
|
||||
set_model(None)
|
||||
set_tokenizer(None)
|
||||
|
||||
return "当前未加载模型"
|
||||
except Exception as e:
|
||||
return f"卸载模型时出错: {str(e)}"
|
||||
|
||||
unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
|
||||
|
||||
def save_model(save_model_name):
|
||||
try:
|
||||
global model, tokenizer
|
||||
if model is None:
|
||||
return "没有加载的模型可保存"
|
||||
|
||||
save_path = os.path.join(models_dir, save_model_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model.save_pretrained(save_path)
|
||||
tokenizer.save_pretrained(save_path)
|
||||
return f"模型已保存到 {save_path}"
|
||||
except Exception as e:
|
||||
return f"保存模型时出错: {str(e)}"
|
||||
|
||||
save_button.click(fn=save_model, inputs=save_model_name_input, outputs=state_output)
|
||||
|
||||
def refresh_model_list():
|
||||
try:
|
||||
nonlocal model_folders
|
||||
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))]
|
||||
return gr.Dropdown(choices=model_folders)
|
||||
except Exception as e:
|
||||
return f"刷新模型列表时出错: {str(e)}"
|
||||
|
||||
refresh_button.click(fn=refresh_model_list, inputs=None, outputs=model_select_dropdown)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo = model_manage_page()
|
||||
demo.queue()
|
||||
demo.launch()
|
130
frontend/prompt_manage_page.py
Normal file
130
frontend/prompt_manage_page.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import gradio as gr
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_prompt_store
|
||||
from schema.prompt import promptTempleta
|
||||
def prompt_manage_page():
|
||||
def get_prompts() -> List[List[str]]:
|
||||
selected_row = None
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
prompts = db.all()
|
||||
return [
|
||||
[p["id"], p["name"], p["description"], p["content"]]
|
||||
for p in prompts
|
||||
]
|
||||
except Exception as e:
|
||||
raise gr.Error(f"获取提示词失败: {str(e)}")
|
||||
|
||||
def add_prompt(name, description, content):
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
new_prompt = promptTempleta(
|
||||
name=name if name else "",
|
||||
description=description if description else "",
|
||||
content=content if content else ""
|
||||
)
|
||||
prompt_id = db.insert(new_prompt.model_dump())
|
||||
# 更新ID
|
||||
db.update({"id": prompt_id}, doc_ids=[prompt_id])
|
||||
return get_prompts(), "", "", "" # 返回清空后的输入框值
|
||||
except Exception as e:
|
||||
raise gr.Error(f"添加失败: {str(e)}")
|
||||
|
||||
def edit_prompt():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要编辑的行")
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
db.update({
|
||||
"name": selected_row[1] if selected_row[1] else "",
|
||||
"description": selected_row[2] if selected_row[2] else "",
|
||||
"content": selected_row[3] if selected_row[3] else ""
|
||||
}, doc_ids=[selected_row[0]])
|
||||
return get_prompts()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"编辑失败: {str(e)}")
|
||||
|
||||
def delete_prompt():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要删除的行")
|
||||
try:
|
||||
db = get_prompt_store()
|
||||
db.remove(doc_ids=[selected_row[0]])
|
||||
return get_prompts()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"删除失败: {str(e)}")
|
||||
|
||||
selected_row = None # 保存当前选中行的全局变量
|
||||
|
||||
def select_record(dataFrame ,evt: gr.SelectData):
|
||||
global selected_row
|
||||
selected_row = dataFrame.iloc[evt.index[0]].tolist()
|
||||
selected_row[0] = int(selected_row[0])
|
||||
print(selected_row)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 提示词模板管理")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
name_input = gr.Textbox(label="模板名称")
|
||||
description_input = gr.Textbox(label="模板描述")
|
||||
content_input = gr.Textbox(label="模板内容", lines=10)
|
||||
add_button = gr.Button("添加新模板", variant="primary")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
prompt_table = gr.DataFrame(
|
||||
headers=["id", "名称", "描述", "内容"],
|
||||
datatype=["number", "str", "str", "str"],
|
||||
interactive=True,
|
||||
value=get_prompts(),
|
||||
wrap=True,
|
||||
col_count=(4, "auto")
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
refresh_button = gr.Button("刷新数据", variant="secondary")
|
||||
edit_button = gr.Button("编辑选中行", variant="primary")
|
||||
delete_button = gr.Button("删除选中行", variant="stop")
|
||||
|
||||
refresh_button.click(
|
||||
fn=get_prompts,
|
||||
outputs=[prompt_table],
|
||||
queue=False
|
||||
)
|
||||
|
||||
add_button.click(
|
||||
fn=add_prompt,
|
||||
inputs=[name_input, description_input, content_input],
|
||||
outputs=[prompt_table, name_input, description_input, content_input]
|
||||
)
|
||||
|
||||
prompt_table.select(fn=select_record,
|
||||
inputs=[prompt_table],
|
||||
outputs=[],
|
||||
show_progress="hidden")
|
||||
|
||||
edit_button.click(
|
||||
fn=edit_prompt,
|
||||
inputs=[],
|
||||
outputs=[prompt_table]
|
||||
)
|
||||
|
||||
delete_button.click(
|
||||
fn=delete_prompt,
|
||||
inputs=[],
|
||||
outputs=[prompt_table]
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo = prompt_manage_page()
|
||||
demo.queue()
|
||||
demo.launch()
|
131
frontend/setting_page.py
Normal file
131
frontend/setting_page.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import gradio as gr
|
||||
from typing import List
|
||||
from sqlmodel import Session, select
|
||||
from schema import APIProvider
|
||||
from global_var import get_sql_engine
|
||||
|
||||
def setting_page():
|
||||
def get_providers() -> List[List[str]]:
|
||||
selected_row = None
|
||||
try: # 添加异常处理
|
||||
with Session(get_sql_engine()) as session:
|
||||
providers = session.exec(select(APIProvider)).all()
|
||||
return [
|
||||
[p.id, p.model_id, p.base_url, p.api_key or ""]
|
||||
for p in providers
|
||||
]
|
||||
except Exception as e:
|
||||
raise gr.Error(f"获取数据失败: {str(e)}")
|
||||
|
||||
def add_provider(model_id, base_url, api_key):
|
||||
try:
|
||||
with Session(get_sql_engine()) as session:
|
||||
new_provider = APIProvider(
|
||||
model_id=model_id if model_id else None,
|
||||
base_url=base_url if base_url else None,
|
||||
api_key=api_key if api_key else None
|
||||
)
|
||||
session.add(new_provider)
|
||||
session.commit()
|
||||
session.refresh(new_provider)
|
||||
return get_providers(), "", "", "" # 返回清空后的输入框值
|
||||
except Exception as e:
|
||||
raise gr.Error(f"添加失败: {str(e)}")
|
||||
|
||||
def edit_provider():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要编辑的行")
|
||||
try:
|
||||
with Session(get_sql_engine()) as session:
|
||||
provider = session.get(APIProvider, selected_row[0])
|
||||
if not provider:
|
||||
raise gr.Error("找不到选中的记录")
|
||||
provider.model_id = selected_row[1] if selected_row[1] else None
|
||||
provider.base_url = selected_row[2] if selected_row[2] else None
|
||||
provider.api_key = selected_row[3] if selected_row[3] else None
|
||||
session.add(provider)
|
||||
session.commit()
|
||||
session.refresh(provider)
|
||||
return get_providers()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"编辑失败: {str(e)}")
|
||||
|
||||
def delete_provider():
|
||||
global selected_row
|
||||
if not selected_row:
|
||||
raise gr.Error("请先选择要删除的行")
|
||||
try:
|
||||
with Session(get_sql_engine()) as session:
|
||||
provider = session.get(APIProvider, selected_row[0])
|
||||
if not provider:
|
||||
raise gr.Error("找不到选中的记录")
|
||||
session.delete(provider)
|
||||
session.commit()
|
||||
return get_providers()
|
||||
except Exception as e:
|
||||
raise gr.Error(f"删除失败: {str(e)}")
|
||||
|
||||
selected_row = None # 保存当前选中行的全局变量
|
||||
|
||||
def select_record(dataFrame ,evt: gr.SelectData):
|
||||
global selected_row
|
||||
selected_row = dataFrame.iloc[evt.index[0]].tolist()
|
||||
selected_row[0] = int(selected_row[0])
|
||||
print(selected_row)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## API Provider 管理")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
model_id_input = gr.Textbox(label="Model ID")
|
||||
base_url_input = gr.Textbox(label="Base URL")
|
||||
api_key_input = gr.Textbox(label="API Key")
|
||||
add_button = gr.Button("添加新API", variant="primary")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
provider_table = gr.DataFrame(
|
||||
headers=["id", "model id", "base URL", "API Key"],
|
||||
datatype=["number", "str", "str", "str"],
|
||||
interactive=True,
|
||||
value=get_providers(),
|
||||
wrap=True,
|
||||
col_count=(4, "auto")
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
refresh_button = gr.Button("刷新数据", variant="secondary")
|
||||
edit_button = gr.Button("编辑选中行", variant="primary")
|
||||
delete_button = gr.Button("删除选中行", variant="stop")
|
||||
|
||||
refresh_button.click(
|
||||
fn=get_providers,
|
||||
outputs=[provider_table],
|
||||
queue=False # 立即刷新不需要排队
|
||||
)
|
||||
|
||||
add_button.click(
|
||||
fn=add_provider,
|
||||
inputs=[model_id_input, base_url_input, api_key_input],
|
||||
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
|
||||
)
|
||||
|
||||
provider_table.select(fn=select_record,
|
||||
inputs=[provider_table],
|
||||
outputs=[],
|
||||
show_progress="hidden")
|
||||
|
||||
edit_button.click(
|
||||
fn=edit_provider,
|
||||
inputs=[],
|
||||
outputs=[provider_table]
|
||||
)
|
||||
|
||||
delete_button.click(
|
||||
fn=delete_provider,
|
||||
inputs=[],
|
||||
outputs=[provider_table]
|
||||
)
|
||||
|
||||
return demo
|
113
frontend/train_page.py
Normal file
113
frontend/train_page.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import subprocess
|
||||
import os
|
||||
import gradio as gr
|
||||
import sys
|
||||
from tinydb import Query
|
||||
from pathlib import Path
|
||||
from transformers import TrainerCallback
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
||||
from tools import find_available_port
|
||||
from train import train_model
|
||||
|
||||
def train_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 微调")
|
||||
# 获取数据集列表并设置初始值
|
||||
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
||||
initial_dataset = datasets_list[0] if datasets_list else None
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
dataset_dropdown = gr.Dropdown(
|
||||
choices=datasets_list,
|
||||
value=initial_dataset, # 设置初始选中项
|
||||
label="选择数据集",
|
||||
allow_custom_value=True,
|
||||
interactive=True
|
||||
)
|
||||
# 新增超参数输入组件
|
||||
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
||||
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
||||
epoch_input = gr.Number(value=1, label="epoch", precision=0)
|
||||
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
|
||||
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
|
||||
|
||||
train_button = gr.Button("开始微调")
|
||||
|
||||
# 训练状态输出
|
||||
output = gr.Textbox(label="训练状态", interactive=False)
|
||||
with gr.Column(scale=3):
|
||||
# 新增 TensorBoard iframe 显示框
|
||||
tensorboard_iframe = gr.HTML(label="TensorBoard 可视化")
|
||||
|
||||
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||||
# 使用动态传入的超参数
|
||||
learning_rate = float(learning_rate)
|
||||
per_device_train_batch_size = int(per_device_train_batch_size)
|
||||
epoch = int(epoch)
|
||||
save_steps = int(save_steps) # 新增保存步数参数
|
||||
lora_rank = int(lora_rank) # 新增LoRA秩参数
|
||||
|
||||
# 加载数据集
|
||||
dataset = get_datasets().get(Query().name == dataset_name)
|
||||
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
||||
|
||||
# 扫描 training 文件夹并生成递增目录
|
||||
training_dir = get_workdir() + "/training"
|
||||
os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在
|
||||
existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()]
|
||||
next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1
|
||||
new_training_dir = os.path.join(training_dir, str(next_dir_number))
|
||||
|
||||
tensorboard_port = find_available_port(6006) # 从默认端口 6006 开始检测
|
||||
print(f"TensorBoard 将使用端口: {tensorboard_port}")
|
||||
|
||||
tensorboard_logdir = os.path.join(new_training_dir, "logs")
|
||||
os.makedirs(tensorboard_logdir, exist_ok=True) # 确保日志目录存在
|
||||
tensorboard_process = subprocess.Popen(
|
||||
["tensorboard", "--logdir", tensorboard_logdir, "--port", str(tensorboard_port)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
print("TensorBoard 已启动,日志目录:", tensorboard_logdir)
|
||||
|
||||
# 动态生成 TensorBoard iframe
|
||||
tensorboard_url = f"http://localhost:{tensorboard_port}"
|
||||
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>'
|
||||
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
|
||||
|
||||
try:
|
||||
train_model(get_model(), get_tokenizer(),
|
||||
dataset, new_training_dir,
|
||||
learning_rate, per_device_train_batch_size, epoch,
|
||||
save_steps, lora_rank)
|
||||
except Exception as e:
|
||||
raise gr.Error(str(e))
|
||||
finally:
|
||||
# 确保训练结束后终止 TensorBoard 子进程
|
||||
tensorboard_process.terminate()
|
||||
print("TensorBoard 子进程已终止")
|
||||
|
||||
train_button.click(
|
||||
fn=start_training,
|
||||
inputs=[
|
||||
dataset_dropdown,
|
||||
learning_rate_input,
|
||||
per_device_train_batch_size_input,
|
||||
epoch_input,
|
||||
save_steps_input,
|
||||
lora_rank_input
|
||||
],
|
||||
outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from global_var import init_global_var
|
||||
from model_manage_page import model_manage_page
|
||||
init_global_var("workdir")
|
||||
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
|
||||
demo.queue()
|
||||
demo.launch()
|
52
global_var.py
Normal file
52
global_var.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from db import load_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
||||
from tools import scan_docs_directory
|
||||
|
||||
_prompt_store = None
|
||||
_sql_engine = None
|
||||
_datasets = None
|
||||
_model = None
|
||||
_tokenizer = None
|
||||
_workdir = None
|
||||
def init_global_var(workdir="workdir"):
|
||||
global _prompt_store, _sql_engine, _datasets, _workdir
|
||||
_prompt_store = get_prompt_tinydb(workdir)
|
||||
_sql_engine = load_sqlite_engine(workdir)
|
||||
_datasets = get_all_dataset(workdir)
|
||||
_workdir = workdir
|
||||
|
||||
def get_workdir():
|
||||
return _workdir
|
||||
|
||||
def get_prompt_store():
|
||||
return _prompt_store
|
||||
|
||||
def set_prompt_store(new_prompt_store):
|
||||
global _prompt_store
|
||||
_prompt_store = new_prompt_store
|
||||
|
||||
def get_sql_engine():
|
||||
return _sql_engine
|
||||
|
||||
def set_sql_engine(new_sql_engine):
|
||||
global _sql_engine
|
||||
_sql_engine = new_sql_engine
|
||||
|
||||
def get_docs():
|
||||
global _workdir
|
||||
return scan_docs_directory(_workdir)
|
||||
def get_datasets():
|
||||
return _datasets
|
||||
|
||||
def get_model():
|
||||
return _model
|
||||
|
||||
def set_model(new_model):
|
||||
global _model
|
||||
_model = new_model
|
||||
|
||||
def get_tokenizer():
|
||||
return _tokenizer
|
||||
|
||||
def set_tokenizer(new_tokenizer):
|
||||
global _tokenizer
|
||||
_tokenizer = new_tokenizer
|
29
main.py
Normal file
29
main.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import gradio as gr
|
||||
import train
|
||||
from frontend import *
|
||||
from db import initialize_sqlite_db, initialize_prompt_store
|
||||
from global_var import init_global_var, get_sql_engine, get_prompt_store
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_global_var()
|
||||
initialize_sqlite_db(get_sql_engine())
|
||||
initialize_prompt_store(get_prompt_store())
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("模型管理"):
|
||||
model_manage_page()
|
||||
with gr.TabItem("模型推理"):
|
||||
chat_page()
|
||||
with gr.TabItem("模型微调"):
|
||||
train_page()
|
||||
with gr.TabItem("数据集生成"):
|
||||
dataset_generate_page()
|
||||
with gr.TabItem("数据集管理"):
|
||||
dataset_manage_page()
|
||||
with gr.TabItem("提示词模板管理"):
|
||||
prompt_manage_page()
|
||||
with gr.TabItem("设置"):
|
||||
setting_page()
|
||||
|
||||
app.launch()
|
18
paper/latex/chapters/abstract.tex
Normal file
18
paper/latex/chapters/abstract.tex
Normal file
@@ -0,0 +1,18 @@
|
||||
% 摘要
|
||||
\begin{center}
|
||||
{\zihao{3}\textbf{毕业论文系统设计}}\par
|
||||
{\zihao{-4}\songti 计算机科学与技术 \quad 专业 \quad 计科211(创) \quad 张三 \par
|
||||
指导教师:李四教授}
|
||||
\end{center}
|
||||
|
||||
% 中文摘要
|
||||
\begin{onecolabstract}
|
||||
\noindent{}\makebox[5em][l]{{\zihao{4}\textbf{摘要}}}{\songti \zihao{-4}本文研究了一种基于文档驱动的自适应编码大模型微调框架。该框架通过分析文档结构自动生成训练样本,实现了模型参数的高效优化。实验结果表明,该方法在多个NLP任务上取得了显著的性能提升,同时减少了人工标注的工作量。}\par
|
||||
\noindent{}\makebox[5em][l]{{\zihao{4}\textbf{关键词}}}{\zihao{-4}\songti 关键词1;关键词2}\par
|
||||
\end{onecolabstract}
|
||||
|
||||
% 英文摘要
|
||||
\begin{onecolabstract}
|
||||
\noindent{}\makebox[10em][l]{{\zihao{4} \textbf{ABSTRACT}}}{\zihao{-4}This paper proposes a document-driven adaptive fine-tuning framework for large coding models. By analyzing document structures to automatically generate training samples, the framework achieves efficient optimization of model parameters. Experimental results demonstrate significant performance improvements on multiple NLP tasks while reducing manual annotation workload.}\par
|
||||
\noindent{}\makebox[10em][l]{{\zihao{4}\textbf{KEYWORDS}}}{\zihao{-4}Document-driven; Adaptive fine-tuning; Large language models; NLP tasks; Automatic annotation}\par
|
||||
\end{onecolabstract}
|
12
paper/latex/chapters/acknowledgement.tex
Normal file
12
paper/latex/chapters/acknowledgement.tex
Normal file
@@ -0,0 +1,12 @@
|
||||
% 致谢章节
|
||||
|
||||
\section*{致谢}
|
||||
\addcontentsline{toc}{section}{致谢\tiny{\quad.}}
|
||||
|
||||
在此,我衷心感谢我的导师XXX教授在论文写作过程中给予的悉心指导和宝贵建议。同时,感谢计算机科学与网络工程学院的各位老师四年来对我的培养和教育。
|
||||
|
||||
感谢我的同学和朋友在学习和生活中给予的帮助和支持。最后,特别感谢我的家人一直以来的鼓励和理解。
|
||||
\par
|
||||
\vspace{5ex}
|
||||
\rightline{\zihao{3}{苏伟强\quad\qquad}}
|
||||
\rightline{二O一九年五月十九日于广州}
|
6
paper/latex/chapters/conclusion.tex
Normal file
6
paper/latex/chapters/conclusion.tex
Normal file
@@ -0,0 +1,6 @@
|
||||
% 第五章 总结与展望
|
||||
|
||||
\section{总结与展望}
|
||||
|
||||
|
||||
本文提出的文档驱动自适应编码框架有效解决了大模型微调中的样本利用率问题,实验结果表明该方法在多个NLP任务上性能提升显著。
|
36
paper/latex/chapters/cover.tex
Normal file
36
paper/latex/chapters/cover.tex
Normal file
@@ -0,0 +1,36 @@
|
||||
% 封面部分
|
||||
\begin{titlepage}
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
\includegraphics[scale=0.5]{pic//logo//logo.jpg}
|
||||
\end{figure}
|
||||
\vspace{0.2cm}
|
||||
\centering
|
||||
|
||||
{\zihao{1}\songti{本科毕业论文(设计)}}
|
||||
|
||||
\vspace{2.5cm}
|
||||
|
||||
\begin{flushleft}
|
||||
{{\songti \zihao{-3} \qquad\qquad\qquad 课题名称}\quad{\zihao{-4}\dlmu[7.5cm]{基于文档驱动的自适应编码大模型微调框架}}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\songti\zihao{-3} \qquad\qquad\qquad 学\qquad 院}\quad\dlmu[7.5cm]{计算机科学与网络工程学院}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\songti\zihao{-3} \qquad\qquad\qquad 专\qquad 业}\quad\dlmu[7.5cm]{计算机科学与技术}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\songti\zihao{-3} \qquad\qquad\qquad 班级名称}\quad\dlmu[7.5cm]{计科211(创)}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\songti\zihao{-3} \qquad\qquad\qquad 学生姓名}\quad\dlmu[7.5cm]{张三}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\songti\zihao{-3} \qquad\qquad\qquad 学\qquad 号}\quad\dlmu[7.5cm]{20210001}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\songti\zihao{-3} \qquad\qquad\qquad 指导老师}\quad\dlmu[7.5cm]{李四教授}\par}
|
||||
\vspace{0.5cm}
|
||||
{{\songti\zihao{-3} \qquad\qquad\qquad 完成日期}\quad\dlmu[7.5cm]{\number\year 年\number\month 月\number\day 日}\par}
|
||||
\end{flushleft}
|
||||
|
||||
\vspace{4cm}
|
||||
|
||||
{\songti \zihao{3} 教务处制}
|
||||
|
||||
\end{titlepage}
|
3
paper/latex/chapters/implementation.tex
Normal file
3
paper/latex/chapters/implementation.tex
Normal file
@@ -0,0 +1,3 @@
|
||||
% 第四章:关键技术实现
|
||||
|
||||
\section{关键技术实现}
|
76
paper/latex/chapters/introduction.tex
Normal file
76
paper/latex/chapters/introduction.tex
Normal file
@@ -0,0 +1,76 @@
|
||||
% 第一章:绪论
|
||||
|
||||
\section{绪论}
|
||||
|
||||
\subsection{研究背景与意义}
|
||||
|
||||
在现代软件开发领域,程序员的编码工作日益依赖于先进的大语言模型支持,这些模型凭借其强大的能力,显著自动化了代码生成流程,有效减轻了开发者的工作负担,并大幅度提升了开发效率。然而,尽管这些模型在公开数据集与广泛使用的开源项目中展现出非凡的性能,但在处理企业内部高度专业化的私有库时,其局限性便显露无遗。核心原因在于,大语言模型往往基于广泛的通用数据集进行训练,缺乏对特定企业或项目中私有库内专有函数、类及其交互细节的深度理解和应用适应性。
|
||||
相较于广泛采用的公开编码模型,针对私有库设计的专有模型显得尤为必要。公开模型虽强大,但在面对包含企业核心业务逻辑、技术秘密及高度定制化功能的私有库时,往往捉襟见肘。由于缺乏对私有库具体实现细节的认知,生成的代码往往无法精准引用库中的类、方法或属性,这不仅增加了后续人工调整的工作量,还可能引入潜在的安全风险。此外,企业间的私有库差异巨大,从架构设计到API接口各不相同,要求任何自动化工具都必须具备高度的灵活性和可定制性,以适应这些多样化的环境。
|
||||
鉴于上述现状,本项目通过深度解析私有库的文档资源,精准提取关键信息,并以此为基础对大语言模型进行针对性的微调与优化。这一过程不仅增强了模型对私有库特定功能和用法的理解能力,还极大地提升了生成代码的准确性和实用性。通过本项目,我们期望能够让生成的代码片段无缝集成于企业的私有库生态中,真正实现企业级软件开发的智能化与高效化,满足企业对高质量、高安全性代码的迫切需求。
|
||||
|
||||
\subsection{国内外研究现状}
|
||||
|
||||
人工智能辅助编码,作为软件开发领域的一项前沿技术,近年来取得了显著的进展,并展现出巨大的潜力。随着计算能力的提升和机器学习技术的进步,特别是大型语言模型(LLMs)的兴起,AI正在深刻地改变着代码的编写、测试和维护方式。本文将对人工智能辅助编码的最新进展、应用、潜力、挑战以及未来方向进行全面的概述。
|
||||
|
||||
\subsubsection{人工智能辅助编码的最新进展和趋势}
|
||||
|
||||
人工智能辅助编码的最新进展和趋势表明,AI正逐渐成为开发者的重要伙伴,甚至在某些方面能够独立完成复杂的编程任务。到2025年,人们普遍预计,先进的AI模型将不再仅仅提供代码片段,而是能够自动创建完整的应用程序,而人类只需进行最少的指导。GitHub Copilot 和 OpenAI Codex 等工具正处于这一趋势的前沿,使得编程变得更加快速和便捷。
|
||||
|
||||
主流模型和技术方面,大型语言模型是当前研究和应用的核心。这些模型,例如 OpenAI 的 GPT 系列、Google 的 Gemini、Anthropic 的 Claude 以及国内的文心一言、通义千问等,通过对海量文本和代码数据的学习,展现出强大的代码生成和理解能力。它们能够理解开发者的自然语言指令,并将其转化为可执行的代码。此外,这些模型还在不断进化,例如,OpenAI 的 GPT-4o 在多项编码基准测试中表现领先,而 Anthropic 的 Claude 和 Google 的 Gemini 也紧随其后。一些开源模型,如 Code Llama,也为研究和应用提供了更多的选择。
|
||||
|
||||
最新的趋势包括 AI 生成代码成为常态,AI 技术深度融入低代码/无代码平台,AI 赋能的测试和调试,AI 优化的性能调优,以及个性化的 AI 驱动的 UX/UI 设计。到2025年,AI 工具将尝试编写完整的模块,主动优化软件,并自动处理调试。此外,AI 还将被集成到 SaaS 和云原生开发中,提升应用程序的自主性、自愈能力和性能优化。预测分析、高度个性化和自学习功能将成为 AI 驱动开发的主导,从而最大限度地减少编码工作并优化性能和用户体验。
|
||||
|
||||
值得关注的是,国内的 AI 大模型也在编程辅助领域取得了显著进展。例如,科大讯飞的讯飞星火认知大模型具备代码生成、注释和错误检测能力。百度推出了基于文心大模型的智能代码助手 Comate,提供智能推荐、生成和问答等功能。阿里巴巴的通义千问在代码理解方面也表现出色。这些进展表明,国内 AI 技术正在迅速追赶国际领先水平。
|
||||
|
||||
\subsubsection{AI在软件开发的不同阶段的研究与应用}
|
||||
|
||||
AI 在软件开发的不同阶段都展现出强大的应用潜力。
|
||||
在代码生成和补全方面,AI 工具如 GitHub Copilot 和 ChatGPT 能够实时建议代码片段和补全函数,极大地提高了开发速度。Google 的 Gemini Code Assist 甚至可以分析项目上下文和编程语言,生成相关的代码段,从而自动化重复性任务。百度 Comate 和商汤代码小浣熊等国内工具也提供了类似的功能,支持多种编程语言和集成开发环境。
|
||||
在错误检测方面,AI 驱动的工具能够分析代码以识别错误、漏洞和低效之处,帮助开发者提高软件质量。例如,科大讯飞的讯飞星火可以精准定位代码语法和逻辑错误。IBM watsonx Code Assistant 等工具也具备错误检测和修复建议功能。
|
||||
在代码优化方面,AI 可以提供代码改进建议,以优化性能并降低维护难度。AI 能够检测不良的编码实践,并根据最佳实践提出改进建议,还可以分析并提高代码的效率。
|
||||
在自动化测试方面,AI 可以增强测试自动化,通过更快地检测潜在问题来确保软件的可靠性。例如,AI 可以自动生成并执行测试用例,优化测试覆盖范围,并及早发现错误。商汤代码小浣熊也具备辅助进行单元测试的能力。
|
||||
此外,AI 还在应用开发方面发挥作用。低代码/无代码 AI 平台使得即使没有广泛编程知识的人也能够创建应用程序。这些平台通过 AI 驱动的易用性,降低了软件开发的门槛。
|
||||
值得一提的是,AI 在代码审查自动化方面也取得了进展。例如,CodeAgent 是一种新型的多智能体大型语言模型系统,用于自动化代码审查。它能够检测代码变更与提交消息之间的不一致性,识别潜在的漏洞,验证代码风格的一致性,并提出代码修改建议。
|
||||
|
||||
\subsubsection{AI辅助编码在不同编程语言和环境中的实际应用和用户反馈}
|
||||
AI 辅助编码工具的实际应用非常广泛,支持多种编程语言和开发环境。这些工具通常以插件的形式集成到流行的集成开发环境(IDEs)中,如 VS Code、IntelliJ IDEA 和 Android Studio。它们支持包括 Python、Java、JavaScript、C++、Go 和 SQL 等在内的 100 多种主流编程语言。
|
||||
用户反馈普遍积极,许多开发者报告了显著的生产力提升。例如,有开发者报告使用 AI 工具后,项目时间缩短了一半,生产力提高了 50\%。RunSignup 团队通过使用基于自身代码库训练的 AI 工具,预计在 2025 年开发效率将提高约 20\%。Salesforce 的 CEO Marc Benioff 也表示,他们的工程生产力提高了 30\%。
|
||||
用户赞赏 AI 工具能够自动化重复性任务,例如生成样板代码、修复语法错误和起草文档,从而使开发者能够专注于更复杂的挑战,如系统设计和算法创新。GitHub Copilot 的用户 Lars Gyrup Brink Nielsen 甚至表示,没有它,他再也不会开发软件了。
|
||||
然而,用户反馈也并非完全没有担忧。一些开发者担心过度依赖 AI 可能会导致对代码库的理解不够深入,从而在未来需要进行修改时更容易引入缺陷。此外,对于 AI 生成代码的质量和安全性也存在一定的疑虑,这将在后续章节中进一步讨论。
|
||||
|
||||
\subsubsection{AI辅助编码在提升开发效率、代码质量和降低成本方面的潜力}
|
||||
AI 辅助编码在提升开发效率、代码质量和降低成本方面展现出巨大的潜力。
|
||||
提升开发效率方面,AI 能够自动化重复性任务,加速开发时间,并提供实时的支持和建议,从而显著提高开发速度。例如,通过自动生成代码片段和脚本,AI 可以大幅减少编码所需的时间,这对于需要在短时间内完成项目的场景尤为重要。
|
||||
提高代码质量方面,AI 可以帮助减少错误,确保代码风格和最佳实践的一致性,并为新手开发者提供学习机会,从而提高整体代码质量。AI 驱动的工具能够检测代码中的错误和低效之处,并提供修复建议。此外,AI 还可以增强测试自动化,确保软件的可靠性。
|
||||
降低成本方面,AI 有潜力通过减少开发团队规模、缩短上市时间以及降低软件开发的门槛来实现成本节约。对于初创企业和中小型企业而言,AI 编码助手可以 democratize 高质量的编码专业知识,从而减少对大型开发团队的需求。
|
||||
然而,值得注意的是,尽管 AI 可以提高生产力,但研究表明,AI 辅助的代码中代码克隆的现象显著增加。这可能对代码的可维护性产生负面影响,并可能导致技术债务的累积。因此,在追求效率的同时,仍需关注代码的长期质量和可维护性。
|
||||
|
||||
\subsubsection{研究AI辅助编码面临的挑战和未来方向}
|
||||
尽管人工智能辅助编码带来了诸多益处,但也面临着一些重要的挑战,并且未来的发展方向也需要仔细考量。
|
||||
安全性是首要的挑战之一。研究表明,AI 生成的代码可能包含与手动编写的代码相似甚至更高的安全漏洞。这些漏洞包括 SQL 注入、跨站脚本攻击(XSS)等常见安全问题。AI 模型通过复制其训练数据中的模式来工作,如果训练数据包含不安全的编码模式,AI 可能会在生成的代码中重现这些模式。此外,AI 工具可能缺乏对特定应用程序上下文和安全要求的全面理解,从而无法生成完全安全的代码。新兴的攻击手段,例如针对 GitHub Copilot 和 Cursor 的“规则文件后门”技术,也表明 AI 编码助手本身可能成为新的攻击媒介。因此,对 AI 生成的代码进行严格的安全扫描和人工审查至关重要。
|
||||
可解释性是另一个重要的挑战。许多 AI 模型,特别是基于深度学习的模型,本质上是“黑箱”,难以理解其代码生成的过程。这种不透明性使得调试、信任以及识别潜在的偏见或错误变得困难。为了解决这个问题,未来的研究需要侧重于开发更具可解释性的 AI 技术,使用户能够理解 AI 生成代码的原因和逻辑。
|
||||
伦理问题也日益受到关注。AI 辅助编码的广泛应用可能会导致软件开发人员的就业岗位流失。虽然 AI 也可能创造新的工作岗位,但如何确保工人能够顺利过渡到新的角色是一个重要的社会问题。此外,训练数据中的偏见可能会导致 AI 模型生成带有歧视性或不公平的代码。例如,如果训练数据主要来自特定人群编写的代码,AI 可能会偏向于这些编码风格或实践,而忽略其他更优的方案。因此,负责任的 AI 开发实践至关重要,包括确保数据的多样性和公正性,以及在开发过程中考虑到伦理因素。
|
||||
未来的研究方向包括提高 AI 生成代码的准确性和可靠性,增强其对复杂架构上下文的理解,以及开发更好的代码质量评估指标。此外,还需要深入研究 AI 对软件开发的长期影响以及人类开发者角色的演变。人机协作被认为是未来的重要发展方向,即 AI 系统与人类程序员协同工作,共同提高编码效率和软件质量。
|
||||
为了应对这些挑战,需要从多个层面进行努力,包括制定伦理指导原则、加强外部监管、推动国际合作以及保护用户权益。只有这样,才能确保 AI 辅助编码技术朝着负责任和有益的方向发展。
|
||||
|
||||
\subsubsection{总结当前研究现状}
|
||||
当前,人工智能辅助编码正处于快速发展阶段。大型语言模型作为核心技术,在代码生成、补全、错误检测、优化和自动化测试等方面展现出强大的能力。AI 工具已经广泛应用于各种编程语言和开发环境中,并获得了用户的积极反馈,普遍认为能够显著提升开发效率、代码质量并降低成本。
|
||||
然而,研究也揭示了 AI 辅助编码面临的严峻挑战,主要集中在安全性、可解释性和伦理问题上。AI 生成的代码可能存在安全漏洞,模型的决策过程往往难以解释,并且 AI 的应用也引发了关于就业、偏见和责任的伦理担忧。
|
||||
未来的研究方向将侧重于克服这些挑战,例如开发更安全的 AI 模型,提高模型的可解释性,以及制定负责任的 AI 开发和部署框架。人机协作模式被认为是未来的趋势,AI 将成为开发者更强大、更智能的助手。持续的研究和跨领域的合作对于确保 AI 辅助编码技术的健康发展和广泛应用至关重要。
|
||||
|
||||
\subsection{本文结构安排}
|
||||
|
||||
本文围绕基于大型语言模型的自动化微调框架展开研究与实现,全文共分为五章,具体结构安排如下:
|
||||
|
||||
第一章 前言:本章首先介绍了研究的背景与意义,阐述了大型语言模型微调自动化的重要性和必要性。随后,对国内外相关的研究现状进行了回顾与分析,指出了现有方法的优势与不足。最后,概述了本文的主要研究内容,并介绍了论文的整体结构安排。
|
||||
|
||||
第二章 相关技术介绍:本章详细介绍了本文研究所涉及的关键技术。包括大型语言模型(LLM)的发展、应用及在辅助编码方面的潜力;提示工程技术在引导LLM生成高质量文本中的作用;模型量化技术及其在降低模型部署成本方面的意义;LoRA(Low-Rank Adaptation)等参数高效微调方法,特别是QLoRA的原理与优势;优化微调效率的unsloth算子;以及用于构建交互式界面的Gradio框架。
|
||||
|
||||
第三章 需求分析:本章从项目整体出发,对基于大型语言模型的自动化微调框架进行了需求分析。首先介绍了项目的整体目标和应用场景。然后,详细分析了系统的功能需求,包括训练语料生成、模型微调、自动化整合以及前端展示等核心功能。最后,阐述了系统的非功能需求,如性能要求和扩展性要求。
|
||||
|
||||
第四章 关键技术实现:本章详细阐述了系统的具体实现过程。首先介绍了系统的整体架构设计、模块划分与交互流程。接着,描述了双数据库架构(SQLite+TinyDB)的设计与实现方案,以及数据模型定义和数据库管理。详细介绍了语料生成与处理技术,包括Markdown文档解析、Prompt模板应用、API协程并发调用以及数据校验与持久化。重点阐述了语言模型训练技术的实现,涵盖监督式微调(SFT)流程、训练数据准备、LoRA微调方法应用、训练配置、监控与结果保存。随后,介绍了基于Gradio框架的前端交互系统设计与实现,包括全局状态管理、前后端数据流、流式响应与实时反馈以及异常处理。最后,探讨了系统的扩展性实现方案。
|
||||
|
||||
第五章 总结与展望:本章对本文的研究工作进行了全面的总结,回顾了所取得的主要成果。同时,分析了当前研究存在的不足与局限性。最后,对未来的研究方向和可能的技术发展进行了展望。
|
||||
|
||||
\subsection{小结}
|
||||
本章作为全文的引言部分,首先阐明了在当前大型语言模型蓬勃发展的背景下,构建自动化微调框架的研究背景和重要的现实意义。通过对国内外相关研究现状的梳理,我们认识到自动化、高效化微调工具的缺失是当前LLM应用落地的瓶颈之一,这进一步凸显了本研究的价值。本章还概述了本文的主要研究内容,旨在通过整合先进的语料生成、模型微调和前端交互技术,构建一个用户友好、高效灵活的LLM自动化微调框架。最后,详细介绍了本文的章节结构安排,为读者清晰地勾勒出后续内容的逻辑脉络,为深入理解本文的研究工作奠定了基础。
|
12
paper/latex/chapters/references.tex
Normal file
12
paper/latex/chapters/references.tex
Normal file
@@ -0,0 +1,12 @@
|
||||
% 参考文献章节
|
||||
|
||||
\renewcommand\refname{参考文献}
|
||||
\begin{thebibliography}{0}
|
||||
\addcontentsline{toc}{section}{参考文献\tiny{\quad}}
|
||||
|
||||
\bibitem{数字信号处理教材}程佩青. 数字信号处理教程[M]. 清华大学出版社有限公司, 2001.
|
||||
%
|
||||
%\bibitem{信号与系统}陈后金. 信号与系统[M]. 清华大学出版社有限公司, 2003.
|
||||
%
|
||||
%\bibitem{数字信号处理教材陈}陈后金. 数字信号处理.2版[M]. 高等教育出版社, 2008.
|
||||
\end{thebibliography}
|
3
paper/latex/chapters/requirement.tex
Normal file
3
paper/latex/chapters/requirement.tex
Normal file
@@ -0,0 +1,3 @@
|
||||
% 第三章:需求分析
|
||||
|
||||
\section{需求分析}
|
3
paper/latex/chapters/technology.tex
Normal file
3
paper/latex/chapters/technology.tex
Normal file
@@ -0,0 +1,3 @@
|
||||
% 第二章:相关技术介绍
|
||||
|
||||
\section{相关技术介绍}
|
148
paper/latex/main.tex
Normal file
148
paper/latex/main.tex
Normal file
@@ -0,0 +1,148 @@
|
||||
\documentclass[12pt,a4paper]{article}
|
||||
\usepackage{graphicx}
|
||||
\usepackage{ctex}
|
||||
\usepackage{indentfirst}
|
||||
%\graphicspath{{chapter/}{figures/}}
|
||||
\usepackage{CJK}
|
||||
\usepackage{amsmath}%数学
|
||||
|
||||
%\usepackage[colorlinks,linkcolor=red]{hyperref}%超链接
|
||||
|
||||
\usepackage{fancyhdr} %使用fancyhdr包自定义页眉页脚
|
||||
%\pagestyle{empty}
|
||||
\pagestyle{fancy}
|
||||
%\pagestyle{plain}%没有页眉,页脚放页数
|
||||
|
||||
\usepackage{titlesec}%设置章节标题与正文间距为2行
|
||||
\titlespacing{\section}{0pt}{0pt}{2em}
|
||||
|
||||
\usepackage{enumerate}%项目编号
|
||||
|
||||
\renewcommand{\figurename}{图}%将figure改为图
|
||||
|
||||
\usepackage[]{caption2}%去掉图片编号后的":"
|
||||
\renewcommand{\captionlabeldelim}{}
|
||||
|
||||
\renewcommand {\thefigure} {\thesection{}.\arabic{figure}}%图片索引该为按照章节
|
||||
|
||||
\renewcommand{\headrulewidth}{0pt}
|
||||
\renewcommand{\footrulewidth}{0pt}
|
||||
\lhead{}
|
||||
\chead{}
|
||||
\rhead{}
|
||||
\lfoot{}
|
||||
\cfoot{\thepage}
|
||||
\rfoot{}
|
||||
|
||||
\usepackage{booktabs}%表格用
|
||||
|
||||
\usepackage{titlesec}%修改标题格式宏包
|
||||
\titleformat{\section}{\centering\zihao{3}\songti\bfseries}{\arabic{section}.}{0.5em}{}%修改section标题格式
|
||||
|
||||
% 设置目录格式为宋体
|
||||
\renewcommand\contentsname{\songti 目录}
|
||||
|
||||
\usepackage{multirow}%跨行表格
|
||||
\usepackage{abstract}%摘要
|
||||
\usepackage{setspace} %行间距的宏包
|
||||
|
||||
\usepackage{makecell}%表格竖线连续
|
||||
|
||||
\def\I{\vrule width1.2pt}
|
||||
%!\I 就可以代替| 来画表格了
|
||||
|
||||
%可固定下划线长度
|
||||
\makeatletter
|
||||
\newcommand\dlmu[2][4cm]{\hskip1pt\underline{\hb@xt@ #1{\hss#2\hss}}\hskip3pt}
|
||||
\makeatother
|
||||
|
||||
\usepackage{float}%可以用于禁止浮动体浮动
|
||||
|
||||
|
||||
|
||||
%目录超链接
|
||||
\usepackage[colorlinks,linkcolor=black,anchorcolor=blue,citecolor=black]{hyperref}
|
||||
|
||||
\usepackage{listings}%可以插入代码
|
||||
\usepackage{xcolor}%语法高亮支持
|
||||
|
||||
%代码格式
|
||||
\definecolor{dkgreen}{rgb}{0,0.6,0}
|
||||
\definecolor{gray}{rgb}{0.5,0.5,0.5}
|
||||
\definecolor{mauve}{rgb}{0.58,0,0.82}
|
||||
\usepackage{fontspec}
|
||||
\setmonofont{Consolas}
|
||||
\lstset{ %
|
||||
numbers=left,
|
||||
basicstyle=\tiny\ttfamily,
|
||||
numberstyle=\tiny,
|
||||
tabsize=4,
|
||||
numbersep=5pt,
|
||||
keywordstyle= \color{blue!70}, %关键词为蓝色
|
||||
commentstyle=\color{gray}, %注释为灰色
|
||||
frame=shadowbox, % 框架阴影效果
|
||||
rulesepcolor= \color{ red!20!green!20!blue!20} ,
|
||||
escapeinside={\%*}{*)},
|
||||
xleftmargin=2em, % 边界选项
|
||||
xrightmargin=2em, % 边界选项
|
||||
aboveskip=1em, % 边界选项
|
||||
framexleftmargin=2em, % 边界选项
|
||||
breaklines,%过长代码自动换行
|
||||
}
|
||||
|
||||
|
||||
\usepackage{titlesec}
|
||||
\titleformat{\section}{\centering\Large\bfseries}{第\,\thesection\,章}{1em}{}
|
||||
|
||||
|
||||
\usepackage{titletoc}
|
||||
\titlecontents{section}[0pt]{\addvspace{1.5pt}\filright\bf}
|
||||
{\contentspush{第\thecontentslabel\ 章\quad}}
|
||||
{}{\titlerule*[8pt]{.}\contentspage}
|
||||
|
||||
|
||||
% 设置页面格式
|
||||
\usepackage[left=3.0cm, right=2.6cm, top=2.54cm, bottom=2.54cm]{geometry}
|
||||
\begin{document}
|
||||
|
||||
|
||||
% 封面部分
|
||||
\input{chapters/cover}
|
||||
|
||||
\renewcommand{\abstractname}{\scriptsize}
|
||||
|
||||
%% 摘要和关键词部分
|
||||
\input{chapters/abstract}
|
||||
\newpage
|
||||
|
||||
% 目录部分
|
||||
\renewcommand{\contentsname}{\centerline{\zihao{-2}\textbf{目录}}}
|
||||
|
||||
\tableofcontents
|
||||
\newpage
|
||||
|
||||
% 正文部分
|
||||
{
|
||||
\setlength{\baselineskip}{23pt}
|
||||
|
||||
% 引入各章节文件
|
||||
\newpage
|
||||
\input{chapters/introduction}
|
||||
\newpage
|
||||
\input{chapters/technology}
|
||||
\newpage
|
||||
\input{chapters/requirement}
|
||||
\newpage
|
||||
\input{chapters/implementation}
|
||||
\newpage
|
||||
\input{chapters/conclusion}
|
||||
|
||||
|
||||
% 参考文献
|
||||
\newpage
|
||||
\input{chapters/references}
|
||||
%致谢
|
||||
\newpage
|
||||
\input{chapters/acknowledgement}
|
||||
}
|
||||
\end{document}
|
BIN
paper/latex/pic/logo/logo.jpg
Normal file
BIN
paper/latex/pic/logo/logo.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
164
paper/latex/reference.tex
Normal file
164
paper/latex/reference.tex
Normal file
@@ -0,0 +1,164 @@
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%双并列图片示例
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%\begin{figure}[H]
|
||||
% \centering
|
||||
% \begin{minipage}[t]{0,40\textwidth}
|
||||
% \centering
|
||||
% \includegraphics[scale=0.5]{ccjg.pdf} %1.png是图片文件的相对路径
|
||||
% \caption{IEEE 802.11层次结构} %caption是图片的标题
|
||||
% \label{p_ccjg} %此处的label相当于一个图片的专属标志,目的是方便上下文的引用
|
||||
% \end{minipage}
|
||||
% \hfil
|
||||
% \begin{minipage}[t]{0,50\textwidth}
|
||||
% \centering
|
||||
% \includegraphics[scale=1]{AODV.pdf} %1.png是图片文件的相对路径
|
||||
% \caption{AODV示意图} %caption是图片的标题
|
||||
% \label{p_AODV} %此处的label相当于一个图片的专属标志,目的是方便上下文的引用
|
||||
% \end{minipage}
|
||||
%\end{figure}
|
||||
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%表格示例
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%\begin{table}[H]
|
||||
% \centering
|
||||
% \caption{802.11a/b/g物理层,MAC层参数}
|
||||
% \begin{tabular}{ccccc}
|
||||
% \toprule
|
||||
% & 参数 & 802.11a & 802.11b & 802.11g \\
|
||||
% \midrule
|
||||
% \multirow{4}[7]{*}{物理层} & 频带/Hz(freq\_) & $5*10^9$ & $2.4*10^9$ & $2.4*10^9$ \\
|
||||
% \cmidrule{3-5} & 通信感知范围\cite{bib13}(CSThresh\_) & $3.17291*10^9$ & $2.79*10^9$ & $2.79*10^9$ \\
|
||||
% \cmidrule{3-5} & 可通信范围\cite{bib13}(RXThresh\_) & $6.5556*10^{10}$ & $5.76*10^9$ & $5.76*10^9$ \\
|
||||
% \cmidrule{3-5} & 传输功率/W(Pt\_) & 0.281838 & 0.281838 & 0.281838 \\
|
||||
% \midrule
|
||||
% \multirow{9}[17]{*}{MAC层} & 竞争窗口最小值\cite{bib12}/s(CWMin) & 15 & 31 & 15 \\
|
||||
% \cmidrule{3-5} & 竞争窗口最大值\cite{bib12}/s(CWMax) & 1023 & 1023 & 1023 \\
|
||||
% \cmidrule{3-5} & 时隙\cite{bib11}/s(SlotTime\_) & 0.00005 & 0.00002 & 0.000009s \\
|
||||
% \cmidrule{3-5} & SIFS\cite{bib14}\cite{bib11}/s(SIFS\_) & 0.000016 & 0.00001 & 0.000016s \\
|
||||
% \cmidrule{3-5} & 前导码长度\cite{bib14}(PreambleLength) & 96 & 144 & 120 \\
|
||||
% \cmidrule{3-5} & PLCP头部长度\cite{bib14}PLCPHeaderLength\_) & 24 & 48 & 24 \\
|
||||
% \cmidrule{3-5} & PLCP数据率\cite{bib14}/bps(PLCPDataRate\_) & $6*10^6$ & $1*10^6$ & $6*10^6$ \\
|
||||
% \cmidrule{3-5} & 最高速率\cite{bib14}/bps(dataRate) & $5.4*10^7$ & $1.1*10^7$ & $5.4*10^7$ \\
|
||||
% \cmidrule{3-5} & 最低速率\cite{bib14}/bps(basicRate\_) & $6*10^6$ & $1*10^6$ & $6*10^6$ \\
|
||||
% \bottomrule
|
||||
% \end{tabular}%
|
||||
% \label{t_abgcs}%
|
||||
%\end{table}%
|
||||
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%插入代码示例
|
||||
%%%%%%%%%%%%title:代码文件标题
|
||||
%%%%%%%%%%%%language:语言,C++,C,Matlab,Python
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%插入代码的时候需要知:注释中同时出现标点符号,英文,中文时会互相影响,
|
||||
%这个时候,在标点符号,英文后面都要追加空格,才能正常显示
|
||||
%\lstset{language=C++}
|
||||
%\begin{lstlisting}[title=AODV100.tr]
|
||||
%
|
||||
%\end{lstlisting}
|
||||
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%对齐公式示例
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
|
||||
%\begin{align}
|
||||
% \label{kk}
|
||||
% k&=\dfrac{3Z_{11}^{'}}{2(1-l^2_2)^{3/2}}\\
|
||||
% \label{hh}
|
||||
% h&=\frac{1}{\pi}\left[Z_{00}-\frac{k\pi}{2}+k\arcsin(l_2)+kl_2\sqrt{1-l^2_2} \right]\\
|
||||
% \label{ll} l&=\frac{1}{2}\left[\sqrt{\frac{5Z_{40}^{'}+3Z^{'}_{20}}{8Z_{20}}}+\sqrt{\frac{5Z_{11}^{'}+Z^{'}_{11}}{6Z_{11}}}\right]\\
|
||||
% \label{pp}
|
||||
% \phi&=\arctan\left[\frac{Im[Z_{n1}]}{Re[Z_{n1}]}\right]
|
||||
%\end{align}
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%表格示例2
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%\begin{table}[H]
|
||||
% \centering
|
||||
% \caption{NVIDIA$^{\textregistered}$ Jetson TK1配置一览}
|
||||
% \vspace{0.5cm}
|
||||
% \begin{tabular}{l}
|
||||
% \Xhline{1.2pt}
|
||||
% Tegra K1 SOC \\
|
||||
% NVIDIA$^{\textregistered}$ Kepler$^{\textregistered}$ GPU、192 个 CUDA 核心 \\
|
||||
% NVIDIA$^{\textregistered}$ 4-Plus-1™ 四核 ARM$^{\textregistered}$ Cortex™-A15 CPU \\
|
||||
% 2 GB x16 内存、64 位宽度 \\
|
||||
% 16 GB 4.51 eMMC 内存 \\
|
||||
% 1 个 USB 3.0 端口、A \\
|
||||
% 1 个 USB 2.0 端口、Micro AB\\
|
||||
% 1 个半迷你 PCIE 插槽\\
|
||||
% 1 个完整尺寸 SD/MMC 连接器\\
|
||||
% 1 个 RTL8111GS Realtek 千兆位以太网局域网 \\
|
||||
% 1 个 SATA 数据端口 \\
|
||||
% 1 个完整尺寸 HDMI 端口 \\
|
||||
% 1 个 RS232 串行端口 \\
|
||||
% SPI 4 兆字节引导闪存\\
|
||||
% 1 个带 Mic In 和 Line Out 的 ALC5639 Realtek 音频编解码器\\
|
||||
% 以下信号可通过扩展端口获得:DP/LVDS, Touch SPI 1x4 + 1x1 CSI-2, GPIOs, UART, HSIC, I$^2$C
|
||||
% \\
|
||||
% \Xhline{1.2pt}
|
||||
% \end{tabular}%
|
||||
% \label{aaa}%
|
||||
%\end{table}%
|
||||
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%双并列表格示例
|
||||
%%%%%%%%%%%%
|
||||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
%\begin{table}[H]\footnotesize
|
||||
% \centering
|
||||
%
|
||||
% \begin{minipage}[t]{0,47\textwidth}
|
||||
% \caption{上位机配置清单}
|
||||
% \vspace{0.5cm}
|
||||
% \centering
|
||||
% \begin{tabular}{cc}
|
||||
% \Xhline{1.2pt}
|
||||
% 运行环境 & ubuntu14 (基于Cortex$^{\textregistered}$-A15芯片) \\
|
||||
% 编程语言 & C/C++ \\
|
||||
% 第三方库及组件 & GTK2.0,OpenCV2.4.10 \\
|
||||
% 开发环境 & Qt Creator 与 make工程管理器 \\
|
||||
% 编译工具链 & NVIDIA$^{\textregistered}$-ARM$^{\textregistered}$编译工具链 \\
|
||||
% 程序结构 & 模块化结构 \\
|
||||
% \Xhline{1.2pt}
|
||||
% \end{tabular}%
|
||||
%
|
||||
% \label{pzqd}%
|
||||
% \end{minipage}
|
||||
% \hfil
|
||||
% \hfil
|
||||
% \begin{minipage}[t]{0,47\textwidth}
|
||||
% \centering
|
||||
% \caption{上位机功能清单}
|
||||
% \vspace{0.5cm}
|
||||
% \begin{tabular}{cc}
|
||||
% \Xhline{1.2pt}
|
||||
% 编号 & \multicolumn{1}{c}{功能描述} \\
|
||||
% \Xhline{1.2pt}
|
||||
% 1 & \multicolumn{1}{c}{可打开/关闭摄像头} \\
|
||||
% 2 & 可通过摄像头捕获图片为目标图片 \\
|
||||
% 3 & 可从文件系统内选择图片并载入为目标图片 \\
|
||||
% 4 & 可以检测目标图片中圆形轮廓的半径和圆心 \\
|
||||
% 5 & 可以检测目标图片中平行直线的间距 \\
|
||||
% 6 & 检测算法的参数可自由调整 \\
|
||||
% \Xhline{1.2pt}
|
||||
% \end{tabular}%
|
||||
% \label{gn}%
|
||||
% \end{minipage}
|
||||
%\end{table}%
|
@@ -1,2 +1,11 @@
|
||||
openai>=1.0.0
|
||||
python-dotenv>=1.0.0
|
||||
pydantic>=2.0.0
|
||||
gradio>=5.25.0
|
||||
langchain>=0.3
|
||||
tinydb>=4.0.0
|
||||
unsloth>=2025.3.19
|
||||
sqlmodel>=0.0.24
|
||||
jinja2>=3.1.0
|
||||
tensorboardX>=2.6.2.2
|
||||
tensorboard>=2.19.0
|
4
schema/__init__.py
Normal file
4
schema/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .dataset import *
|
||||
from .dataset_generation import *
|
||||
from .md_doc import MarkdownNode
|
||||
from .prompt import promptTempleta
|
30
schema/dataset.py
Normal file
30
schema/dataset.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
class Doc(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="文档ID")
|
||||
name: str = Field(default="", description="文档名称")
|
||||
path: str = Field(default="", description="文档路径")
|
||||
markdown_files: list[str] = Field(default_factory=list, description="文档路径列表")
|
||||
version: Optional[str] = Field(default="", description="文档版本")
|
||||
|
||||
class Q_A(BaseModel):
|
||||
question: str = Field(default="", min_length=1,description="问题")
|
||||
answer: str = Field(default="", min_length=1, description="答案")
|
||||
|
||||
class DatasetItem(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="数据集项ID")
|
||||
message: list[Q_A] = Field(description="数据集项内容")
|
||||
|
||||
class Dataset(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="数据集ID")
|
||||
name: str = Field(default="", description="数据集名称")
|
||||
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
|
||||
source_doc: Optional[Doc] = Field(default=None, description="数据集来源文档")
|
||||
description: Optional[str] = Field(default="", description="数据集描述")
|
||||
created_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="记录创建时间"
|
||||
)
|
||||
dataset_items: list[DatasetItem] = Field(default_factory=list, description="数据集项列表")
|
47
schema/dataset_generation.py
Normal file
47
schema/dataset_generation.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlmodel import SQLModel, Relationship, Field
|
||||
|
||||
class APIProvider(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True,allow_mutation=False)
|
||||
base_url: str = Field(...,min_length=1,description="API的基础URL,不能为空")
|
||||
model_id: str = Field(...,min_length=1,description="API使用的模型ID,不能为空")
|
||||
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
|
||||
created_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="记录创建时间"
|
||||
)
|
||||
|
||||
class LLMParameters(SQLModel):
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
class TokensUsage(SQLModel):
|
||||
prompt_tokens: int = Field(default=0, description="提示词使用的token数量")
|
||||
completion_tokens: int = Field(default=0, description="完成部分使用的token数量")
|
||||
prompt_cache_hit_tokens: Optional[int] = Field(default=None, description="缓存命中token数量")
|
||||
prompt_cache_miss_tokens: Optional[int] = Field(default=None, description="缓存未命中token数量")
|
||||
|
||||
class LLMResponse(SQLModel):
|
||||
timestamp: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="响应的时间戳"
|
||||
)
|
||||
response_id: str = Field(..., description="响应的唯一ID")
|
||||
tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
||||
content: str = Field(default_factory=dict, description="API响应的内容")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")
|
||||
|
||||
class LLMRequest(SQLModel):
|
||||
prompt: str = Field(..., description="发送给API的提示词")
|
||||
api_provider: APIProvider = Field(..., description="API提供者的信息")
|
||||
format: Optional[str] = Field(default=None, description="API响应的格式")
|
||||
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
||||
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
total_tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
13
schema/md_doc.py
Normal file
13
schema/md_doc.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
class MarkdownNode(BaseModel):
|
||||
level: int = Field(default=0, description="节点层级")
|
||||
title: str = Field(default="Root", description="节点标题")
|
||||
content: Optional[str] = Field(default=None, description="节点内容")
|
||||
children: List['MarkdownNode'] = Field(default_factory=list, description="子节点列表")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
MarkdownNode.model_rebuild()
|
20
schema/prompt.py
Normal file
20
schema/prompt.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
class promptTempleta(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="模板ID")
|
||||
name: Optional[str] = Field(default="", description="模板名称")
|
||||
description: Optional[str] = Field(default="", description="模板描述")
|
||||
content: str = Field(default="", min_length=1, description="模板内容")
|
||||
created_at: str = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
|
||||
description="记录创建时间"
|
||||
)
|
||||
|
||||
@field_validator('content')
|
||||
def validate_content(cls, value):
|
||||
if not "document_slice" in PromptTemplate.from_template(value).input_variables:
|
||||
raise ValueError("模板变量缺少 document_slice")
|
||||
return value
|
5
tools/__init__.py
Normal file
5
tools/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .parse_markdown import *
|
||||
from .document import *
|
||||
from .json_example import generate_json_example
|
||||
from .port import *
|
||||
from .reasoning import call_openai_api
|
35
tools/convert_json_to_dataset.py
Normal file
35
tools/convert_json_to_dataset.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import List
|
||||
from schema.dataset import Dataset, DatasetItem, Q_A
|
||||
import json
|
||||
|
||||
def convert_json_to_dataset(json_data: List[dict]) -> Dataset:
|
||||
# 将JSON数据转换为dataset格式
|
||||
dataset_items = []
|
||||
item_id = 1 # 自增ID计数器
|
||||
for item in json_data:
|
||||
qa = Q_A(question=item["question"], answer=item["answer"])
|
||||
dataset_item_obj = DatasetItem(id=item_id, message=[qa])
|
||||
dataset_items.append(dataset_item_obj)
|
||||
item_id += 1 # ID自增
|
||||
|
||||
# 创建dataset对象
|
||||
result_dataset = Dataset(
|
||||
name="Converted Dataset",
|
||||
model_id=None,
|
||||
description="Dataset converted from JSON",
|
||||
dataset_items=dataset_items
|
||||
)
|
||||
return result_dataset
|
||||
|
||||
# 示例:从文件读取JSON并转换
|
||||
if __name__ == "__main__":
|
||||
# 假设JSON数据存储在文件中
|
||||
with open(r"workdir\dataset_old\llamafactory.json", "r", encoding="utf-8") as file:
|
||||
json_data = json.load(file)
|
||||
|
||||
# 转换为dataset格式
|
||||
converted_dataset = convert_json_to_dataset(json_data)
|
||||
|
||||
# 输出结果到文件
|
||||
with open("output.json", "w", encoding="utf-8") as file:
|
||||
file.write(converted_dataset.model_dump_json(indent=4))
|
32
tools/document.py
Normal file
32
tools/document.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import Doc
|
||||
|
||||
def scan_docs_directory(workdir: str):
|
||||
docs_dir = os.path.join(workdir, "docs")
|
||||
|
||||
doc_list = os.listdir(docs_dir)
|
||||
|
||||
to_return = []
|
||||
|
||||
for doc_name in doc_list:
|
||||
doc_path = os.path.join(docs_dir, doc_name)
|
||||
if os.path.isdir(doc_path):
|
||||
markdown_files = []
|
||||
for root, dirs, files in os.walk(doc_path):
|
||||
for file in files:
|
||||
if file.endswith(".md"):
|
||||
markdown_files.append(os.path.join(root, file))
|
||||
to_return.append(Doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
|
||||
|
||||
return to_return
|
||||
|
||||
# 添加测试代码
|
||||
if __name__ == "__main__":
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
docs = scan_docs_directory(workdir)
|
||||
print(docs)
|
67
tools/json_example.py
Normal file
67
tools/json_example.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from pydantic import BaseModel, create_model
|
||||
from typing import Any, Dict, List, Optional, Union, get_args, get_origin
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
|
||||
def generate_json_example(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
|
||||
"""
|
||||
根据 Pydantic V2 模型生成示例 JSON 数据结构。
|
||||
"""
|
||||
def _generate_example(field_type: Any) -> Any:
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
if origin is list or origin is List:
|
||||
# 生成多个元素,这里生成 3 个
|
||||
result = [_generate_example(args[0]) for _ in range(list_length)] if args else []
|
||||
result.append("......")
|
||||
return result
|
||||
elif origin is dict or origin is Dict:
|
||||
if len(args) == 2:
|
||||
return {"key": _generate_example(args[1])}
|
||||
return {}
|
||||
elif origin is Union:
|
||||
# 处理 Optional 类型(Union[T, None])
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
return _generate_example(non_none_args[0]) if non_none_args else None
|
||||
elif field_type is str:
|
||||
return "string"
|
||||
elif field_type is int:
|
||||
return 0
|
||||
elif field_type is float:
|
||||
return 0.0
|
||||
elif field_type is bool:
|
||||
return True
|
||||
elif field_type is datetime:
|
||||
return datetime.now().isoformat()
|
||||
elif field_type is date:
|
||||
return date.today().isoformat()
|
||||
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_json_example(field_type, include_optional))
|
||||
else:
|
||||
# 处理直接类型注解(非泛型)
|
||||
if field_type is type(None):
|
||||
return None
|
||||
try:
|
||||
if issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_json_example(field_type, include_optional))
|
||||
except TypeError:
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
example_data = {}
|
||||
for field_name, field in model.model_fields.items():
|
||||
if include_optional or not isinstance(field.default, type(None)):
|
||||
example_data[field_name] = _generate_example(field.annotation)
|
||||
|
||||
return json.dumps(example_data, indent=2, default=str)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
from pathlib import Path
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import Dataset
|
||||
|
||||
print("示例 JSON:")
|
||||
print(generate_json_example(Dataset))
|
@@ -1,28 +1,45 @@
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
class MarkdownNode:
|
||||
def __init__(self, level=0, title="Root"):
|
||||
self.level = level
|
||||
self.title = title
|
||||
self.content = "" # 使用字符串存储合并后的内容
|
||||
self.children = []
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import MarkdownNode
|
||||
|
||||
def __repr__(self):
|
||||
return f"({self.level}) {self.title}"
|
||||
def process_markdown_file(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
def add_child(self, child):
|
||||
self.children.append(child)
|
||||
root = parse_markdown(content)
|
||||
results = []
|
||||
|
||||
def print_tree(self, indent=0):
|
||||
prefix = "│ " * (indent - 1) + "└─ " if indent > 0 else ""
|
||||
print(f"{prefix}{self.title}")
|
||||
if self.content:
|
||||
content_prefix = "│ " * indent + "├─ [内容]"
|
||||
print(content_prefix)
|
||||
for line in self.content.split('\n'):
|
||||
print("│ " * indent + "│ " + line)
|
||||
for child in self.children:
|
||||
child.print_tree(indent + 1)
|
||||
def traverse(node, parent_titles):
|
||||
current_titles = parent_titles.copy()
|
||||
current_titles.append(node.title)
|
||||
|
||||
if not node.children: # 叶子节点
|
||||
if node.content:
|
||||
full_text = ' -> '.join(current_titles) + '\n' + node.content
|
||||
results.append(full_text)
|
||||
else:
|
||||
for child in node.children:
|
||||
traverse(child, current_titles)
|
||||
|
||||
traverse(root, [])
|
||||
return results
|
||||
def add_child(parent, child):
|
||||
parent.children.append(child)
|
||||
|
||||
def print_tree(node, indent=0):
|
||||
prefix = "│ " * (indent - 1) + "└─ " if indent > 0 else ""
|
||||
print(f"{prefix}{node.title}")
|
||||
if node.content:
|
||||
content_prefix = "│ " * indent + "├─ [内容]"
|
||||
print(content_prefix)
|
||||
for line in node.content.split('\n'):
|
||||
print("│ " * indent + "│ " + line)
|
||||
for child in node.children:
|
||||
print_tree(child, indent + 1)
|
||||
|
||||
def parse_markdown(markdown):
|
||||
lines = markdown.split('\n')
|
||||
@@ -51,10 +68,10 @@ def parse_markdown(markdown):
|
||||
if match:
|
||||
level = len(match.group(1))
|
||||
title = match.group(2)
|
||||
node = MarkdownNode(level, title)
|
||||
node = MarkdownNode(level=level, title=title, content="", children=[])
|
||||
while stack[-1].level >= level:
|
||||
stack.pop()
|
||||
stack[-1].add_child(node)
|
||||
add_child(stack[-1], node)
|
||||
stack.append(node)
|
||||
else:
|
||||
if stack[-1].content:
|
||||
@@ -64,10 +81,13 @@ def parse_markdown(markdown):
|
||||
return root
|
||||
|
||||
if __name__=="__main__":
|
||||
# 从文件读取 Markdown 内容
|
||||
with open("example.md", "r", encoding="utf-8") as f:
|
||||
markdown = f.read()
|
||||
# # 从文件读取 Markdown 内容
|
||||
# with open("workdir/example.md", "r", encoding="utf-8") as f:
|
||||
# markdown = f.read()
|
||||
|
||||
# 解析 Markdown 并打印树结构
|
||||
root = parse_markdown(markdown)
|
||||
root.print_tree()
|
||||
# # 解析 Markdown 并打印树结构
|
||||
# root = parse_markdown(markdown)
|
||||
# print_tree(root)
|
||||
for i in process_markdown_file("workdir/example.md"):
|
||||
print("~"*20)
|
||||
print(i)
|
15
tools/port.py
Normal file
15
tools/port.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import socket
|
||||
|
||||
# 启动 TensorBoard 子进程前添加端口检测逻辑
|
||||
def find_available_port(start_port):
|
||||
port = start_port
|
||||
while True:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
if s.connect_ex(('localhost', port)) != 0: # 端口未被占用
|
||||
return port
|
||||
port += 1 # 如果端口被占用,尝试下一个端口
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_port = 6006 # 起始端口号
|
||||
available_port = find_available_port(start_port)
|
||||
print(f"Available port: {available_port}")
|
123
tools/reasoning.py
Normal file
123
tools/reasoning.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import openai
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters
|
||||
|
||||
async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=llm_request.api_provider.api_key,
|
||||
base_url=llm_request.api_provider.base_url
|
||||
)
|
||||
|
||||
total_duration = 0.0
|
||||
total_tokens = TokensUsage()
|
||||
prompt = llm_request.prompt
|
||||
round_start = datetime.now(timezone.utc)
|
||||
if llm_request.format:
|
||||
prompt += "\n请以JSON格式返回结果" + llm_request.format
|
||||
|
||||
for i in range(rounds):
|
||||
round_start = datetime.now(timezone.utc)
|
||||
try:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
create_args = {
|
||||
"model": llm_request.api_provider.model_id,
|
||||
"messages": messages,
|
||||
"temperature": llm_parameters.temperature if llm_parameters else None,
|
||||
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
|
||||
"top_p": llm_parameters.top_p if llm_parameters else None,
|
||||
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
|
||||
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
|
||||
"seed": llm_parameters.seed if llm_parameters else None
|
||||
} # 处理format参数
|
||||
|
||||
if llm_request.format:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
|
||||
response = await client.chat.completions.create(**create_args)
|
||||
|
||||
round_end = datetime.now(timezone.utc)
|
||||
duration = (round_end - round_start).total_seconds()
|
||||
total_duration += duration
|
||||
|
||||
# 处理可能不存在的缓存token字段
|
||||
usage = response.usage
|
||||
cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None)
|
||||
cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None)
|
||||
|
||||
tokens_usage = TokensUsage(
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
prompt_cache_hit_tokens=cache_hit,
|
||||
prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens
|
||||
)
|
||||
|
||||
# 累加总token使用量
|
||||
total_tokens.prompt_tokens += tokens_usage.prompt_tokens
|
||||
total_tokens.completion_tokens += tokens_usage.completion_tokens
|
||||
if tokens_usage.prompt_cache_hit_tokens:
|
||||
total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens
|
||||
if tokens_usage.prompt_cache_miss_tokens:
|
||||
total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens
|
||||
|
||||
llm_request.response.append(LLMResponse(
|
||||
response_id=response.id,
|
||||
tokens_usage=tokens_usage,
|
||||
content = response.choices[0].message.content,
|
||||
total_duration=duration,
|
||||
llm_parameters=llm_parameters
|
||||
))
|
||||
except Exception as e:
|
||||
round_end = datetime.now(timezone.utc)
|
||||
duration = (round_end - round_start).total_seconds()
|
||||
total_duration += duration
|
||||
|
||||
llm_request.response.append(LLMResponse(
|
||||
response_id=f"error-round-{i+1}",
|
||||
content={"error": str(e)},
|
||||
total_duration=duration
|
||||
))
|
||||
if llm_request.error is None:
|
||||
llm_request.error = []
|
||||
llm_request.error.append(str(e))
|
||||
|
||||
# 更新总耗时和总token使用量
|
||||
llm_request.total_duration = total_duration
|
||||
llm_request.total_tokens_usage = total_tokens
|
||||
|
||||
return llm_request
|
||||
|
||||
if __name__ == "__main__":
|
||||
from json_example import generate_json_example
|
||||
from sqlmodel import Session, select
|
||||
from global_var import get_sql_engine, init_global_var
|
||||
from schema import DatasetItem
|
||||
|
||||
init_global_var("workdir")
|
||||
api_state = "1 deepseek-chat"
|
||||
with Session(get_sql_engine()) as session:
|
||||
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||
llm_request = LLMRequest(
|
||||
prompt="测试,随便说点什么",
|
||||
api_provider=api_provider,
|
||||
format=generate_json_example(DatasetItem)
|
||||
)
|
||||
|
||||
# # 单次调用示例
|
||||
# result = asyncio.run(call_openai_api(llm_request))
|
||||
# print(f"\n单次调用结果 - 响应数量: {len(result.response)}")
|
||||
# for i, resp in enumerate(result.response, 1):
|
||||
# print(f"响应{i}: {resp.response_content}")
|
||||
|
||||
# 多次调用示例
|
||||
params = LLMParameters(temperature=0.7, max_tokens=100)
|
||||
result = asyncio.run(call_openai_api(llm_request, 3,params))
|
||||
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
|
||||
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
|
||||
for i, resp in enumerate(result.response, 1):
|
||||
print(f"响应{i}: {resp.content}")
|
1
train/__init__.py
Normal file
1
train/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .model import *
|
134
train/model.py
Normal file
134
train/model.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import os
|
||||
from datasets import Dataset as HFDataset
|
||||
from unsloth import FastLanguageModel
|
||||
from trl import SFTTrainer # 用于监督微调的训练器
|
||||
from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数
|
||||
from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练
|
||||
from unsloth.chat_templates import get_chat_template, train_on_responses_only
|
||||
def get_model_name(model):
|
||||
return os.path.basename(model.name_or_path)
|
||||
def formatting_prompts(examples,tokenizer):
|
||||
"""格式化对话数据的函数
|
||||
Args:
|
||||
examples: 包含对话列表的字典
|
||||
Returns:
|
||||
包含格式化文本的字典
|
||||
"""
|
||||
questions = examples["question"]
|
||||
answer = examples["answer"]
|
||||
|
||||
convos = [
|
||||
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
|
||||
for q, r in zip(questions, answer)
|
||||
]
|
||||
|
||||
# 使用tokenizer.apply_chat_template格式化对话
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
|
||||
for convo in convos
|
||||
]
|
||||
|
||||
return {"text": texts}
|
||||
|
||||
|
||||
def train_model(
|
||||
model,
|
||||
tokenizer,
|
||||
dataset: list,
|
||||
train_dir: str,
|
||||
learning_rate: float,
|
||||
per_device_train_batch_size: int,
|
||||
epoch: int,
|
||||
save_steps: int,
|
||||
lora_rank: int,
|
||||
trainer_callback=None
|
||||
) -> None:
|
||||
# 模型配置参数
|
||||
dtype = None # 数据类型,None表示自动选择
|
||||
load_in_4bit = False # 使用4bit量化加载模型以节省显存
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
# 原始模型
|
||||
model,
|
||||
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
|
||||
# 建议: 8-32之间
|
||||
r=lora_rank, # 使用动态传入的LoRA秩
|
||||
# 需要应用LoRA的目标模块列表
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
|
||||
"gate_proj", "up_proj", "down_proj", # FFN相关层
|
||||
],
|
||||
# LoRA缩放因子,用于控制LoRA更新的幅度。值越大,LoRA的更新影响越大。
|
||||
lora_alpha=16,
|
||||
# LoRA层的dropout率,用于防止过拟合,这里设为0表示不使用dropout。
|
||||
# 如果数据集较小,建议设置0.1左右。
|
||||
lora_dropout=0,
|
||||
# 是否对bias参数进行微调,none表示不微调bias
|
||||
# none: 不微调偏置参数;
|
||||
# all: 微调所有参数;
|
||||
# lora_only: 只微调LoRA参数。
|
||||
bias="none",
|
||||
# 是否使用梯度检查点技术节省显存,使用unsloth优化版本
|
||||
# 会略微降低训练速度,但可以显著减少显存使用
|
||||
use_gradient_checkpointing="unsloth",
|
||||
# 随机数种子,用于结果复现
|
||||
random_state=3407,
|
||||
# 是否使用rank-stabilized LoRA,这里不使用
|
||||
# 会略微降低训练速度,但可以显著减少显存使用
|
||||
use_rslora=False,
|
||||
# LoFTQ配置,这里不使用该量化技术,用于进一步压缩模型大小
|
||||
loftq_config=None,
|
||||
)
|
||||
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="qwen-2.5",
|
||||
)
|
||||
|
||||
train_dataset = HFDataset.from_list(dataset)
|
||||
train_dataset = train_dataset.map(formatting_prompts,
|
||||
fn_kwargs={"tokenizer": tokenizer},
|
||||
batched=True)
|
||||
|
||||
# 初始化SFT训练器
|
||||
trainer = SFTTrainer(
|
||||
model=model, # 待训练的模型
|
||||
tokenizer=tokenizer, # 分词器
|
||||
train_dataset=train_dataset, # 训练数据集
|
||||
dataset_text_field="text", # 数据集字段的名称
|
||||
max_seq_length=model.max_seq_length, # 最大序列长度
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=1, # 数据集处理的并行进程数
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=per_device_train_batch_size, # 每个GPU的训练批次大小
|
||||
gradient_accumulation_steps=4, # 梯度累积步数,用于模拟更大的batch size
|
||||
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
|
||||
learning_rate=learning_rate, # 学习率
|
||||
lr_scheduler_type="linear", # 线性学习率调度器
|
||||
max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据)
|
||||
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
|
||||
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
|
||||
logging_steps=1, # 每1步记录一次日志
|
||||
optim="adamw_8bit", # 使用8位AdamW优化器节省显存,几乎不影响训练效果
|
||||
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
|
||||
seed=114514, # 随机数种子
|
||||
output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志
|
||||
save_strategy="steps", # 按步保存中间权重
|
||||
save_steps=save_steps, # 使用动态传入的保存步数
|
||||
logging_dir=train_dir + "/logs", # 日志文件存储路径
|
||||
report_to="tensorboard", # 使用TensorBoard记录日志
|
||||
),
|
||||
)
|
||||
|
||||
if trainer_callback is not None:
|
||||
trainer.add_callback(trainer_callback)
|
||||
|
||||
trainer = train_on_responses_only(
|
||||
trainer,
|
||||
instruction_part = "<|im_start|>user\n",
|
||||
response_part = "<|im_start|>assistant\n",
|
||||
)
|
||||
|
||||
# 开始训练
|
||||
trainer_stats = trainer.train(resume_from_checkpoint=False)
|
Reference in New Issue
Block a user