Compare commits

14 Commits

Author SHA1 Message Date
carry
7a4388c928 featmodel): 添加保存模式选择功能
在模型管理页面中新增保存模式选择功能,用户可以通过下拉菜单选择不同的保存模式(如默认、合并16位、合并4位等)。同时,将保存模型的逻辑抽离到独立的`save_model.py`文件中,以提高代码的可维护性和复用性。
2025-04-23 14:09:02 +08:00
carry
6338706967 feat: 修改应用启动方式
- 将 app.launch() 修改为 app.launch(server_name="0.0.0.0")
- 此修改使应用能够监听所有网络接口,提高可用性
2025-04-22 19:08:23 +08:00
carry
3718c75cee feat(frontend): 添加数据集生成页面的处理进度显示
- 在处理文档片段时添加进度条,提升用户体验
- 优化代码格式,调整缩进和空行
2025-04-22 00:14:16 +08:00
carry
905658073a docs(README): 更新项目文档
- 添加项目概述、核心功能、技术架构等详细信息
- 插入系统架构图和技术栈说明
- 细化功能模块描述,包括模型管理、推理、微调等
- 增加QLoRA原理和参数配置说明
- 补充快速开始指南和许可证信息
- 优化文档结构,增强可读性和完整性
2025-04-21 14:28:20 +08:00
carry
9806334517 fix(train_page): 捕获训练过程中的异常并终止 TensorBoard 进程
- 在训练过程中添加异常捕获,将异常信息转换为 gr.Error 抛出
- 确保在发生异常时也能终止 TensorBoard 子进程
2025-04-20 21:40:46 +08:00
carry
0a4efa5641 feat(dataset): 添加数据集生成功能
- 新增数据集生成页面和相关逻辑
- 实现数据集名称重复性检查
- 添加数据集对象创建和保存功能
- 优化文档处理和提示模板应用
- 增加错误处理和数据解析
2025-04-20 21:25:51 +08:00
carry
994d600221 refactor(frontend): 调整 TensorBoard iframe 高度
- 将 TensorBoard iframe 的高度从 500px 修改为 1000px
- 此修改旨在提供更宽敞的显示区域,改善用户体验
2025-04-20 21:25:37 +08:00
carry
d5774eee0c feat(db): 添加数据集导出功能
- 新增 save_dataset 函数,用于将 TinyDB 中的数据集保存为单独的 JSON 文件
- 更新 db/__init__.py,添加 get_dataset_tinydb 函数的引用
- 修改 db/dataset_store.py,实现 save_dataset 函数并添加相关逻辑
2025-04-20 19:44:11 +08:00
carry
87501c9353 fix(global_var): 移除全局变量设置函数set_datasets
- 删除了 global_var.py 文件中的 set_datasets 函数
- 该函数用于设置全局变量 _datasets,但似乎已不再使用
2025-04-20 19:14:00 +08:00
carry
5fc3b4950b refactor(schema): 修改 LLMResponse 中 API 响应内容的字段名称
- 将 LLMResponse 类中的 response_content 字段重命名为 content
- 更新字段类型从 dict 改为 str,以更准确地表示响应内容
- 在 reasoning.py 中相应地修改了调用 LLMResponse 时的参数
2025-04-20 18:40:51 +08:00
carry
c28e4819d9 refactor(frontend/tools): 重命名生成示例 JSON 数据结构的函数
- 将 generate_example_json 函数重命名为 generate_json_example
- 更新相关文件中的函数调用和引用
- 此更改旨在使函数名称更具描述性和一致性
2025-04-20 16:11:36 +08:00
carry
e7cf51d662 refactor(frontend): 重构数据集生成页面
- 调整页面布局,优化用户交互流程
- 新增数据集名称输入框
- 使用 LLMRequest 和 LLMResponse 模型处理请求和响应
- 添加 generate_example_json 函数用于格式化生成数据
- 改进数据集生成逻辑,支持多轮次生成
2025-04-20 16:10:08 +08:00
carry
4c9caff668 refactor(schema): 重构数据集和文档类的命名
- 将 dataset、dataset_item 和 doc 类的首字母大写,以符合 Python 类命名惯例
- 更新相关模块中的导入和引用,以适应新的类名
- 此更改不影响功能,仅提高了代码的一致性和可读性
2025-04-20 01:46:15 +08:00
carry
9236f49b36 feat(frontend): 添加文档切片和并发数功能
- 新增并发数输入框
- 实现文档切片处理
- 更新生成数据集的逻辑,支持并发处理
2025-04-20 01:40:48 +08:00
16 changed files with 284 additions and 96 deletions

