docs(db): 修改了代码注释

This commit is contained in:
carry 2025-04-06 21:26:53 +08:00
parent b1e98ca913
commit 10b4c29bda

View File

@ -7,36 +7,60 @@ import sys
from dotenv import load_dotenv
from sqlalchemy.engine import Engine
# 项目路径配置
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset_generation import APIProvider
# 全局引擎实例(可选)
# 全局变量,用于存储数据库引擎实例
_engine: Optional[Engine] = None
def get_engine(workdir: str) -> Engine:
"""
获取数据库引擎实例如果引擎尚未创建则创建一个新的引擎并返回
Args:
workdir (str): 工作目录路径用于确定数据库文件的存储位置
Returns:
Engine: SQLAlchemy 数据库引擎实例
"""
global _engine
if not _engine:
# 创建数据库目录(如果不存在)
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
# 定义数据库文件路径
db_path = os.path.join(db_dir, "db.sqlite")
# 创建数据库URL
db_url = f"sqlite:///{db_path}"
# 创建数据库引擎
_engine = create_engine(db_url)
return _engine
def initialize_db(engine: Engine) -> None:
"""
初始化数据库创建所有表结构并插入初始数据如果不存在
Args:
engine (Engine): SQLAlchemy 数据库引擎实例
"""
# 创建所有表结构
SQLModel.metadata.create_all(engine)
# 加载环境变量
load_dotenv()
# 从环境变量中获取API相关配置
api_key = os.getenv("API_KEY")
base_url = os.getenv("BASE_URL")
model_id = os.getenv("MODEL_ID")
# 如果所有必要的环境变量都存在,则插入初始数据
if api_key and base_url and model_id:
with Session(engine) as session:
# 使用新的 select() 语法查询
# 查询是否已存在APIProvider记录
statement = select(APIProvider).limit(1)
existing_provider = session.exec(statement).first()
# 如果不存在则插入新的APIProvider记录
if not existing_provider:
provider = APIProvider(
base_url=base_url,
@ -47,6 +71,9 @@ def initialize_db(engine: Engine) -> None:
session.commit()
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取数据库引擎
engine = get_engine(workdir)
# 初始化数据库
initialize_db(engine)