Compare commits

..

4 Commits

Author SHA1 Message Date
carry
67281fe06a feat(db): 添加 prompt 存储功能
- 新增 prompt_store 模块,使用 TinyDB 存储 prompt 模板
- 在全局变量中添加 prompt_store 实例
- 更新 main.py,初始化 prompt 存储
- 新增 prompt 模板的 Pydantic 模型
- 更新 requirements.txt,添加 tinydb 依赖
2025-04-09 09:58:42 +08:00
carry
2d905a0270 refactor(db): 调整导入模块顺序
- 将 os 和 sys 模块导入提前到文件顶部
- 优化代码结构,遵循常见的 Python 导入模块顺序
2025-04-09 09:57:20 +08:00
carry
374b124cf8 feat(setting_page): 添加供应商后清空输入框
- 修改 add_provider 函数,返回清空后的输入框值
- 更新 add_button.click 事件处理,添加清空输入框的输出
2025-04-09 08:17:43 +08:00
carry
74ae5e1426 refactor(db): 重命名数据库引擎获取函数
将 get_engine 函数重命名为 get_sqlite_engine,以更清晰地表示其功能和用途。
- 更新了 db/__init__.py 中的导入和 __all__ 列表
- 修改了 db/init_db.py 中的函数定义
- 更新了前端设置页面和全局变量中的导入和函数调用

此更改提高了代码的可读性和维护性,特别是在将来可能添加其他类型数据库引擎的情况下。
2025-04-09 08:12:59 +08:00
8 changed files with 97 additions and 15 deletions

View File

@ -1,3 +1,9 @@
from .init_db import get_engine, initialize_sqlite_db from .init_db import get_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
__all__ = ['get_engine', 'initialize_sqlite_db'] __all__ = [
"get_sqlite_engine",
"initialize_sqlite_db",
"get_prompt_tinydb",
"initialize_prompt_store"
]

View File

@ -1,9 +1,9 @@
import os
import sys
from sqlmodel import SQLModel, create_engine, Session from sqlmodel import SQLModel, create_engine, Session
from sqlmodel import select from sqlmodel import select
from typing import Optional from typing import Optional
import os
from pathlib import Path from pathlib import Path
import sys
from dotenv import load_dotenv from dotenv import load_dotenv
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@ -14,7 +14,7 @@ from schema.dataset_generation import APIProvider
# 全局变量,用于存储数据库引擎实例 # 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None _engine: Optional[Engine] = None
def get_engine(workdir: str) -> Engine: def get_sqlite_engine(workdir: str) -> Engine:
""" """
获取数据库引擎实例如果引擎尚未创建则创建一个新的引擎并返回 获取数据库引擎实例如果引擎尚未创建则创建一个新的引擎并返回
@ -74,6 +74,6 @@ if __name__ == "__main__":
# 定义工作目录路径 # 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎 # 获取数据库引擎
engine = get_engine(workdir) engine = get_sqlite_engine(workdir)
# 初始化数据库 # 初始化数据库
initialize_sqlite_db(engine) initialize_sqlite_db(engine)

60
db/prompt_store.py Normal file
View File

@ -0,0 +1,60 @@
import os
import json
import sys
from typing import Optional
from pathlib import Path
from datetime import datetime, timezone
from tinydb import TinyDB, Query
from tinydb.storages import JSONStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.prompt import promptTempleta
# 全局变量用于存储TinyDB实例
_db_instance: Optional[TinyDB] = None
# 自定义存储类用于格式化JSON数据
def get_prompt_tinydb(workdir: str) -> TinyDB:
"""
获取TinyDB实例如果实例尚未创建则创建一个新的并返回
Args:
workdir (str): 工作目录路径用于确定数据库文件的存储位置
Returns:
TinyDB: TinyDB数据库实例
"""
global _db_instance
if not _db_instance:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "prompts.json")
# 创建TinyDB实例
_db_instance = TinyDB(db_path)
return _db_instance
def initialize_prompt_store(db: TinyDB) -> None:
"""
初始化prompt模板存储
Args:
db (TinyDB): TinyDB数据库实例
"""
db.insert(promptTempleta(name="default",
description="默认提示词模板",
content="""项目名为:{ project_name }
请依据以下该项目官方文档的部分内容创造合适的对话数据集用于微调一个了解该项目的小模型的语料要求兼顾文档中间尽可能多的信息点使用中文
文档节选{ content }
按照如下json格式返回{ json }""").model_dump())
# TinyDB不需要显式创建表结构首次使用时自动创建
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库实例
db = get_prompt_tinydb(workdir)
# 初始化prompt存储
initialize_prompt_store(db)

View File

@ -1,7 +1,7 @@
import gradio as gr import gradio as gr
from typing import List from typing import List
from sqlmodel import Session, select from sqlmodel import Session, select
from db import get_engine from db import get_sqlite_engine
from schema import APIProvider from schema import APIProvider
import os import os
from global_var import sql_engine from global_var import sql_engine
@ -30,7 +30,7 @@ def setting_page():
session.add(new_provider) session.add(new_provider)
session.commit() session.commit()
session.refresh(new_provider) session.refresh(new_provider)
return get_providers() return get_providers(), "", "", "" # 返回清空后的输入框值
except Exception as e: except Exception as e:
raise gr.Error(f"添加失败: {str(e)}") raise gr.Error(f"添加失败: {str(e)}")
@ -108,7 +108,7 @@ def setting_page():
add_button.click( add_button.click(
fn=add_provider, fn=add_provider,
inputs=[model_id_input, base_url_input, api_key_input], inputs=[model_id_input, base_url_input, api_key_input],
outputs=[provider_table] outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
) )
provider_table.select(select_record, [], [], show_progress="hidden") provider_table.select(select_record, [], [], show_progress="hidden")

View File

@ -1,3 +1,5 @@
from db import get_engine from db import get_sqlite_engine
from db import get_prompt_tinydb
sql_engine = get_engine("workdir") prompt_store = get_prompt_tinydb("workdir")
sql_engine = get_sqlite_engine("workdir")

View File

@ -1,12 +1,12 @@
import gradio as gr import gradio as gr
from frontend.setting_page import setting_page from frontend.setting_page import setting_page
from frontend import * from frontend import *
from db import initialize_sqlite_db from db import initialize_sqlite_db,initialize_prompt_store
from global_var import sql_engine from global_var import sql_engine,prompt_store
if __name__ == "__main__": if __name__ == "__main__":
initialize_sqlite_db(sql_engine) initialize_sqlite_db(sql_engine)
initialize_prompt_store(prompt_store)
with gr.Blocks() as app: with gr.Blocks() as app:
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架") gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
with gr.Tabs(): with gr.Tabs():

View File

@ -3,3 +3,4 @@ python-dotenv>=1.0.0
pydantic>=2.0.0 pydantic>=2.0.0
gradio>=5.0.0 gradio>=5.0.0
langchain>=0.3 langchain>=0.3
tinydb>=4.0.0

13
schema/prompt.py Normal file
View File

@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime, timezone
class promptTempleta(BaseModel):
id: Optional[int] = Field(default=None, description="模板ID")
name: Optional[str] = Field(default="", description="模板名称")
description: Optional[str] = Field(default="", description="模板描述")
content: str = Field(default="", min_length=1, description="模板内容")
created_at: str = Field(
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
description="记录创建时间"
)