View File

@@ -1,15 +1,91 @@
# 基于文档驱动的自适应编码大模型微调框架 # 基于文档驱动的自适应编码大模型微调框架
## 简介 ## 简介
本人的毕业设计
### 项目概述 ### 项目概述
本项目是一个基于文档驱动的自适应编码大模型微调框架,通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。
* 通过深度解析私有库的文档以及其他资源,生成指令型语料,据此对大语言模型进行针对私有库的微调。 ### 核心功能
- 文档解析与语料生成
- 大语言模型高效微调
- 交互式训练与推理界面
- 训练过程可视化监控
### 项目技术 ## 技术架构
* 使用unsloth框架在GPU上实现大语言模型的qlora微调 ### 系统架构
* 使用langchain框架编写工作流实现批量生成微调语料 ```
* 使用tinydb和sqlite实现数据的持久化 [前端界面] -> [模型微调] -> [数据存储]
* 使用gradio框架实现前端展示 ↑ ↑ ↑
│ │ │
[Gradio] [unsloth/QLoRA] [SQLite/TinyDB]
```
**施工中......** ### 技术栈
- **前端界面**: Gradio构建的交互式Web界面
- **模型微调**: 基于unsloth框架的QLoRA高效微调
- **数据存储**: SQLite(结构化数据) + TinyDB(非结构化数据)
- **工作流引擎**: LangChain实现文档解析与语料生成
- **训练监控**: TensorBoard集成
## 功能模块
### 1. 模型管理
- 支持多种格式的大语言模型加载
- 模型信息查看与状态管理
- 模型卸载与内存释放
### 2. 模型推理
- 对话式交互界面
- 流式响应输出
- 历史对话管理
### 3. 模型微调
- 训练参数配置(学习率、batch size等)
- LoRA参数配置(秩、alpha等)
- 训练过程实时监控
- 训练中断与恢复
### 4. 数据集生成
- 文档解析与清洗
- 指令-响应对生成
- 数据集质量评估
- 数据集版本管理
## 微调技术
### QLoRA原理
QLoRA(Quantized Low-Rank Adaptation)是一种高效的大模型微调技术,核心特点包括:
1. **4-bit量化**: 将预训练模型量化为4-bit表示大幅减少显存占用
2. **低秩适配**: 通过低秩矩阵分解(LoRA)实现参数高效更新
3. **内存优化**: 使用梯度检查点等技术进一步降低显存需求
### 参数配置
- **学习率**: 建议2e-5到2e-4
- **LoRA秩**: 控制适配器复杂度(建议16-64)
- **LoRA Alpha**: 控制适配器更新幅度(通常设为秩的1-2倍)
### 训练监控
- **TensorBoard集成**: 实时查看损失曲线、学习率等指标
- **日志记录**: 训练过程详细日志保存
- **模型检查点**: 定期保存中间权重
## 快速开始
1. 安装依赖:
```bash
pip install -r requirements.txt
```
2. 启动应用:
```bash
python main.py
```
3. 访问Web界面:
```
http://localhost:7860
```
## 许可证
MIT License

View File

@@ -1,11 +1,12 @@
from .init_db import load_sqlite_engine, initialize_sqlite_db from .init_db import load_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store from .prompt_store import get_prompt_tinydb, initialize_prompt_store
from .dataset_store import get_all_dataset from .dataset_store import get_all_dataset, save_dataset
__all__ = [ __all__ = [
"load_sqlite_engine", "load_sqlite_engine",
"initialize_sqlite_db", "initialize_sqlite_db",
"get_prompt_tinydb", "get_prompt_tinydb",
"initialize_prompt_store", "initialize_prompt_store",
"get_all_dataset" "get_all_dataset",
"save_dataset"
] ]

