gzhu-biyesheji/db/init_db.py
carry feaea1fb64 refactor(db): 重命名数据库引擎加载函数
- 将 get_sqlite_engine 函数重命名为 load_sqlite_engine
- 更新了相关模块中的导入和调用
- 这个改动是为了更好地反映函数的实际功能,提高代码可读性
2025-04-18 15:16:29 +08:00

79 lines
2.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
from sqlmodel import SQLModel, create_engine, Session
from sqlmodel import select
from typing import Optional
from pathlib import Path
from dotenv import load_dotenv
from sqlalchemy.engine import Engine
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset_generation import APIProvider
# 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None
def load_sqlite_engine(workdir: str) -> Engine:
"""
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
Args:
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
Returns:
Engine: SQLAlchemy 数据库引擎实例。
"""
global _engine
if not _engine:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "db.sqlite")
# 创建数据库URL
db_url = f"sqlite:///{db_path}"
# 创建数据库引擎
_engine = create_engine(db_url)
return _engine
def initialize_sqlite_db(engine: Engine) -> None:
"""
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
Args:
engine (Engine): SQLAlchemy 数据库引擎实例。
"""
# 创建所有表结构
SQLModel.metadata.create_all(engine)
# 加载环境变量
load_dotenv()
# 从环境变量中获取API相关配置
api_key = os.getenv("API_KEY")
base_url = os.getenv("BASE_URL")
model_id = os.getenv("MODEL_ID")
# 如果所有必要的环境变量都存在,则插入初始数据
if api_key and base_url and model_id:
with Session(engine) as session:
# 查询是否已存在APIProvider记录
statement = select(APIProvider).limit(1)
existing_provider = session.exec(statement).first()
# 如果不存在则插入新的APIProvider记录
if not existing_provider:
provider = APIProvider(
base_url=base_url,
model_id=model_id,
api_key=api_key
)
session.add(provider)
session.commit()
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎
engine = load_sqlite_engine(workdir)
# 初始化数据库
initialize_sqlite_db(engine)