From 67281fe06a07e7803d5d229d5ada7ac373d55174 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Wed, 9 Apr 2025 09:58:42 +0800 Subject: [PATCH] =?UTF-8?q?feat(db):=20=E6=B7=BB=E5=8A=A0=20prompt=20?= =?UTF-8?q?=E5=AD=98=E5=82=A8=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 prompt_store 模块,使用 TinyDB 存储 prompt 模板 - 在全局变量中添加 prompt_store 实例 - 更新 main.py,初始化 prompt 存储 - 新增 prompt 模板的 Pydantic 模型 - 更新 requirements.txt,添加 tinydb 依赖 --- db/__init__.py | 8 ++++++- db/prompt_store.py | 60 ++++++++++++++++++++++++++++++++++++++++++++++ global_var.py | 2 ++ main.py | 6 ++--- requirements.txt | 3 ++- schema/prompt.py | 13 ++++++++++ 6 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 db/prompt_store.py create mode 100644 schema/prompt.py diff --git a/db/__init__.py b/db/__init__.py index 8fe1308..06161f0 100644 --- a/db/__init__.py +++ b/db/__init__.py @@ -1,3 +1,9 @@ from .init_db import get_sqlite_engine, initialize_sqlite_db +from .prompt_store import get_prompt_tinydb, initialize_prompt_store -__all__ = ['get_sqlite_engine', 'initialize_sqlite_db'] \ No newline at end of file +__all__ = [ + "get_sqlite_engine", + "initialize_sqlite_db", + "get_prompt_tinydb", + "initialize_prompt_store" +] \ No newline at end of file diff --git a/db/prompt_store.py b/db/prompt_store.py new file mode 100644 index 0000000..85a0481 --- /dev/null +++ b/db/prompt_store.py @@ -0,0 +1,60 @@ +import os +import json +import sys +from typing import Optional +from pathlib import Path +from datetime import datetime, timezone +from tinydb import TinyDB, Query +from tinydb.storages import JSONStorage + +# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块 +sys.path.append(str(Path(__file__).resolve().parent.parent)) +from schema.prompt import promptTempleta + +# 全局变量,用于存储TinyDB实例 +_db_instance: Optional[TinyDB] = None +# 自定义存储类,用于格式化JSON数据 +def get_prompt_tinydb(workdir: str) -> TinyDB: + """ + 获取TinyDB实例。如果实例尚未创建,则创建一个新的并返回。 + + Args: + workdir (str): 工作目录路径,用于确定数据库文件的存储位置。 + + Returns: + TinyDB: TinyDB数据库实例 + """ + global _db_instance + if not _db_instance: + # 创建数据库目录(如果不存在) + db_dir = os.path.join(workdir, "db") + os.makedirs(db_dir, exist_ok=True) + # 定义数据库文件路径 + db_path = os.path.join(db_dir, "prompts.json") + # 创建TinyDB实例 + _db_instance = TinyDB(db_path) + return _db_instance + +def initialize_prompt_store(db: TinyDB) -> None: + """ + 初始化prompt模板存储 + + Args: + db (TinyDB): TinyDB数据库实例 + """ + db.insert(promptTempleta(name="default", + description="默认提示词模板", + content="""项目名为:{ project_name } +请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文 +文档节选:{ content } + 按照如下json格式返回:{ json }""").model_dump()) + # TinyDB不需要显式创建表结构,首次使用时自动创建 + + +if __name__ == "__main__": + # 定义工作目录路径 + workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") + # 获取数据库实例 + db = get_prompt_tinydb(workdir) + # 初始化prompt存储 + initialize_prompt_store(db) \ No newline at end of file diff --git a/global_var.py b/global_var.py index ee0ef40..67c140f 100644 --- a/global_var.py +++ b/global_var.py @@ -1,3 +1,5 @@ from db import get_sqlite_engine +from db import get_prompt_tinydb +prompt_store = get_prompt_tinydb("workdir") sql_engine = get_sqlite_engine("workdir") \ No newline at end of file diff --git a/main.py b/main.py index 9a7b9fc..c3437e5 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,12 @@ import gradio as gr from frontend.setting_page import setting_page from frontend import * -from db import initialize_sqlite_db -from global_var import sql_engine +from db import initialize_sqlite_db,initialize_prompt_store +from global_var import sql_engine,prompt_store if __name__ == "__main__": initialize_sqlite_db(sql_engine) - + initialize_prompt_store(prompt_store) with gr.Blocks() as app: gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架") with gr.Tabs(): diff --git a/requirements.txt b/requirements.txt index fe9aed6..2bd21f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ openai>=1.0.0 python-dotenv>=1.0.0 pydantic>=2.0.0 gradio>=5.0.0 -langchain>=0.3 \ No newline at end of file +langchain>=0.3 +tinydb>=4.0.0 \ No newline at end of file diff --git a/schema/prompt.py b/schema/prompt.py new file mode 100644 index 0000000..a6143e3 --- /dev/null +++ b/schema/prompt.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, Field +from typing import Optional +from datetime import datetime, timezone + +class promptTempleta(BaseModel): + id: Optional[int] = Field(default=None, description="模板ID") + name: Optional[str] = Field(default="", description="模板名称") + description: Optional[str] = Field(default="", description="模板描述") + content: str = Field(default="", min_length=1, description="模板内容") + created_at: str = Field( + default_factory=lambda: datetime.now(timezone.utc).isoformat(), + description="记录创建时间" + ) \ No newline at end of file