View File

@@ -8,7 +8,7 @@ from tinydb.storages import MemoryStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块 # 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A from schema.dataset import Dataset, DatasetItem, Q_A
def get_all_dataset(workdir: str) -> TinyDB: def get_all_dataset(workdir: str) -> TinyDB:
""" """
@@ -39,6 +39,30 @@ def get_all_dataset(workdir: str) -> TinyDB:
return db return db
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
"""
将TinyDB中的数据集保存为单独的json文件
Args:
db (TinyDB): 包含数据集对象的TinyDB实例
workdir (str): 工作目录路径
name (str, optional): 要保存的数据集名称None表示保存所有
"""
dataset_dir = os.path.join(workdir, "dataset")
os.makedirs(dataset_dir, exist_ok=True)
datasets = db.all() if name is None else db.search(Query().name == name)
for dataset in datasets:
try:
filename = f"{dataset.get(dataset['name'])}.json"
filepath = os.path.join(dataset_dir, filename)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":
# 定义工作目录路径 # 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
@@ -47,4 +71,11 @@ if __name__ == "__main__":
# 打印结果 # 打印结果
print(f"Found {len(datasets)} datasets:") print(f"Found {len(datasets)} datasets:")
for ds in datasets.all(): for ds in datasets.all():
print(f"- {ds['name']} (ID: {ds['id']})") print(f"- {ds['name']} (ID: {ds['id']})")
# 询问要保存的数据集名称
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
# 保存数据集到文件
save_dataset(datasets, workdir, name)
print(f"Datasets {'all' if name is None else name} saved to json files")

View File

