docs(db): 修改了代码注释
This commit is contained in:
parent
b1e98ca913
commit
10b4c29bda
@ -7,36 +7,60 @@ import sys
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
# 项目路径配置
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset_generation import APIProvider
|
||||
|
||||
# 全局引擎实例(可选)
|
||||
# 全局变量,用于存储数据库引擎实例
|
||||
_engine: Optional[Engine] = None
|
||||
|
||||
def get_engine(workdir: str) -> Engine:
|
||||
"""
|
||||
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||||
|
||||
Returns:
|
||||
Engine: SQLAlchemy 数据库引擎实例。
|
||||
"""
|
||||
global _engine
|
||||
if not _engine:
|
||||
# 创建数据库目录(如果不存在)
|
||||
db_dir = os.path.join(workdir, "db")
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
# 定义数据库文件路径
|
||||
db_path = os.path.join(db_dir, "db.sqlite")
|
||||
# 创建数据库URL
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
# 创建数据库引擎
|
||||
_engine = create_engine(db_url)
|
||||
return _engine
|
||||
|
||||
def initialize_db(engine: Engine) -> None:
|
||||
"""
|
||||
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
|
||||
|
||||
Args:
|
||||
engine (Engine): SQLAlchemy 数据库引擎实例。
|
||||
"""
|
||||
# 创建所有表结构
|
||||
SQLModel.metadata.create_all(engine)
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
# 从环境变量中获取API相关配置
|
||||
api_key = os.getenv("API_KEY")
|
||||
base_url = os.getenv("BASE_URL")
|
||||
model_id = os.getenv("MODEL_ID")
|
||||
|
||||
# 如果所有必要的环境变量都存在,则插入初始数据
|
||||
if api_key and base_url and model_id:
|
||||
with Session(engine) as session:
|
||||
# 使用新的 select() 语法查询
|
||||
# 查询是否已存在APIProvider记录
|
||||
statement = select(APIProvider).limit(1)
|
||||
existing_provider = session.exec(statement).first()
|
||||
|
||||
# 如果不存在,则插入新的APIProvider记录
|
||||
if not existing_provider:
|
||||
provider = APIProvider(
|
||||
base_url=base_url,
|
||||
@ -47,6 +71,9 @@ def initialize_db(engine: Engine) -> None:
|
||||
session.commit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取数据库引擎
|
||||
engine = get_engine(workdir)
|
||||
# 初始化数据库
|
||||
initialize_db(engine)
|
Loading…
x
Reference in New Issue
Block a user