
- 将 get_sqlite_engine 函数重命名为 load_sqlite_engine - 更新了相关模块中的导入和调用 - 这个改动是为了更好地反映函数的实际功能,提高代码可读性
79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
import os
|
||
import sys
|
||
from sqlmodel import SQLModel, create_engine, Session
|
||
from sqlmodel import select
|
||
from typing import Optional
|
||
from pathlib import Path
|
||
from dotenv import load_dotenv
|
||
from sqlalchemy.engine import Engine
|
||
|
||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||
from schema.dataset_generation import APIProvider
|
||
|
||
# 全局变量,用于存储数据库引擎实例
|
||
_engine: Optional[Engine] = None
|
||
|
||
def load_sqlite_engine(workdir: str) -> Engine:
|
||
"""
|
||
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||
|
||
Args:
|
||
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||
|
||
Returns:
|
||
Engine: SQLAlchemy 数据库引擎实例。
|
||
"""
|
||
global _engine
|
||
if not _engine:
|
||
# 创建数据库目录(如果不存在)
|
||
db_dir = os.path.join(workdir, "db")
|
||
os.makedirs(db_dir, exist_ok=True)
|
||
# 定义数据库文件路径
|
||
db_path = os.path.join(db_dir, "db.sqlite")
|
||
# 创建数据库URL
|
||
db_url = f"sqlite:///{db_path}"
|
||
# 创建数据库引擎
|
||
_engine = create_engine(db_url)
|
||
return _engine
|
||
|
||
def initialize_sqlite_db(engine: Engine) -> None:
|
||
"""
|
||
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
|
||
|
||
Args:
|
||
engine (Engine): SQLAlchemy 数据库引擎实例。
|
||
"""
|
||
# 创建所有表结构
|
||
SQLModel.metadata.create_all(engine)
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
# 从环境变量中获取API相关配置
|
||
api_key = os.getenv("API_KEY")
|
||
base_url = os.getenv("BASE_URL")
|
||
model_id = os.getenv("MODEL_ID")
|
||
|
||
# 如果所有必要的环境变量都存在,则插入初始数据
|
||
if api_key and base_url and model_id:
|
||
with Session(engine) as session:
|
||
# 查询是否已存在APIProvider记录
|
||
statement = select(APIProvider).limit(1)
|
||
existing_provider = session.exec(statement).first()
|
||
|
||
# 如果不存在,则插入新的APIProvider记录
|
||
if not existing_provider:
|
||
provider = APIProvider(
|
||
base_url=base_url,
|
||
model_id=model_id,
|
||
api_key=api_key
|
||
)
|
||
session.add(provider)
|
||
session.commit()
|
||
|
||
if __name__ == "__main__":
|
||
# 定义工作目录路径
|
||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||
# 获取数据库引擎
|
||
engine = load_sqlite_engine(workdir)
|
||
# 初始化数据库
|
||
initialize_sqlite_db(engine) |