@@ -1,13 +1,18 @@
import gradio as gr import gradio as gr
import sys import sys
import json
from tinydb import Query
from pathlib import Path from pathlib import Path
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from sqlmodel import Session, select from sqlmodel import Session, select
from schema import Dataset, DatasetItem, Q_A
from db.dataset_store import get_all_dataset
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
from tools import call_openai_api from db import save_dataset
from global_var import get_docs, get_prompt_store, get_sql_engine from tools import call_openai_api, process_markdown_file, generate_json_example
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
def dataset_generate_page(): def dataset_generate_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@@ -16,13 +21,6 @@ def dataset_generate_page():
with gr.Column(scale=1): with gr.Column(scale=1):
docs_list = [str(doc.name) for doc in get_docs()] docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown(
choices=docs_list,
value=initial_doc,
label="选择文档",
interactive=True
)
doc_choice = gr.State(value=initial_doc)
prompts = get_prompt_store().all() prompts = get_prompt_store().all()
prompt_list = [f"{p['id']} {p['name']}" for p in prompts] prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_list[0] if prompt_list else None initial_prompt = prompt_list[0] if prompt_list else None
@@ -37,14 +35,6 @@ def dataset_generate_page():
input_variables = prompt_template.input_variables input_variables = prompt_template.input_variables
input_variables.remove("document_slice") input_variables.remove("document_slice")
initial_dataframe_value = [[var, ""] for var in input_variables] initial_dataframe_value = [[var, ""] for var in input_variables]
prompt_dropdown = gr.Dropdown(
choices=prompt_list,
value=initial_prompt,
label="选择模板",
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
# 从数据库获取API Provider列表 # 从数据库获取API Provider列表
with Session(get_sql_engine()) as session: with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all() providers = session.exec(select(APIProvider)).all()
@@ -57,8 +47,18 @@ def dataset_generate_page():
label="选择API", label="选择API",
interactive=True interactive=True
) )
api_choice = gr.State(value=initial_api) doc_dropdown = gr.Dropdown(
choices=docs_list,
value=initial_doc,
label="选择文档",
interactive=True
)
prompt_dropdown = gr.Dropdown(
choices=prompt_list,
value=initial_prompt,
label="选择模板",
interactive=True
)
rounds_input = gr.Number( rounds_input = gr.Number(
value=1, value=1,
label="生成轮次", label="生成轮次",
@@ -67,10 +67,25 @@ def dataset_generate_page():
step=1, step=1,
interactive=True interactive=True
) )
concurrency_input = gr.Number(
value=1,
label="并发数",
minimum=1,
maximum=10,
step=1,
interactive=True,
visible=False
)
dataset_name_input = gr.Textbox(
label="数据集名称",
placeholder="输入数据集保存名称",
interactive=True
)
prompt_choice = gr.State(value=initial_prompt)
generate_button = gr.Button("生成数据集",variant="primary") generate_button = gr.Button("生成数据集",variant="primary")
doc_choice = gr.State(value=initial_doc)
output_text = gr.Textbox(label="生成结果", interactive=False) output_text = gr.Textbox(label="生成结果", interactive=False)
api_choice = gr.State(value=initial_api)
with gr.Column(scale=2): with gr.Column(scale=2):
variables_dataframe = gr.Dataframe( variables_dataframe = gr.Dataframe(
headers=["变量名", "变量值"], headers=["变量名", "变量值"],
@@ -79,8 +94,6 @@ def dataset_generate_page():
label="变量列表", label="变量列表",
value=initial_dataframe_value # 设置初始化数据 value=initial_dataframe_value # 设置初始化数据
) )
def on_doc_change(selected_doc): def on_doc_change(selected_doc):
return selected_doc return selected_doc
@@ -100,8 +113,14 @@ def dataset_generate_page():
dataframe_value = [[var, ""] for var in input_variables] dataframe_value = [[var, ""] for var in input_variables]
return selected_prompt, dataframe_value return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()): def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files dataset_db = get_datasets()
if not dataset_db.search(Query().name == dataset_name):
raise gr.Error("数据集名称已存在")
doc = [i for i in get_docs() if i.name == doc_state][0]
doc_files = doc.markdown_files
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0] prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
prompt = PromptTemplate.from_template(prompt["content"]) prompt = PromptTemplate.from_template(prompt["content"])
with Session(get_sql_engine()) as session: with Session(get_sql_engine()) as session:
@@ -114,23 +133,42 @@ def dataset_generate_page():
if var_name: if var_name:
variables_dict[var_name] = var_value variables_dict[var_name] = var_value
# 注入除document_slice以外的所有参数
prompt = prompt.partial(**variables_dict) prompt = prompt.partial(**variables_dict)
dataset = Dataset(
name=dataset_name,
model_id=[api_provider.model_id],
source_doc=doc,
dataset_items=[]
)
total_slices = len(document_slice_list)
for i, document_slice in enumerate(document_slice_list):
progress((i + 1) / total_slices, desc=f"处理文档片段 {i + 1}/{total_slices}")
request = LLMRequest(api_provider=api_provider,
prompt=prompt.format(document_slice=document_slice),
format=generate_json_example(DatasetItem))
call_openai_api(request, rounds)
print(doc) for resp in request.response:
print(prompt.format(document_slice="test")) try:
print(variables_dict) content = json.loads(resp.content)
dataset_item = DatasetItem(
message=[Q_A(
question=content.get("question", ""),
answer=content.get("answer", "")
)]
)
dataset.dataset_items.append(dataset_item)
except json.JSONDecodeError as e:
print(f"Failed to parse response: {e}")
import time # 保存数据集到TinyDB
total_steps = rounds dataset_db.insert(dataset.model_dump())
for i in range(total_steps):
# 模拟每个步骤的工作负载
time.sleep(0.5)
current_progress = (i + 1) / total_steps save_dataset(dataset_db,get_workdir(),dataset_name)
progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}")
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
return "all done"
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice) doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe]) prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
@@ -138,7 +176,7 @@ def dataset_generate_page():
generate_button.click( generate_button.click(
on_generate_click, on_generate_click,
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input], inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
outputs=output_text outputs=output_text
) )

