From 10b4c29bda4cbd7f7496883a51001b8514ce820f Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Sun, 6 Apr 2025 21:26:53 +0800 Subject: [PATCH] =?UTF-8?q?docs(db):=20=E4=BF=AE=E6=94=B9=E4=BA=86?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- db/init_db.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/db/init_db.py b/db/init_db.py index 02dc0a0..caf9274 100644 --- a/db/init_db.py +++ b/db/init_db.py @@ -7,36 +7,60 @@ import sys from dotenv import load_dotenv from sqlalchemy.engine import Engine -# 项目路径配置 +# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块 sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema.dataset_generation import APIProvider -# 全局引擎实例(可选) +# 全局变量,用于存储数据库引擎实例 _engine: Optional[Engine] = None def get_engine(workdir: str) -> Engine: + """ + 获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。 + + Args: + workdir (str): 工作目录路径,用于确定数据库文件的存储位置。 + + Returns: + Engine: SQLAlchemy 数据库引擎实例。 + """ global _engine if not _engine: + # 创建数据库目录(如果不存在) db_dir = os.path.join(workdir, "db") os.makedirs(db_dir, exist_ok=True) + # 定义数据库文件路径 db_path = os.path.join(db_dir, "db.sqlite") + # 创建数据库URL db_url = f"sqlite:///{db_path}" + # 创建数据库引擎 _engine = create_engine(db_url) return _engine def initialize_db(engine: Engine) -> None: + """ + 初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。 + + Args: + engine (Engine): SQLAlchemy 数据库引擎实例。 + """ + # 创建所有表结构 SQLModel.metadata.create_all(engine) + # 加载环境变量 load_dotenv() + # 从环境变量中获取API相关配置 api_key = os.getenv("API_KEY") base_url = os.getenv("BASE_URL") model_id = os.getenv("MODEL_ID") + # 如果所有必要的环境变量都存在,则插入初始数据 if api_key and base_url and model_id: with Session(engine) as session: - # 使用新的 select() 语法查询 + # 查询是否已存在APIProvider记录 statement = select(APIProvider).limit(1) existing_provider = session.exec(statement).first() + # 如果不存在,则插入新的APIProvider记录 if not existing_provider: provider = APIProvider( base_url=base_url, @@ -47,6 +71,9 @@ def initialize_db(engine: Engine) -> None: session.commit() if __name__ == "__main__": + # 定义工作目录路径 workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") + # 获取数据库引擎 engine = get_engine(workdir) + # 初始化数据库 initialize_db(engine) \ No newline at end of file