Compare commits

..

No commits in common. "67281fe06a07e7803d5d229d5ada7ac373d55174" and "0a6ae7a4ee7e818a7268400dd9b82044db2237f3" have entirely different histories.

8 changed files with 15 additions and 97 deletions

View File

@ -1,9 +1,3 @@
from .init_db import get_sqlite_engine, initialize_sqlite_db from .init_db import get_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
__all__ = [ __all__ = ['get_engine', 'initialize_sqlite_db']
"get_sqlite_engine",
"initialize_sqlite_db",
"get_prompt_tinydb",
"initialize_prompt_store"
]

View File

@ -1,9 +1,9 @@
import os
import sys
from sqlmodel import SQLModel, create_engine, Session from sqlmodel import SQLModel, create_engine, Session
from sqlmodel import select from sqlmodel import select
from typing import Optional from typing import Optional
import os
from pathlib import Path from pathlib import Path
import sys
from dotenv import load_dotenv from dotenv import load_dotenv
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@ -14,7 +14,7 @@ from schema.dataset_generation import APIProvider
# 全局变量,用于存储数据库引擎实例 # 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None _engine: Optional[Engine] = None
def get_sqlite_engine(workdir: str) -> Engine: def get_engine(workdir: str) -> Engine:
""" """
获取数据库引擎实例如果引擎尚未创建则创建一个新的引擎并返回 获取数据库引擎实例如果引擎尚未创建则创建一个新的引擎并返回
@ -74,6 +74,6 @@ if __name__ == "__main__":
# 定义工作目录路径 # 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎 # 获取数据库引擎
engine = get_sqlite_engine(workdir) engine = get_engine(workdir)
# 初始化数据库 # 初始化数据库
initialize_sqlite_db(engine) initialize_sqlite_db(engine)

View File

@ -1,60 +0,0 @@
import os
import json
import sys
from typing import Optional
from pathlib import Path
from datetime import datetime, timezone
from tinydb import TinyDB, Query
from tinydb.storages import JSONStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.prompt import promptTempleta
# 全局变量用于存储TinyDB实例
_db_instance: Optional[TinyDB] = None
# 自定义存储类用于格式化JSON数据
def get_prompt_tinydb(workdir: str) -> TinyDB:
"""
获取TinyDB实例如果实例尚未创建则创建一个新的并返回
Args:
workdir (str): 工作目录路径用于确定数据库文件的存储位置
Returns:
TinyDB: TinyDB数据库实例
"""
global _db_instance
if not _db_instance:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "prompts.json")
# 创建TinyDB实例
_db_instance = TinyDB(db_path)
return _db_instance
def initialize_prompt_store(db: TinyDB) -> None:
"""
初始化prompt模板存储
Args:
db (TinyDB): TinyDB数据库实例
"""
db.insert(promptTempleta(name="default",
description="默认提示词模板",
content="""项目名为:{ project_name }
请依据以下该项目官方文档的部分内容创造合适的对话数据集用于微调一个了解该项目的小模型的语料要求兼顾文档中间尽可能多的信息点使用中文
文档节选{ content }
按照如下json格式返回{ json }""").model_dump())
# TinyDB不需要显式创建表结构首次使用时自动创建
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库实例
db = get_prompt_tinydb(workdir)
# 初始化prompt存储
initialize_prompt_store(db)

View File

@ -1,7 +1,7 @@
import gradio as gr import gradio as gr
from typing import List from typing import List
from sqlmodel import Session, select from sqlmodel import Session, select
from db import get_sqlite_engine from db import get_engine
from schema import APIProvider from schema import APIProvider
import os import os
from global_var import sql_engine from global_var import sql_engine
@ -30,7 +30,7 @@ def setting_page():
session.add(new_provider) session.add(new_provider)
session.commit() session.commit()
session.refresh(new_provider) session.refresh(new_provider)
return get_providers(), "", "", "" # 返回清空后的输入框值 return get_providers()
except Exception as e: except Exception as e:
raise gr.Error(f"添加失败: {str(e)}") raise gr.Error(f"添加失败: {str(e)}")
@ -108,7 +108,7 @@ def setting_page():
add_button.click( add_button.click(
fn=add_provider, fn=add_provider,
inputs=[model_id_input, base_url_input, api_key_input], inputs=[model_id_input, base_url_input, api_key_input],
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出 outputs=[provider_table]
) )
provider_table.select(select_record, [], [], show_progress="hidden") provider_table.select(select_record, [], [], show_progress="hidden")

View File

@ -1,5 +1,3 @@
from db import get_sqlite_engine from db import get_engine
from db import get_prompt_tinydb
prompt_store = get_prompt_tinydb("workdir") sql_engine = get_engine("workdir")
sql_engine = get_sqlite_engine("workdir")

View File

@ -1,12 +1,12 @@
import gradio as gr import gradio as gr
from frontend.setting_page import setting_page from frontend.setting_page import setting_page
from frontend import * from frontend import *
from db import initialize_sqlite_db,initialize_prompt_store from db import initialize_sqlite_db
from global_var import sql_engine,prompt_store from global_var import sql_engine
if __name__ == "__main__": if __name__ == "__main__":
initialize_sqlite_db(sql_engine) initialize_sqlite_db(sql_engine)
initialize_prompt_store(prompt_store)
with gr.Blocks() as app: with gr.Blocks() as app:
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架") gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
with gr.Tabs(): with gr.Tabs():

View File

@ -2,5 +2,4 @@ openai>=1.0.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0
pydantic>=2.0.0 pydantic>=2.0.0
gradio>=5.0.0 gradio>=5.0.0
langchain>=0.3 langchain>=0.3
tinydb>=4.0.0

View File

@ -1,13 +0,0 @@
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime, timezone
class promptTempleta(BaseModel):
id: Optional[int] = Field(default=None, description="模板ID")
name: Optional[str] = Field(default="", description="模板名称")
description: Optional[str] = Field(default="", description="模板描述")
content: str = Field(default="", min_length=1, description="模板内容")
created_at: str = Field(
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
description="记录创建时间"
)