View File

@@ -30,6 +30,11 @@ def model_manage_page():
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称") save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
save_method_dropdown = gr.Dropdown(
choices=["default", "merged_16bit", "merged_4bit", "lora", "gguf", "gguf_q4_k_m", "gguf_f16"],
label="保存模式",
value="default"
)
with gr.Column(scale=1): with gr.Column(scale=1):
save_button = gr.Button("保存模型", variant="secondary") save_button = gr.Button("保存模型", variant="secondary")
@@ -73,21 +78,12 @@ def model_manage_page():
unload_button.click(fn=unload_model, inputs=None, outputs=state_output) unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
def save_model(save_model_name): from train.save_model import save_model_to_dir
try:
global model, tokenizer def save_model(save_model_name, save_method):
if model is None: return save_model_to_dir(save_model_name, models_dir, get_model(), get_tokenizer(), save_method)
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name)
os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
return f"模型已保存到 {save_path}"
except Exception as e:
return f"保存模型时出错: {str(e)}"
save_button.click(fn=save_model, inputs=save_model_name_input, outputs=state_output) save_button.click(fn=save_model, inputs=[save_model_name_input, save_method_dropdown], outputs=state_output)
def refresh_model_list(): def refresh_model_list():
try: try:

View File

@@ -74,7 +74,7 @@ def train_page():
# 动态生成 TensorBoard iframe # 动态生成 TensorBoard iframe
tensorboard_url = f"http://localhost:{tensorboard_port}" tensorboard_url = f"http://localhost:{tensorboard_port}"
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="500px"></iframe>' tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>'
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
try: try:
@@ -82,6 +82,8 @@ def train_page():
dataset, new_training_dir, dataset, new_training_dir,
learning_rate, per_device_train_batch_size, epoch, learning_rate, per_device_train_batch_size, epoch,
save_steps, lora_rank) save_steps, lora_rank)
except Exception as e:
raise gr.Error(str(e))
finally: finally:
# 确保训练结束后终止 TensorBoard 子进程 # 确保训练结束后终止 TensorBoard 子进程
tensorboard_process.terminate() tensorboard_process.terminate()

View File

@@ -37,10 +37,6 @@ def get_docs():
def get_datasets(): def get_datasets():
return _datasets return _datasets
def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
def get_model(): def get_model():
return _model return _model

View File

@@ -26,4 +26,4 @@ if __name__ == "__main__":
with gr.TabItem("设置"): with gr.TabItem("设置"):
setting_page() setting_page()
app.launch() app.launch(server_name="0.0.0.0")

View File

@@ -2,7 +2,7 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from datetime import datetime, timezone from datetime import datetime, timezone
class doc(BaseModel): class Doc(BaseModel):
id: Optional[int] = Field(default=None, description="文档ID") id: Optional[int] = Field(default=None, description="文档ID")
name: str = Field(default="", description="文档名称") name: str = Field(default="", description="文档名称")
path: str = Field(default="", description="文档路径") path: str = Field(default="", description="文档路径")
@@ -13,18 +13,18 @@ class Q_A(BaseModel):
question: str = Field(default="", min_length=1,description="问题") question: str = Field(default="", min_length=1,description="问题")
answer: str = Field(default="", min_length=1, description="答案") answer: str = Field(default="", min_length=1, description="答案")
class dataset_item(BaseModel): class DatasetItem(BaseModel):
id: Optional[int] = Field(default=None, description="数据集项ID") id: Optional[int] = Field(default=None, description="数据集项ID")
message: list[Q_A] = Field(description="数据集项内容") message: list[Q_A] = Field(description="数据集项内容")
class dataset(BaseModel): class Dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID") id: Optional[int] = Field(default=None, description="数据集ID")
name: str = Field(default="", description="数据集名称") name: str = Field(default="", description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID") model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档") source_doc: Optional[Doc] = Field(default=None, description="数据集来源文档")
description: Optional[str] = Field(default="", description="数据集描述") description: Optional[str] = Field(default="", description="数据集描述")
created_at: datetime = Field( created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间" description="记录创建时间"
) )
dataset_items: list[dataset_item] = Field(default_factory=list, description="数据集项列表") dataset_items: list[DatasetItem] = Field(default_factory=list, description="数据集项列表")

