Compare commits
54 Commits
1a2ca3e244
...
improve
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7a4388c928 | ||
![]() |
6338706967 | ||
![]() |
3718c75cee | ||
![]() |
905658073a | ||
![]() |
9806334517 | ||
![]() |
0a4efa5641 | ||
![]() |
994d600221 | ||
![]() |
d5774eee0c | ||
![]() |
87501c9353 | ||
![]() |
5fc3b4950b | ||
![]() |
c28e4819d9 | ||
![]() |
e7cf51d662 | ||
![]() |
4c9caff668 | ||
![]() |
9236f49b36 | ||
![]() |
868fcd45ba | ||
![]() |
5a21c8598a | ||
![]() |
1e829c9268 | ||
![]() |
9fc3ab904b | ||
![]() |
d827f9758f | ||
![]() |
ff1e9731bc | ||
![]() |
90fde639ff | ||
![]() |
5fc90903fb | ||
![]() |
81c2ad4a2d | ||
![]() |
314434951d | ||
![]() |
e16882953d | ||
![]() |
86bcf90c66 | ||
![]() |
961a017f19 | ||
![]() |
5a386d6401 | ||
![]() |
feaea1fb64 | ||
![]() |
7242a2ce03 | ||
![]() |
db6e2271dc | ||
![]() |
d764537143 | ||
![]() |
8c35a38c47 | ||
![]() |
7ee751c88f | ||
![]() |
b715b36a5f | ||
![]() |
8023233bb2 | ||
![]() |
2a86b3b5b0 | ||
![]() |
ca1505304e | ||
![]() |
df9260e918 | ||
![]() |
df9aba0c6e | ||
![]() |
6b87dcb58f | ||
![]() |
d0aebd17fa | ||
![]() |
d9abf08184 | ||
![]() |
a27a1ab079 | ||
![]() |
aa758e3c2a | ||
![]() |
664944f0c5 | ||
![]() |
9298438f98 | ||
![]() |
4f7926aec6 | ||
![]() |
148f4afb25 | ||
![]() |
11a3039775 | ||
![]() |
a4289815ba | ||
![]() |
088067d335 | ||
![]() |
9fb31c46c8 | ||
![]() |
4f09823123 |
92
README.md
92
README.md
@@ -1,15 +1,91 @@
|
|||||||
# 基于文档驱动的自适应编码大模型微调框架
|
# 基于文档驱动的自适应编码大模型微调框架
|
||||||
|
|
||||||
## 简介
|
## 简介
|
||||||
本人的毕业设计
|
|
||||||
### 项目概述
|
### 项目概述
|
||||||
|
本项目是一个基于文档驱动的自适应编码大模型微调框架,通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
|
||||||
|
|
||||||
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
|
### 核心功能
|
||||||
|
- 文档解析与语料生成
|
||||||
|
- 大语言模型高效微调
|
||||||
|
- 交互式训练与推理界面
|
||||||
|
- 训练过程可视化监控
|
||||||
|
|
||||||
### 项目技术
|
## 技术架构
|
||||||
|
|
||||||
* 使用unsloth框架在GPU上实现大语言模型的qlora微调
|
### 系统架构
|
||||||
* 使用langchain框架编写工作流实现批量生成微调语料
|
```
|
||||||
* 使用tinydb和sqlite实现数据的持久化
|
[前端界面] -> [模型微调] -> [数据存储]
|
||||||
* 使用gradio框架实现前端展示
|
↑ ↑ ↑
|
||||||
|
│ │ │
|
||||||
|
[Gradio] [unsloth/QLoRA] [SQLite/TinyDB]
|
||||||
|
```
|
||||||
|
|
||||||
**施工中......**
|
### 技术栈
|
||||||
|
- **前端界面**: Gradio构建的交互式Web界面
|
||||||
|
- **模型微调**: 基于unsloth框架的QLoRA高效微调
|
||||||
|
- **数据存储**: SQLite(结构化数据) + TinyDB(非结构化数据)
|
||||||
|
- **工作流引擎**: LangChain实现文档解析与语料生成
|
||||||
|
- **训练监控**: TensorBoard集成
|
||||||
|
|
||||||
|
## 功能模块
|
||||||
|
|
||||||
|
### 1. 模型管理
|
||||||
|
- 支持多种格式的大语言模型加载
|
||||||
|
- 模型信息查看与状态管理
|
||||||
|
- 模型卸载与内存释放
|
||||||
|
|
||||||
|
### 2. 模型推理
|
||||||
|
- 对话式交互界面
|
||||||
|
- 流式响应输出
|
||||||
|
- 历史对话管理
|
||||||
|
|
||||||
|
### 3. 模型微调
|
||||||
|
- 训练参数配置(学习率、batch size等)
|
||||||
|
- LoRA参数配置(秩、alpha等)
|
||||||
|
- 训练过程实时监控
|
||||||
|
- 训练中断与恢复
|
||||||
|
|
||||||
|
### 4. 数据集生成
|
||||||
|
- 文档解析与清洗
|
||||||
|
- 指令-响应对生成
|
||||||
|
- 数据集质量评估
|
||||||
|
- 数据集版本管理
|
||||||
|
|
||||||
|
## 微调技术
|
||||||
|
|
||||||
|
### QLoRA原理
|
||||||
|
QLoRA(Quantized Low-Rank Adaptation)是一种高效的大模型微调技术,核心特点包括:
|
||||||
|
1. **4-bit量化**: 将预训练模型量化为4-bit表示,大幅减少显存占用
|
||||||
|
2. **低秩适配**: 通过低秩矩阵分解(LoRA)实现参数高效更新
|
||||||
|
3. **内存优化**: 使用梯度检查点等技术进一步降低显存需求
|
||||||
|
|
||||||
|
### 参数配置
|
||||||
|
- **学习率**: 建议2e-5到2e-4
|
||||||
|
- **LoRA秩**: 控制适配器复杂度(建议16-64)
|
||||||
|
- **LoRA Alpha**: 控制适配器更新幅度(通常设为秩的1-2倍)
|
||||||
|
|
||||||
|
### 训练监控
|
||||||
|
- **TensorBoard集成**: 实时查看损失曲线、学习率等指标
|
||||||
|
- **日志记录**: 训练过程详细日志保存
|
||||||
|
- **模型检查点**: 定期保存中间权重
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
1. 安装依赖:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 启动应用:
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 访问Web界面:
|
||||||
|
```
|
||||||
|
http://localhost:7860
|
||||||
|
```
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
MIT License
|
||||||
|
@@ -1,11 +1,12 @@
|
|||||||
from .init_db import get_sqlite_engine, initialize_sqlite_db
|
from .init_db import load_sqlite_engine, initialize_sqlite_db
|
||||||
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
||||||
from .dataset_store import get_all_dataset
|
from .dataset_store import get_all_dataset, save_dataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_sqlite_engine",
|
"load_sqlite_engine",
|
||||||
"initialize_sqlite_db",
|
"initialize_sqlite_db",
|
||||||
"get_prompt_tinydb",
|
"get_prompt_tinydb",
|
||||||
"initialize_prompt_store",
|
"initialize_prompt_store",
|
||||||
"get_all_dataset"
|
"get_all_dataset",
|
||||||
|
"save_dataset"
|
||||||
]
|
]
|
@@ -8,7 +8,7 @@ from tinydb.storages import MemoryStorage
|
|||||||
|
|
||||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from schema.dataset import dataset, dataset_item, Q_A
|
from schema.dataset import Dataset, DatasetItem, Q_A
|
||||||
|
|
||||||
def get_all_dataset(workdir: str) -> TinyDB:
|
def get_all_dataset(workdir: str) -> TinyDB:
|
||||||
"""
|
"""
|
||||||
@@ -39,6 +39,30 @@ def get_all_dataset(workdir: str) -> TinyDB:
|
|||||||
return db
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
|
||||||
|
"""
|
||||||
|
将TinyDB中的数据集保存为单独的json文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (TinyDB): 包含数据集对象的TinyDB实例
|
||||||
|
workdir (str): 工作目录路径
|
||||||
|
name (str, optional): 要保存的数据集名称,None表示保存所有
|
||||||
|
"""
|
||||||
|
dataset_dir = os.path.join(workdir, "dataset")
|
||||||
|
os.makedirs(dataset_dir, exist_ok=True)
|
||||||
|
|
||||||
|
datasets = db.all() if name is None else db.search(Query().name == name)
|
||||||
|
|
||||||
|
for dataset in datasets:
|
||||||
|
try:
|
||||||
|
filename = f"{dataset.get(dataset['name'])}.json"
|
||||||
|
filepath = os.path.join(dataset_dir, filename)
|
||||||
|
|
||||||
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 定义工作目录路径
|
# 定义工作目录路径
|
||||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||||
@@ -48,3 +72,10 @@ if __name__ == "__main__":
|
|||||||
print(f"Found {len(datasets)} datasets:")
|
print(f"Found {len(datasets)} datasets:")
|
||||||
for ds in datasets.all():
|
for ds in datasets.all():
|
||||||
print(f"- {ds['name']} (ID: {ds['id']})")
|
print(f"- {ds['name']} (ID: {ds['id']})")
|
||||||
|
|
||||||
|
# 询问要保存的数据集名称
|
||||||
|
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
|
||||||
|
|
||||||
|
# 保存数据集到文件
|
||||||
|
save_dataset(datasets, workdir, name)
|
||||||
|
print(f"Datasets {'all' if name is None else name} saved to json files")
|
@@ -14,7 +14,7 @@ from schema.dataset_generation import APIProvider
|
|||||||
# 全局变量,用于存储数据库引擎实例
|
# 全局变量,用于存储数据库引擎实例
|
||||||
_engine: Optional[Engine] = None
|
_engine: Optional[Engine] = None
|
||||||
|
|
||||||
def get_sqlite_engine(workdir: str) -> Engine:
|
def load_sqlite_engine(workdir: str) -> Engine:
|
||||||
"""
|
"""
|
||||||
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||||||
|
|
||||||
@@ -74,6 +74,6 @@ if __name__ == "__main__":
|
|||||||
# 定义工作目录路径
|
# 定义工作目录路径
|
||||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||||
# 获取数据库引擎
|
# 获取数据库引擎
|
||||||
engine = get_sqlite_engine(workdir)
|
engine = load_sqlite_engine(workdir)
|
||||||
# 初始化数据库
|
# 初始化数据库
|
||||||
initialize_sqlite_db(engine)
|
initialize_sqlite_db(engine)
|
@@ -44,12 +44,12 @@ def initialize_prompt_store(db: TinyDB) -> None:
|
|||||||
# 检查数据库是否为空
|
# 检查数据库是否为空
|
||||||
if not db.all(): # 如果数据库中没有数据
|
if not db.all(): # 如果数据库中没有数据
|
||||||
db.insert(promptTempleta(
|
db.insert(promptTempleta(
|
||||||
id=0,
|
id=1,
|
||||||
name="default",
|
name="default",
|
||||||
description="默认提示词模板",
|
description="默认提示词模板",
|
||||||
content="""项目名为:{project_name}
|
content="""项目名为:{project_name}
|
||||||
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
|
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
|
||||||
文档节选:{ content }""").model_dump())
|
文档节选:{document_slice}""").model_dump())
|
||||||
# 如果数据库中已有数据,则跳过插入
|
# 如果数据库中已有数据,则跳过插入
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,42 +1,189 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from tools import scan_docs_directory
|
import sys
|
||||||
from global_var import get_docs, scan_docs_directory, get_prompt_store
|
import json
|
||||||
|
from tinydb import Query
|
||||||
|
from pathlib import Path
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from schema import Dataset, DatasetItem, Q_A
|
||||||
|
from db.dataset_store import get_all_dataset
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
|
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
|
||||||
|
from db import save_dataset
|
||||||
|
from tools import call_openai_api, process_markdown_file, generate_json_example
|
||||||
|
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
|
||||||
|
|
||||||
def dataset_generate_page():
|
def dataset_generate_page():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("## 数据集生成")
|
gr.Markdown("## 数据集生成")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column(scale=1):
|
||||||
# 获取文档列表并设置初始值
|
|
||||||
docs_list = [str(doc.name) for doc in get_docs()]
|
docs_list = [str(doc.name) for doc in get_docs()]
|
||||||
initial_doc = docs_list[0] if docs_list else None
|
initial_doc = docs_list[0] if docs_list else None
|
||||||
|
prompts = get_prompt_store().all()
|
||||||
|
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
|
||||||
|
initial_prompt = prompt_list[0] if prompt_list else None
|
||||||
|
|
||||||
|
# 初始化Dataframe的值
|
||||||
|
initial_dataframe_value = []
|
||||||
|
if initial_prompt:
|
||||||
|
selected_prompt_id = int(initial_prompt.split(" ")[0])
|
||||||
|
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
|
||||||
|
prompt_content = prompt_data["content"]
|
||||||
|
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||||
|
input_variables = prompt_template.input_variables
|
||||||
|
input_variables.remove("document_slice")
|
||||||
|
initial_dataframe_value = [[var, ""] for var in input_variables]
|
||||||
|
# 从数据库获取API Provider列表
|
||||||
|
with Session(get_sql_engine()) as session:
|
||||||
|
providers = session.exec(select(APIProvider)).all()
|
||||||
|
api_list = [f"{p.id} {p.model_id}" for p in providers]
|
||||||
|
initial_api = api_list[0] if api_list else None
|
||||||
|
|
||||||
|
api_dropdown = gr.Dropdown(
|
||||||
|
choices=api_list,
|
||||||
|
value=initial_api,
|
||||||
|
label="选择API",
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
doc_dropdown = gr.Dropdown(
|
doc_dropdown = gr.Dropdown(
|
||||||
choices=docs_list,
|
choices=docs_list,
|
||||||
value=initial_doc, # 设置初始选中项
|
value=initial_doc,
|
||||||
label="选择文档",
|
label="选择文档",
|
||||||
allow_custom_value=True,
|
|
||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
doc_state = gr.State(value=initial_doc) # 用文档初始值初始化状态
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
# 获取模板列表并设置初始值
|
|
||||||
prompts = get_prompt_store().all()
|
|
||||||
prompt_choices = [f"{p['id']} {p['name']}" for p in prompts]
|
|
||||||
initial_prompt = prompt_choices[0] if prompt_choices else None
|
|
||||||
|
|
||||||
prompt_dropdown = gr.Dropdown(
|
prompt_dropdown = gr.Dropdown(
|
||||||
choices=prompt_choices,
|
choices=prompt_list,
|
||||||
value=initial_prompt, # 设置初始选中项
|
value=initial_prompt,
|
||||||
label="选择模板",
|
label="选择模板",
|
||||||
allow_custom_value=True,
|
|
||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
prompt_state = gr.State(value=initial_prompt) # 用模板初始值初始化状态
|
rounds_input = gr.Number(
|
||||||
|
value=1,
|
||||||
|
label="生成轮次",
|
||||||
|
minimum=1,
|
||||||
|
maximum=100,
|
||||||
|
step=1,
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
concurrency_input = gr.Number(
|
||||||
|
value=1,
|
||||||
|
label="并发数",
|
||||||
|
minimum=1,
|
||||||
|
maximum=10,
|
||||||
|
step=1,
|
||||||
|
interactive=True,
|
||||||
|
visible=False
|
||||||
|
)
|
||||||
|
dataset_name_input = gr.Textbox(
|
||||||
|
label="数据集名称",
|
||||||
|
placeholder="输入数据集保存名称",
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
prompt_choice = gr.State(value=initial_prompt)
|
||||||
|
generate_button = gr.Button("生成数据集",variant="primary")
|
||||||
|
doc_choice = gr.State(value=initial_doc)
|
||||||
|
output_text = gr.Textbox(label="生成结果", interactive=False)
|
||||||
|
api_choice = gr.State(value=initial_api)
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
variables_dataframe = gr.Dataframe(
|
||||||
|
headers=["变量名", "变量值"],
|
||||||
|
datatype=["str", "str"],
|
||||||
|
interactive=True,
|
||||||
|
label="变量列表",
|
||||||
|
value=initial_dataframe_value # 设置初始化数据
|
||||||
|
)
|
||||||
|
def on_doc_change(selected_doc):
|
||||||
|
return selected_doc
|
||||||
|
|
||||||
# 绑定事件(保留原有逻辑,确保交互时更新)
|
def on_api_change(selected_api):
|
||||||
doc_dropdown.change(lambda x: x, inputs=doc_dropdown, outputs=doc_state)
|
return selected_api
|
||||||
prompt_dropdown.change(lambda x: x, inputs=prompt_dropdown, outputs=prompt_state)
|
|
||||||
|
def on_prompt_change(selected_prompt):
|
||||||
|
if not selected_prompt:
|
||||||
|
return None, []
|
||||||
|
selected_prompt_id = int(selected_prompt.split(" ")[0])
|
||||||
|
prompt_data = get_prompt_store().get(doc_id=selected_prompt_id)
|
||||||
|
prompt_content = prompt_data["content"]
|
||||||
|
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||||
|
input_variables = prompt_template.input_variables
|
||||||
|
input_variables.remove("document_slice")
|
||||||
|
dataframe_value = [] if input_variables is None else input_variables
|
||||||
|
dataframe_value = [[var, ""] for var in input_variables]
|
||||||
|
return selected_prompt, dataframe_value
|
||||||
|
|
||||||
|
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
|
||||||
|
dataset_db = get_datasets()
|
||||||
|
if not dataset_db.search(Query().name == dataset_name):
|
||||||
|
raise gr.Error("数据集名称已存在")
|
||||||
|
|
||||||
|
doc = [i for i in get_docs() if i.name == doc_state][0]
|
||||||
|
doc_files = doc.markdown_files
|
||||||
|
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
|
||||||
|
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
||||||
|
prompt = PromptTemplate.from_template(prompt["content"])
|
||||||
|
with Session(get_sql_engine()) as session:
|
||||||
|
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||||
|
|
||||||
|
variables_dict = {}
|
||||||
|
for _, row in variables_dataframe.iterrows():
|
||||||
|
var_name = row['变量名'].strip()
|
||||||
|
var_value = row['变量值'].strip()
|
||||||
|
if var_name:
|
||||||
|
variables_dict[var_name] = var_value
|
||||||
|
|
||||||
|
prompt = prompt.partial(**variables_dict)
|
||||||
|
|
||||||
|
dataset = Dataset(
|
||||||
|
name=dataset_name,
|
||||||
|
model_id=[api_provider.model_id],
|
||||||
|
source_doc=doc,
|
||||||
|
dataset_items=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
total_slices = len(document_slice_list)
|
||||||
|
for i, document_slice in enumerate(document_slice_list):
|
||||||
|
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
|
||||||
|
request = LLMRequest(api_provider=api_provider,
|
||||||
|
prompt=prompt.format(document_slice=document_slice),
|
||||||
|
format=generate_json_example(DatasetItem))
|
||||||
|
call_openai_api(request, rounds)
|
||||||
|
|
||||||
|
for resp in request.response:
|
||||||
|
try:
|
||||||
|
content = json.loads(resp.content)
|
||||||
|
dataset_item = DatasetItem(
|
||||||
|
message=[Q_A(
|
||||||
|
question=content.get("question", ""),
|
||||||
|
answer=content.get("answer", "")
|
||||||
|
)]
|
||||||
|
)
|
||||||
|
dataset.dataset_items.append(dataset_item)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Failed to parse response: {e}")
|
||||||
|
|
||||||
|
# 保存数据集到TinyDB
|
||||||
|
dataset_db.insert(dataset.model_dump())
|
||||||
|
|
||||||
|
save_dataset(dataset_db,get_workdir(),dataset_name)
|
||||||
|
|
||||||
|
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
|
||||||
|
|
||||||
|
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
||||||
|
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
||||||
|
api_dropdown.change(on_api_change, inputs=api_dropdown, outputs=api_choice)
|
||||||
|
|
||||||
|
generate_button.click(
|
||||||
|
on_generate_click,
|
||||||
|
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
|
||||||
|
outputs=output_text
|
||||||
|
)
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from global_var import init_global_var
|
||||||
|
init_global_var("workdir")
|
||||||
|
demo = dataset_generate_page()
|
||||||
|
demo.launch()
|
@@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
|
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
|
||||||
from tools.model import get_model_name
|
from train import get_model_name
|
||||||
|
|
||||||
def model_manage_page():
|
def model_manage_page():
|
||||||
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
|
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
|
||||||
@@ -30,6 +30,11 @@ def model_manage_page():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
|
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
|
||||||
|
save_method_dropdown = gr.Dropdown(
|
||||||
|
choices=["default", "merged_16bit", "merged_4bit", "lora", "gguf", "gguf_q4_k_m", "gguf_f16"],
|
||||||
|
label="保存模式",
|
||||||
|
value="default"
|
||||||
|
)
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
save_button = gr.Button("保存模型", variant="secondary")
|
save_button = gr.Button("保存模型", variant="secondary")
|
||||||
|
|
||||||
@@ -73,21 +78,12 @@ def model_manage_page():
|
|||||||
|
|
||||||
unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
|
unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
|
||||||
|
|
||||||
def save_model(save_model_name):
|
from train.save_model import save_model_to_dir
|
||||||
try:
|
|
||||||
global model, tokenizer
|
|
||||||
if model is None:
|
|
||||||
return "没有加载的模型可保存"
|
|
||||||
|
|
||||||
save_path = os.path.join(models_dir, save_model_name)
|
def save_model(save_model_name, save_method):
|
||||||
os.makedirs(save_path, exist_ok=True)
|
return save_model_to_dir(save_model_name, models_dir, get_model(), get_tokenizer(), save_method)
|
||||||
model.save_pretrained(save_path)
|
|
||||||
tokenizer.save_pretrained(save_path)
|
|
||||||
return f"模型已保存到 {save_path}"
|
|
||||||
except Exception as e:
|
|
||||||
return f"保存模型时出错: {str(e)}"
|
|
||||||
|
|
||||||
save_button.click(fn=save_model, inputs=save_model_name_input, outputs=state_output)
|
save_button.click(fn=save_model, inputs=[save_model_name_input, save_method_dropdown], outputs=state_output)
|
||||||
|
|
||||||
def refresh_model_list():
|
def refresh_model_list():
|
||||||
try:
|
try:
|
||||||
|
@@ -61,9 +61,11 @@ def prompt_manage_page():
|
|||||||
|
|
||||||
selected_row = None # 保存当前选中行的全局变量
|
selected_row = None # 保存当前选中行的全局变量
|
||||||
|
|
||||||
def select_record(evt: gr.SelectData):
|
def select_record(dataFrame ,evt: gr.SelectData):
|
||||||
global selected_row
|
global selected_row
|
||||||
selected_row = evt.row_value
|
selected_row = dataFrame.iloc[evt.index[0]].tolist()
|
||||||
|
selected_row[0] = int(selected_row[0])
|
||||||
|
print(selected_row)
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("## 提示词模板管理")
|
gr.Markdown("## 提示词模板管理")
|
||||||
@@ -102,7 +104,10 @@ def prompt_manage_page():
|
|||||||
outputs=[prompt_table, name_input, description_input, content_input]
|
outputs=[prompt_table, name_input, description_input, content_input]
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_table.select(select_record, [], [], show_progress="hidden")
|
prompt_table.select(fn=select_record,
|
||||||
|
inputs=[prompt_table],
|
||||||
|
outputs=[],
|
||||||
|
show_progress="hidden")
|
||||||
|
|
||||||
edit_button.click(
|
edit_button.click(
|
||||||
fn=edit_prompt,
|
fn=edit_prompt,
|
||||||
|
@@ -68,9 +68,11 @@ def setting_page():
|
|||||||
|
|
||||||
selected_row = None # 保存当前选中行的全局变量
|
selected_row = None # 保存当前选中行的全局变量
|
||||||
|
|
||||||
def select_record(evt: gr.SelectData):
|
def select_record(dataFrame ,evt: gr.SelectData):
|
||||||
global selected_row
|
global selected_row
|
||||||
selected_row = evt.row_value
|
selected_row = dataFrame.iloc[evt.index[0]].tolist()
|
||||||
|
selected_row[0] = int(selected_row[0])
|
||||||
|
print(selected_row)
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("## API Provider 管理")
|
gr.Markdown("## API Provider 管理")
|
||||||
@@ -109,7 +111,10 @@ def setting_page():
|
|||||||
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
|
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_table.select(select_record, [], [], show_progress="hidden")
|
provider_table.select(fn=select_record,
|
||||||
|
inputs=[provider_table],
|
||||||
|
outputs=[],
|
||||||
|
show_progress="hidden")
|
||||||
|
|
||||||
edit_button.click(
|
edit_button.click(
|
||||||
fn=edit_provider,
|
fn=edit_provider,
|
||||||
|
@@ -1,12 +1,15 @@
|
|||||||
|
import subprocess
|
||||||
|
import os
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import sys
|
import sys
|
||||||
import torch
|
|
||||||
from tinydb import Query
|
from tinydb import Query
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
||||||
from tools import train_model
|
from tools import find_available_port
|
||||||
|
from train import train_model
|
||||||
|
|
||||||
def train_page():
|
def train_page():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
@@ -14,7 +17,8 @@ def train_page():
|
|||||||
# 获取数据集列表并设置初始值
|
# 获取数据集列表并设置初始值
|
||||||
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
||||||
initial_dataset = datasets_list[0] if datasets_list else None
|
initial_dataset = datasets_list[0] if datasets_list else None
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=1):
|
||||||
dataset_dropdown = gr.Dropdown(
|
dataset_dropdown = gr.Dropdown(
|
||||||
choices=datasets_list,
|
choices=datasets_list,
|
||||||
value=initial_dataset, # 设置初始选中项
|
value=initial_dataset, # 设置初始选中项
|
||||||
@@ -22,7 +26,6 @@ def train_page():
|
|||||||
allow_custom_value=True,
|
allow_custom_value=True,
|
||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 新增超参数输入组件
|
# 新增超参数输入组件
|
||||||
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
||||||
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
||||||
@@ -33,7 +36,10 @@ def train_page():
|
|||||||
train_button = gr.Button("开始微调")
|
train_button = gr.Button("开始微调")
|
||||||
|
|
||||||
# 训练状态输出
|
# 训练状态输出
|
||||||
output = gr.Textbox(label="训练日志", interactive=False)
|
output = gr.Textbox(label="训练状态", interactive=False)
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
# 新增 TensorBoard iframe 显示框
|
||||||
|
tensorboard_iframe = gr.HTML(label="TensorBoard 可视化")
|
||||||
|
|
||||||
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||||||
# 使用动态传入的超参数
|
# 使用动态传入的超参数
|
||||||
@@ -42,11 +48,46 @@ def train_page():
|
|||||||
epoch = int(epoch)
|
epoch = int(epoch)
|
||||||
save_steps = int(save_steps) # 新增保存步数参数
|
save_steps = int(save_steps) # 新增保存步数参数
|
||||||
lora_rank = int(lora_rank) # 新增LoRA秩参数
|
lora_rank = int(lora_rank) # 新增LoRA秩参数
|
||||||
|
|
||||||
# 加载数据集
|
# 加载数据集
|
||||||
dataset = get_datasets().get(Query().name == dataset_name)
|
dataset = get_datasets().get(Query().name == dataset_name)
|
||||||
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
||||||
train_model(get_model(), get_tokenizer(), dataset, get_workdir(), learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank)
|
|
||||||
|
|
||||||
|
# 扫描 training 文件夹并生成递增目录
|
||||||
|
training_dir = get_workdir() + "/training"
|
||||||
|
os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在
|
||||||
|
existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()]
|
||||||
|
next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1
|
||||||
|
new_training_dir = os.path.join(training_dir, str(next_dir_number))
|
||||||
|
|
||||||
|
tensorboard_port = find_available_port(6006) # 从默认端口 6006 开始检测
|
||||||
|
print(f"TensorBoard 将使用端口: {tensorboard_port}")
|
||||||
|
|
||||||
|
tensorboard_logdir = os.path.join(new_training_dir, "logs")
|
||||||
|
os.makedirs(tensorboard_logdir, exist_ok=True) # 确保日志目录存在
|
||||||
|
tensorboard_process = subprocess.Popen(
|
||||||
|
["tensorboard", "--logdir", tensorboard_logdir, "--port", str(tensorboard_port)],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
|
print("TensorBoard 已启动,日志目录:", tensorboard_logdir)
|
||||||
|
|
||||||
|
# 动态生成 TensorBoard iframe
|
||||||
|
tensorboard_url = f"http://localhost:{tensorboard_port}"
|
||||||
|
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>'
|
||||||
|
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
|
||||||
|
|
||||||
|
try:
|
||||||
|
train_model(get_model(), get_tokenizer(),
|
||||||
|
dataset, new_training_dir,
|
||||||
|
learning_rate, per_device_train_batch_size, epoch,
|
||||||
|
save_steps, lora_rank)
|
||||||
|
except Exception as e:
|
||||||
|
raise gr.Error(str(e))
|
||||||
|
finally:
|
||||||
|
# 确保训练结束后终止 TensorBoard 子进程
|
||||||
|
tensorboard_process.terminate()
|
||||||
|
print("TensorBoard 子进程已终止")
|
||||||
|
|
||||||
train_button.click(
|
train_button.click(
|
||||||
fn=start_training,
|
fn=start_training,
|
||||||
@@ -56,9 +97,9 @@ def train_page():
|
|||||||
per_device_train_batch_size_input,
|
per_device_train_batch_size_input,
|
||||||
epoch_input,
|
epoch_input,
|
||||||
save_steps_input,
|
save_steps_input,
|
||||||
lora_rank_input # 新增lora_rank_input
|
lora_rank_input
|
||||||
],
|
],
|
||||||
outputs=output
|
outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe
|
||||||
)
|
)
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
@@ -1,18 +1,16 @@
|
|||||||
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
from db import load_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
||||||
from tools import scan_docs_directory
|
from tools import scan_docs_directory
|
||||||
|
|
||||||
_prompt_store = None
|
_prompt_store = None
|
||||||
_sql_engine = None
|
_sql_engine = None
|
||||||
_docs = None
|
|
||||||
_datasets = None
|
_datasets = None
|
||||||
_model = None
|
_model = None
|
||||||
_tokenizer = None
|
_tokenizer = None
|
||||||
_workdir = None
|
_workdir = None
|
||||||
def init_global_var(workdir="workdir"):
|
def init_global_var(workdir="workdir"):
|
||||||
global _prompt_store, _sql_engine, _docs, _datasets, _workdir
|
global _prompt_store, _sql_engine, _datasets, _workdir
|
||||||
_prompt_store = get_prompt_tinydb(workdir)
|
_prompt_store = get_prompt_tinydb(workdir)
|
||||||
_sql_engine = get_sqlite_engine(workdir)
|
_sql_engine = load_sqlite_engine(workdir)
|
||||||
_docs = scan_docs_directory(workdir)
|
|
||||||
_datasets = get_all_dataset(workdir)
|
_datasets = get_all_dataset(workdir)
|
||||||
_workdir = workdir
|
_workdir = workdir
|
||||||
|
|
||||||
@@ -34,19 +32,11 @@ def set_sql_engine(new_sql_engine):
|
|||||||
_sql_engine = new_sql_engine
|
_sql_engine = new_sql_engine
|
||||||
|
|
||||||
def get_docs():
|
def get_docs():
|
||||||
return _docs
|
global _workdir
|
||||||
|
return scan_docs_directory(_workdir)
|
||||||
def set_docs(new_docs):
|
|
||||||
global _docs
|
|
||||||
_docs = new_docs
|
|
||||||
|
|
||||||
def get_datasets():
|
def get_datasets():
|
||||||
return _datasets
|
return _datasets
|
||||||
|
|
||||||
def set_datasets(new_datasets):
|
|
||||||
global _datasets
|
|
||||||
_datasets = new_datasets
|
|
||||||
|
|
||||||
def get_model():
|
def get_model():
|
||||||
return _model
|
return _model
|
||||||
|
|
||||||
|
4
main.py
4
main.py
@@ -1,5 +1,5 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from frontend.setting_page import setting_page
|
import train
|
||||||
from frontend import *
|
from frontend import *
|
||||||
from db import initialize_sqlite_db, initialize_prompt_store
|
from db import initialize_sqlite_db, initialize_prompt_store
|
||||||
from global_var import init_global_var, get_sql_engine, get_prompt_store
|
from global_var import init_global_var, get_sql_engine, get_prompt_store
|
||||||
@@ -26,4 +26,4 @@ if __name__ == "__main__":
|
|||||||
with gr.TabItem("设置"):
|
with gr.TabItem("设置"):
|
||||||
setting_page()
|
setting_page()
|
||||||
|
|
||||||
app.launch()
|
app.launch(server_name="0.0.0.0")
|
@@ -1,9 +1,11 @@
|
|||||||
openai>=1.0.0
|
openai>=1.0.0
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
pydantic>=2.0.0
|
pydantic>=2.0.0
|
||||||
gradio>=5.0.0
|
gradio>=5.25.0
|
||||||
langchain>=0.3
|
langchain>=0.3
|
||||||
tinydb>=4.0.0
|
tinydb>=4.0.0
|
||||||
unsloth>=2025.3.19
|
unsloth>=2025.3.19
|
||||||
sqlmodel>=0.0.24
|
sqlmodel>=0.0.24
|
||||||
jinja2>=3.1.0
|
jinja2>=3.1.0
|
||||||
|
tensorboardX>=2.6.2.2
|
||||||
|
tensorboard>=2.19.0
|
@@ -1,4 +1,4 @@
|
|||||||
from .dataset import *
|
from .dataset import *
|
||||||
from .dataset_generation import APIProvider, LLMResponse, LLMRequest
|
from .dataset_generation import *
|
||||||
from .md_doc import MarkdownNode
|
from .md_doc import MarkdownNode
|
||||||
from .prompt import promptTempleta
|
from .prompt import promptTempleta
|
@@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
class doc(BaseModel):
|
class Doc(BaseModel):
|
||||||
id: Optional[int] = Field(default=None, description="文档ID")
|
id: Optional[int] = Field(default=None, description="文档ID")
|
||||||
name: str = Field(default="", description="文档名称")
|
name: str = Field(default="", description="文档名称")
|
||||||
path: str = Field(default="", description="文档路径")
|
path: str = Field(default="", description="文档路径")
|
||||||
@@ -13,18 +13,18 @@ class Q_A(BaseModel):
|
|||||||
question: str = Field(default="", min_length=1,description="问题")
|
question: str = Field(default="", min_length=1,description="问题")
|
||||||
answer: str = Field(default="", min_length=1, description="答案")
|
answer: str = Field(default="", min_length=1, description="答案")
|
||||||
|
|
||||||
class dataset_item(BaseModel):
|
class DatasetItem(BaseModel):
|
||||||
id: Optional[int] = Field(default=None, description="数据集项ID")
|
id: Optional[int] = Field(default=None, description="数据集项ID")
|
||||||
message: list[Q_A] = Field(description="数据集项内容")
|
message: list[Q_A] = Field(description="数据集项内容")
|
||||||
|
|
||||||
class dataset(BaseModel):
|
class Dataset(BaseModel):
|
||||||
id: Optional[int] = Field(default=None, description="数据集ID")
|
id: Optional[int] = Field(default=None, description="数据集ID")
|
||||||
name: str = Field(default="", description="数据集名称")
|
name: str = Field(default="", description="数据集名称")
|
||||||
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
|
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
|
||||||
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档")
|
source_doc: Optional[Doc] = Field(default=None, description="数据集来源文档")
|
||||||
description: Optional[str] = Field(default="", description="数据集描述")
|
description: Optional[str] = Field(default="", description="数据集描述")
|
||||||
created_at: datetime = Field(
|
created_at: datetime = Field(
|
||||||
default_factory=lambda: datetime.now(timezone.utc),
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
description="记录创建时间"
|
description="记录创建时间"
|
||||||
)
|
)
|
||||||
dataset_items: list[dataset_item] = Field(default_factory=list, description="数据集项列表")
|
dataset_items: list[DatasetItem] = Field(default_factory=list, description="数据集项列表")
|
@@ -12,40 +12,36 @@ class APIProvider(SQLModel, table=True):
|
|||||||
description="记录创建时间"
|
description="记录创建时间"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class LLMParameters(SQLModel):
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
class TokensUsage(SQLModel):
|
||||||
|
prompt_tokens: int = Field(default=0, description="提示词使用的token数量")
|
||||||
|
completion_tokens: int = Field(default=0, description="完成部分使用的token数量")
|
||||||
|
prompt_cache_hit_tokens: Optional[int] = Field(default=None, description="缓存命中token数量")
|
||||||
|
prompt_cache_miss_tokens: Optional[int] = Field(default=None, description="缓存未命中token数量")
|
||||||
|
|
||||||
class LLMResponse(SQLModel):
|
class LLMResponse(SQLModel):
|
||||||
timestamp: datetime = Field(
|
timestamp: datetime = Field(
|
||||||
default_factory=lambda: datetime.now(timezone.utc),
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
description="响应的时间戳"
|
description="响应的时间戳"
|
||||||
)
|
)
|
||||||
response_id: str = Field(..., description="响应的唯一ID")
|
response_id: str = Field(..., description="响应的唯一ID")
|
||||||
tokens_usage: dict = Field(default_factory=lambda: {
|
tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
||||||
"prompt_tokens": 0,
|
content: str = Field(default_factory=dict, description="API响应的内容")
|
||||||
"completion_tokens": 0,
|
|
||||||
"prompt_cache_hit_tokens": None,
|
|
||||||
"prompt_cache_miss_tokens": None
|
|
||||||
}, description="token使用信息")
|
|
||||||
response_content: dict = Field(default_factory=dict, description="API响应的内容")
|
|
||||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||||
llm_parameters: dict = Field(default_factory=lambda: {
|
llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")
|
||||||
"temperature": None,
|
|
||||||
"max_tokens": None,
|
|
||||||
"top_p": None,
|
|
||||||
"frequency_penalty": None,
|
|
||||||
"presence_penalty": None,
|
|
||||||
"seed": None
|
|
||||||
}, description="API的生成参数")
|
|
||||||
|
|
||||||
class LLMRequest(SQLModel):
|
class LLMRequest(SQLModel):
|
||||||
prompt: str = Field(..., description="发送给API的提示词")
|
prompt: str = Field(..., description="发送给API的提示词")
|
||||||
provider_id: int = Field(foreign_key="apiprovider.id")
|
api_provider: APIProvider = Field(..., description="API提供者的信息")
|
||||||
provider: APIProvider = Relationship()
|
|
||||||
format: Optional[str] = Field(default=None, description="API响应的格式")
|
format: Optional[str] = Field(default=None, description="API响应的格式")
|
||||||
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
||||||
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
||||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||||
total_tokens_usage: dict = Field(default_factory=lambda: {
|
total_tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"prompt_cache_hit_tokens": None,
|
|
||||||
"prompt_cache_miss_tokens": None
|
|
||||||
}, description="token使用信息")
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
class promptTempleta(BaseModel):
|
class promptTempleta(BaseModel):
|
||||||
id: Optional[int] = Field(default=None, description="模板ID")
|
id: Optional[int] = Field(default=None, description="模板ID")
|
||||||
@@ -11,3 +12,9 @@ class promptTempleta(BaseModel):
|
|||||||
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
|
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
|
||||||
description="记录创建时间"
|
description="记录创建时间"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator('content')
|
||||||
|
def validate_content(cls, value):
|
||||||
|
if not "document_slice" in PromptTemplate.from_template(value).input_variables:
|
||||||
|
raise ValueError("模板变量缺少 document_slice")
|
||||||
|
return value
|
@@ -1,4 +1,5 @@
|
|||||||
from .parse_markdown import parse_markdown
|
from .parse_markdown import *
|
||||||
from .scan_doc_dir import *
|
from .document import *
|
||||||
from .json_example import generate_example_json
|
from .json_example import generate_json_example
|
||||||
from .model import *
|
from .port import *
|
||||||
|
from .reasoning import call_openai_api
|
@@ -1,19 +1,19 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
from schema.dataset import dataset, dataset_item, Q_A
|
from schema.dataset import Dataset, DatasetItem, Q_A
|
||||||
import json
|
import json
|
||||||
|
|
||||||
def convert_json_to_dataset(json_data: List[dict]) -> dataset:
|
def convert_json_to_dataset(json_data: List[dict]) -> Dataset:
|
||||||
# 将JSON数据转换为dataset格式
|
# 将JSON数据转换为dataset格式
|
||||||
dataset_items = []
|
dataset_items = []
|
||||||
item_id = 1 # 自增ID计数器
|
item_id = 1 # 自增ID计数器
|
||||||
for item in json_data:
|
for item in json_data:
|
||||||
qa = Q_A(question=item["question"], answer=item["answer"])
|
qa = Q_A(question=item["question"], answer=item["answer"])
|
||||||
dataset_item_obj = dataset_item(id=item_id, message=[qa])
|
dataset_item_obj = DatasetItem(id=item_id, message=[qa])
|
||||||
dataset_items.append(dataset_item_obj)
|
dataset_items.append(dataset_item_obj)
|
||||||
item_id += 1 # ID自增
|
item_id += 1 # ID自增
|
||||||
|
|
||||||
# 创建dataset对象
|
# 创建dataset对象
|
||||||
result_dataset = dataset(
|
result_dataset = Dataset(
|
||||||
name="Converted Dataset",
|
name="Converted Dataset",
|
||||||
model_id=None,
|
model_id=None,
|
||||||
description="Dataset converted from JSON",
|
description="Dataset converted from JSON",
|
||||||
|
@@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
# 添加项目根目录到sys.path
|
# 添加项目根目录到sys.path
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from schema import doc
|
from schema import Doc
|
||||||
|
|
||||||
def scan_docs_directory(workdir: str):
|
def scan_docs_directory(workdir: str):
|
||||||
docs_dir = os.path.join(workdir, "docs")
|
docs_dir = os.path.join(workdir, "docs")
|
||||||
@@ -21,7 +21,7 @@ def scan_docs_directory(workdir: str):
|
|||||||
for file in files:
|
for file in files:
|
||||||
if file.endswith(".md"):
|
if file.endswith(".md"):
|
||||||
markdown_files.append(os.path.join(root, file))
|
markdown_files.append(os.path.join(root, file))
|
||||||
to_return.append(doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
|
to_return.append(Doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
|
||||||
|
|
||||||
return to_return
|
return to_return
|
||||||
|
|
@@ -1,32 +1,29 @@
|
|||||||
from pydantic import BaseModel, create_model
|
from pydantic import BaseModel, create_model
|
||||||
from typing import Any, Dict, List, Optional, get_args, get_origin
|
from typing import Any, Dict, List, Optional, Union, get_args, get_origin
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, date
|
from datetime import datetime, date
|
||||||
|
|
||||||
def generate_example_json(model: type[BaseModel]) -> str:
|
def generate_json_example(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
|
||||||
"""
|
"""
|
||||||
根据 Pydantic V2 模型生成示例 JSON 数据结构。
|
根据 Pydantic V2 模型生成示例 JSON 数据结构。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _generate_example(field_type: Any) -> Any:
|
def _generate_example(field_type: Any) -> Any:
|
||||||
origin = get_origin(field_type)
|
origin = get_origin(field_type)
|
||||||
args = get_args(field_type)
|
args = get_args(field_type)
|
||||||
|
|
||||||
if origin is list or origin is List:
|
if origin is list or origin is List:
|
||||||
if args:
|
# 生成多个元素,这里生成 3 个
|
||||||
return [_generate_example(args[0])]
|
result = [_generate_example(args[0]) for _ in range(list_length)] if args else []
|
||||||
else:
|
result.append("......")
|
||||||
return []
|
return result
|
||||||
elif origin is dict or origin is Dict:
|
elif origin is dict or origin is Dict:
|
||||||
if len(args) == 2 and args[0] is str:
|
if len(args) == 2:
|
||||||
return {"key": _generate_example(args[1])}
|
return {"key": _generate_example(args[1])}
|
||||||
else:
|
|
||||||
return {}
|
return {}
|
||||||
elif origin is Optional or origin is type(None):
|
elif origin is Union:
|
||||||
if args:
|
# 处理 Optional 类型(Union[T, None])
|
||||||
return _generate_example(args[0])
|
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||||
else:
|
return _generate_example(non_none_args[0]) if non_none_args else None
|
||||||
return None
|
|
||||||
elif field_type is str:
|
elif field_type is str:
|
||||||
return "string"
|
return "string"
|
||||||
elif field_type is int:
|
elif field_type is int:
|
||||||
@@ -39,13 +36,22 @@ def generate_example_json(model: type[BaseModel]) -> str:
|
|||||||
return datetime.now().isoformat()
|
return datetime.now().isoformat()
|
||||||
elif field_type is date:
|
elif field_type is date:
|
||||||
return date.today().isoformat()
|
return date.today().isoformat()
|
||||||
elif issubclass(field_type, BaseModel):
|
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||||
return generate_example_json(field_type)
|
return json.loads(generate_json_example(field_type, include_optional))
|
||||||
else:
|
else:
|
||||||
return "unknown" # 对于未知类型返回 "unknown"
|
# 处理直接类型注解(非泛型)
|
||||||
|
if field_type is type(None):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
if issubclass(field_type, BaseModel):
|
||||||
|
return json.loads(generate_json_example(field_type, include_optional))
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
example_data = {}
|
example_data = {}
|
||||||
for field_name, field in model.model_fields.items():
|
for field_name, field in model.model_fields.items():
|
||||||
|
if include_optional or not isinstance(field.default, type(None)):
|
||||||
example_data[field_name] = _generate_example(field.annotation)
|
example_data[field_name] = _generate_example(field.annotation)
|
||||||
|
|
||||||
return json.dumps(example_data, indent=2, default=str)
|
return json.dumps(example_data, indent=2, default=str)
|
||||||
@@ -55,9 +61,7 @@ if __name__ == "__main__":
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
# 添加项目根目录到sys.path
|
# 添加项目根目录到sys.path
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from schema import Q_A
|
from schema import Dataset
|
||||||
class Q_A_list(BaseModel):
|
|
||||||
Q_As: List[Q_A]
|
|
||||||
|
|
||||||
print("示例 JSON:")
|
print("示例 JSON:")
|
||||||
print(generate_example_json(Q_A_list))
|
print(generate_json_example(Dataset))
|
||||||
|
15
tools/port.py
Normal file
15
tools/port.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import socket
|
||||||
|
|
||||||
|
# 启动 TensorBoard 子进程前添加端口检测逻辑
|
||||||
|
def find_available_port(start_port):
|
||||||
|
port = start_port
|
||||||
|
while True:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
if s.connect_ex(('localhost', port)) != 0: # 端口未被占用
|
||||||
|
return port
|
||||||
|
port += 1 # 如果端口被占用,尝试下一个端口
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start_port = 6006 # 起始端口号
|
||||||
|
available_port = find_available_port(start_port)
|
||||||
|
print(f"Available port: {available_port}")
|
123
tools/reasoning.py
Normal file
123
tools/reasoning.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import openai
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
|
from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters
|
||||||
|
|
||||||
|
async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest:
|
||||||
|
start_time = datetime.now(timezone.utc)
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key=llm_request.api_provider.api_key,
|
||||||
|
base_url=llm_request.api_provider.base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
total_duration = 0.0
|
||||||
|
total_tokens = TokensUsage()
|
||||||
|
prompt = llm_request.prompt
|
||||||
|
round_start = datetime.now(timezone.utc)
|
||||||
|
if llm_request.format:
|
||||||
|
prompt += "\n请以JSON格式返回结果" + llm_request.format
|
||||||
|
|
||||||
|
for i in range(rounds):
|
||||||
|
round_start = datetime.now(timezone.utc)
|
||||||
|
try:
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
create_args = {
|
||||||
|
"model": llm_request.api_provider.model_id,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": llm_parameters.temperature if llm_parameters else None,
|
||||||
|
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
|
||||||
|
"top_p": llm_parameters.top_p if llm_parameters else None,
|
||||||
|
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
|
||||||
|
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
|
||||||
|
"seed": llm_parameters.seed if llm_parameters else None
|
||||||
|
} # 处理format参数
|
||||||
|
|
||||||
|
if llm_request.format:
|
||||||
|
create_args["response_format"] = {"type": "json_object"}
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(**create_args)
|
||||||
|
|
||||||
|
round_end = datetime.now(timezone.utc)
|
||||||
|
duration = (round_end - round_start).total_seconds()
|
||||||
|
total_duration += duration
|
||||||
|
|
||||||
|
# 处理可能不存在的缓存token字段
|
||||||
|
usage = response.usage
|
||||||
|
cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None)
|
||||||
|
cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None)
|
||||||
|
|
||||||
|
tokens_usage = TokensUsage(
|
||||||
|
prompt_tokens=usage.prompt_tokens,
|
||||||
|
completion_tokens=usage.completion_tokens,
|
||||||
|
prompt_cache_hit_tokens=cache_hit,
|
||||||
|
prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# 累加总token使用量
|
||||||
|
total_tokens.prompt_tokens += tokens_usage.prompt_tokens
|
||||||
|
total_tokens.completion_tokens += tokens_usage.completion_tokens
|
||||||
|
if tokens_usage.prompt_cache_hit_tokens:
|
||||||
|
total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens
|
||||||
|
if tokens_usage.prompt_cache_miss_tokens:
|
||||||
|
total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens
|
||||||
|
|
||||||
|
llm_request.response.append(LLMResponse(
|
||||||
|
response_id=response.id,
|
||||||
|
tokens_usage=tokens_usage,
|
||||||
|
content = response.choices[0].message.content,
|
||||||
|
total_duration=duration,
|
||||||
|
llm_parameters=llm_parameters
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
round_end = datetime.now(timezone.utc)
|
||||||
|
duration = (round_end - round_start).total_seconds()
|
||||||
|
total_duration += duration
|
||||||
|
|
||||||
|
llm_request.response.append(LLMResponse(
|
||||||
|
response_id=f"error-round-{i+1}",
|
||||||
|
content={"error": str(e)},
|
||||||
|
total_duration=duration
|
||||||
|
))
|
||||||
|
if llm_request.error is None:
|
||||||
|
llm_request.error = []
|
||||||
|
llm_request.error.append(str(e))
|
||||||
|
|
||||||
|
# 更新总耗时和总token使用量
|
||||||
|
llm_request.total_duration = total_duration
|
||||||
|
llm_request.total_tokens_usage = total_tokens
|
||||||
|
|
||||||
|
return llm_request
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from json_example import generate_json_example
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from global_var import get_sql_engine, init_global_var
|
||||||
|
from schema import DatasetItem
|
||||||
|
|
||||||
|
init_global_var("workdir")
|
||||||
|
api_state = "1 deepseek-chat"
|
||||||
|
with Session(get_sql_engine()) as session:
|
||||||
|
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||||
|
llm_request = LLMRequest(
|
||||||
|
prompt="测试,随便说点什么",
|
||||||
|
api_provider=api_provider,
|
||||||
|
format=generate_json_example(DatasetItem)
|
||||||
|
)
|
||||||
|
|
||||||
|
# # 单次调用示例
|
||||||
|
# result = asyncio.run(call_openai_api(llm_request))
|
||||||
|
# print(f"\n单次调用结果 - 响应数量: {len(result.response)}")
|
||||||
|
# for i, resp in enumerate(result.response, 1):
|
||||||
|
# print(f"响应{i}: {resp.response_content}")
|
||||||
|
|
||||||
|
# 多次调用示例
|
||||||
|
params = LLMParameters(temperature=0.7, max_tokens=100)
|
||||||
|
result = asyncio.run(call_openai_api(llm_request, 3,params))
|
||||||
|
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
|
||||||
|
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
|
||||||
|
for i, resp in enumerate(result.response, 1):
|
||||||
|
print(f"响应{i}: {resp.content}")
|
1
train/__init__.py
Normal file
1
train/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .model import *
|
@@ -31,8 +31,18 @@ def formatting_prompts(examples,tokenizer):
|
|||||||
return {"text": texts}
|
return {"text": texts}
|
||||||
|
|
||||||
|
|
||||||
def train_model(model, tokenizer, dataset, output_dir, learning_rate,
|
def train_model(
|
||||||
per_device_train_batch_size, epoch, save_steps, lora_rank):
|
model,
|
||||||
|
tokenizer,
|
||||||
|
dataset: list,
|
||||||
|
train_dir: str,
|
||||||
|
learning_rate: float,
|
||||||
|
per_device_train_batch_size: int,
|
||||||
|
epoch: int,
|
||||||
|
save_steps: int,
|
||||||
|
lora_rank: int,
|
||||||
|
trainer_callback=None
|
||||||
|
) -> None:
|
||||||
# 模型配置参数
|
# 模型配置参数
|
||||||
dtype = None # 数据类型,None表示自动选择
|
dtype = None # 数据类型,None表示自动选择
|
||||||
load_in_4bit = False # 使用4bit量化加载模型以节省显存
|
load_in_4bit = False # 使用4bit量化加载模型以节省显存
|
||||||
@@ -75,8 +85,8 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
|
|||||||
chat_template="qwen-2.5",
|
chat_template="qwen-2.5",
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = HFDataset.from_list(dataset)
|
train_dataset = HFDataset.from_list(dataset)
|
||||||
dataset = dataset.map(formatting_prompts,
|
train_dataset = train_dataset.map(formatting_prompts,
|
||||||
fn_kwargs={"tokenizer": tokenizer},
|
fn_kwargs={"tokenizer": tokenizer},
|
||||||
batched=True)
|
batched=True)
|
||||||
|
|
||||||
@@ -84,7 +94,7 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
|
|||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=model, # 待训练的模型
|
model=model, # 待训练的模型
|
||||||
tokenizer=tokenizer, # 分词器
|
tokenizer=tokenizer, # 分词器
|
||||||
train_dataset=dataset, # 训练数据集
|
train_dataset=train_dataset, # 训练数据集
|
||||||
dataset_text_field="text", # 数据集字段的名称
|
dataset_text_field="text", # 数据集字段的名称
|
||||||
max_seq_length=model.max_seq_length, # 最大序列长度
|
max_seq_length=model.max_seq_length, # 最大序列长度
|
||||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||||
@@ -96,20 +106,24 @@ def train_model(model, tokenizer, dataset, output_dir, learning_rate,
|
|||||||
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
|
warmup_steps=int(epoch * 0.1), # 预热步数,逐步增加学习率
|
||||||
learning_rate=learning_rate, # 学习率
|
learning_rate=learning_rate, # 学习率
|
||||||
lr_scheduler_type="linear", # 线性学习率调度器
|
lr_scheduler_type="linear", # 线性学习率调度器
|
||||||
max_steps=int(epoch * len(dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据)
|
max_steps=int(epoch * len(train_dataset)/per_device_train_batch_size), # 最大训练步数(一步 = 处理一个batch的数据)
|
||||||
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
|
fp16=not is_bfloat16_supported(), # 如果不支持bf16则使用fp16
|
||||||
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
|
bf16=is_bfloat16_supported(), # 如果支持则使用bf16
|
||||||
logging_steps=1, # 每1步记录一次日志
|
logging_steps=1, # 每1步记录一次日志
|
||||||
optim="adamw_8bit", # 使用8位AdamW优化器节省显存,几乎不影响训练效果
|
optim="adamw_8bit", # 使用8位AdamW优化器节省显存,几乎不影响训练效果
|
||||||
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
|
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
|
||||||
seed=114514, # 随机数种子
|
seed=114514, # 随机数种子
|
||||||
output_dir=output_dir, # 保存模型检查点和训练日志
|
output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志
|
||||||
save_strategy="steps", # 按步保存中间权重
|
save_strategy="steps", # 按步保存中间权重
|
||||||
save_steps=save_steps, # 使用动态传入的保存步数
|
save_steps=save_steps, # 使用动态传入的保存步数
|
||||||
# report_to="tensorboard", # 将信息输出到tensorboard
|
logging_dir=train_dir + "/logs", # 日志文件存储路径
|
||||||
|
report_to="tensorboard", # 使用TensorBoard记录日志
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if trainer_callback is not None:
|
||||||
|
trainer.add_callback(trainer_callback)
|
||||||
|
|
||||||
trainer = train_on_responses_only(
|
trainer = train_on_responses_only(
|
||||||
trainer,
|
trainer,
|
||||||
instruction_part = "<|im_start|>user\n",
|
instruction_part = "<|im_start|>user\n",
|
48
train/save_model.py
Normal file
48
train/save_model.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import os
|
||||||
|
from global_var import get_model, get_tokenizer
|
||||||
|
|
||||||
|
def save_model_to_dir(save_model_name, models_dir, model, tokenizer, save_method="default"):
|
||||||
|
"""
|
||||||
|
保存模型到指定目录
|
||||||
|
:param save_model_name: 要保存的模型名称
|
||||||
|
:param models_dir: 模型保存的基础目录
|
||||||
|
:param model: 要保存的模型
|
||||||
|
:param tokenizer: 要保存的tokenizer
|
||||||
|
:param save_method: 保存模式选项
|
||||||
|
- "default": 默认保存方式
|
||||||
|
- "merged_16bit": 合并为16位
|
||||||
|
- "merged_4bit": 合并为4位
|
||||||
|
- "lora": 仅LoRA适配器
|
||||||
|
- "gguf": 保存为GGUF格式
|
||||||
|
- "gguf_q4_k_m": 保存为q4_k_m GGUF格式
|
||||||
|
- "gguf_f16": 保存为16位GGUF格式
|
||||||
|
:return: 保存结果消息或错误信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if model is None:
|
||||||
|
return "没有加载的模型可保存"
|
||||||
|
|
||||||
|
save_path = os.path.join(models_dir, save_model_name)
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
|
||||||
|
if save_method == "default":
|
||||||
|
model.save_pretrained(save_path)
|
||||||
|
tokenizer.save_pretrained(save_path)
|
||||||
|
elif save_method == "merged_16bit":
|
||||||
|
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_16bit")
|
||||||
|
elif save_method == "merged_4bit":
|
||||||
|
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_4bit_forced")
|
||||||
|
elif save_method == "lora":
|
||||||
|
model.save_pretrained_merged(save_path, tokenizer, save_method="lora")
|
||||||
|
elif save_method == "gguf":
|
||||||
|
model.save_pretrained_gguf(save_path, tokenizer)
|
||||||
|
elif save_method == "gguf_q4_k_m":
|
||||||
|
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="q4_k_m")
|
||||||
|
elif save_method == "gguf_f16":
|
||||||
|
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="f16")
|
||||||
|
else:
|
||||||
|
return f"不支持的保存模式: {save_method}"
|
||||||
|
|
||||||
|
return f"模型已保存到 {save_path} (模式: {save_method})"
|
||||||
|
except Exception as e:
|
||||||
|
return f"保存模型时出错: {str(e)}"
|
Reference in New Issue
Block a user