View File

@@ -33,7 +33,7 @@ class LLMResponse(SQLModel):
) )
response_id: str = Field(..., description="响应的唯一ID") response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息") tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容") content: str = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数") llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")

View File

@@ -1,5 +1,5 @@
from .parse_markdown import * from .parse_markdown import *
from .document import * from .document import *
from .json_example import generate_example_json from .json_example import generate_json_example
from .port import * from .port import *
from .reasoning import call_openai_api from .reasoning import call_openai_api

View File

@@ -1,19 +1,19 @@
from typing import List from typing import List
from schema.dataset import dataset, dataset_item, Q_A from schema.dataset import Dataset, DatasetItem, Q_A
import json import json
def convert_json_to_dataset(json_data: List[dict]) -> dataset: def convert_json_to_dataset(json_data: List[dict]) -> Dataset:
# 将JSON数据转换为dataset格式 # 将JSON数据转换为dataset格式
dataset_items = [] dataset_items = []
item_id = 1 # 自增ID计数器 item_id = 1 # 自增ID计数器
for item in json_data: for item in json_data:
qa = Q_A(question=item["question"], answer=item["answer"]) qa = Q_A(question=item["question"], answer=item["answer"])
dataset_item_obj = dataset_item(id=item_id, message=[qa]) dataset_item_obj = DatasetItem(id=item_id, message=[qa])
dataset_items.append(dataset_item_obj) dataset_items.append(dataset_item_obj)
item_id += 1 # ID自增 item_id += 1 # ID自增
# 创建dataset对象 # 创建dataset对象
result_dataset = dataset( result_dataset = Dataset(
name="Converted Dataset", name="Converted Dataset",
model_id=None, model_id=None,
description="Dataset converted from JSON", description="Dataset converted from JSON",

View File

@@ -4,7 +4,7 @@ from pathlib import Path
# 添加项目根目录到sys.path # 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import doc from schema import Doc
def scan_docs_directory(workdir: str): def scan_docs_directory(workdir: str):
docs_dir = os.path.join(workdir, "docs") docs_dir = os.path.join(workdir, "docs")
@@ -21,7 +21,7 @@ def scan_docs_directory(workdir: str):
for file in files: for file in files:
if file.endswith(".md"): if file.endswith(".md"):
markdown_files.append(os.path.join(root, file)) markdown_files.append(os.path.join(root, file))
to_return.append(doc(name=doc_name, path=doc_path, markdown_files=markdown_files)) to_return.append(Doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
return to_return return to_return

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union, get_args, get_origin
import json import json
from datetime import datetime, date from datetime import datetime, date
def generate_example_json(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str: def generate_json_example(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
""" """
根据 Pydantic V2 模型生成示例 JSON 数据结构。 根据 Pydantic V2 模型生成示例 JSON 数据结构。
""" """
@@ -37,14 +37,14 @@ def generate_example_json(model: type[BaseModel], include_optional: bool = False
elif field_type is date: elif field_type is date:
return date.today().isoformat() return date.today().isoformat()
elif isinstance(field_type, type) and issubclass(field_type, BaseModel): elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
return json.loads(generate_example_json(field_type, include_optional)) return json.loads(generate_json_example(field_type, include_optional))
else: else:
# 处理直接类型注解(非泛型) # 处理直接类型注解(非泛型)
if field_type is type(None): if field_type is type(None):
return None return None
try: try:
if issubclass(field_type, BaseModel): if issubclass(field_type, BaseModel):
return json.loads(generate_example_json(field_type, include_optional)) return json.loads(generate_json_example(field_type, include_optional))
except TypeError: except TypeError:
pass pass
return "unknown" return "unknown"
@@ -61,7 +61,7 @@ if __name__ == "__main__":
from pathlib import Path from pathlib import Path
# 添加项目根目录到sys.path # 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import dataset from schema import Dataset
print("示例 JSON:") print("示例 JSON:")
print(generate_example_json(dataset)) print(generate_json_example(Dataset))

View File

@@ -68,7 +68,7 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
llm_request.response.append(LLMResponse( llm_request.response.append(LLMResponse(
response_id=response.id, response_id=response.id,
tokens_usage=tokens_usage, tokens_usage=tokens_usage,
response_content={"content": response.choices[0].message.content}, content = response.choices[0].message.content,
total_duration=duration, total_duration=duration,
llm_parameters=llm_parameters llm_parameters=llm_parameters
)) ))
@@ -79,7 +79,7 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
llm_request.response.append(LLMResponse( llm_request.response.append(LLMResponse(
response_id=f"error-round-{i+1}", response_id=f"error-round-{i+1}",
response_content={"error": str(e)}, content={"error": str(e)},
total_duration=duration total_duration=duration
)) ))
if llm_request.error is None: if llm_request.error is None:
@@ -93,10 +93,10 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
return llm_request return llm_request
if __name__ == "__main__": if __name__ == "__main__":
from json_example import generate_example_json from json_example import generate_json_example
from sqlmodel import Session, select from sqlmodel import Session, select
from global_var import get_sql_engine, init_global_var from global_var import get_sql_engine, init_global_var
from schema import dataset_item from schema import DatasetItem
init_global_var("workdir") init_global_var("workdir")
api_state = "1 deepseek-chat" api_state = "1 deepseek-chat"
@@ -105,7 +105,7 @@ if __name__ == "__main__":
llm_request = LLMRequest( llm_request = LLMRequest(
prompt="测试,随便说点什么", prompt="测试,随便说点什么",
api_provider=api_provider, api_provider=api_provider,
format=generate_example_json(dataset_item) format=generate_json_example(DatasetItem)
) )
# # 单次调用示例 # # 单次调用示例
@@ -120,4 +120,4 @@ if __name__ == "__main__":
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s") print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}") print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
for i, resp in enumerate(result.response, 1): for i, resp in enumerate(result.response, 1):
print(f"响应{i}: {resp.response_content}") print(f"响应{i}: {resp.content}")

48
train/save_model.py Normal file
View File

@@ -0,0 +1,48 @@
import os
from global_var import get_model, get_tokenizer
def save_model_to_dir(save_model_name, models_dir, model, tokenizer, save_method="default"):
"""
保存模型到指定目录
:param save_model_name: 要保存的模型名称
:param models_dir: 模型保存的基础目录
:param model: 要保存的模型
:param tokenizer: 要保存的tokenizer
:param save_method: 保存模式选项
- "default": 默认保存方式
- "merged_16bit": 合并为16位
- "merged_4bit": 合并为4位
- "lora": 仅LoRA适配器
- "gguf": 保存为GGUF格式
- "gguf_q4_k_m": 保存为q4_k_m GGUF格式
- "gguf_f16": 保存为16位GGUF格式
:return: 保存结果消息或错误信息
"""
try:
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name)
os.makedirs(save_path, exist_ok=True)
if save_method == "default":
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
elif save_method == "merged_16bit":
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_16bit")
elif save_method == "merged_4bit":
model.save_pretrained_merged(save_path, tokenizer, save_method="merged_4bit_forced")
elif save_method == "lora":
model.save_pretrained_merged(save_path, tokenizer, save_method="lora")
elif save_method == "gguf":
model.save_pretrained_gguf(save_path, tokenizer)
elif save_method == "gguf_q4_k_m":
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="q4_k_m")
elif save_method == "gguf_f16":
model.save_pretrained_gguf(save_path, tokenizer, quantization_method="f16")
else:
return f"不支持的保存模式: {save_method}"
return f"模型已保存到 {save_path} (模式: {save_method})"
except Exception as e:
return f"保存模型时出错: {str(e)}"