docs(db): 修改了代码注释
This commit is contained in:
parent
b1e98ca913
commit
10b4c29bda
@ -7,36 +7,60 @@ import sys
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
# 项目路径配置
|
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from schema.dataset_generation import APIProvider
|
from schema.dataset_generation import APIProvider
|
||||||
|
|
||||||
# 全局引擎实例(可选)
|
# 全局变量,用于存储数据库引擎实例
|
||||||
_engine: Optional[Engine] = None
|
_engine: Optional[Engine] = None
|
||||||
|
|
||||||
def get_engine(workdir: str) -> Engine:
|
def get_engine(workdir: str) -> Engine:
|
||||||
|
"""
|
||||||
|
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Engine: SQLAlchemy 数据库引擎实例。
|
||||||
|
"""
|
||||||
global _engine
|
global _engine
|
||||||
if not _engine:
|
if not _engine:
|
||||||
|
# 创建数据库目录(如果不存在)
|
||||||
db_dir = os.path.join(workdir, "db")
|
db_dir = os.path.join(workdir, "db")
|
||||||
os.makedirs(db_dir, exist_ok=True)
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
|
# 定义数据库文件路径
|
||||||
db_path = os.path.join(db_dir, "db.sqlite")
|
db_path = os.path.join(db_dir, "db.sqlite")
|
||||||
|
# 创建数据库URL
|
||||||
db_url = f"sqlite:///{db_path}"
|
db_url = f"sqlite:///{db_path}"
|
||||||
|
# 创建数据库引擎
|
||||||
_engine = create_engine(db_url)
|
_engine = create_engine(db_url)
|
||||||
return _engine
|
return _engine
|
||||||
|
|
||||||
def initialize_db(engine: Engine) -> None:
|
def initialize_db(engine: Engine) -> None:
|
||||||
|
"""
|
||||||
|
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine (Engine): SQLAlchemy 数据库引擎实例。
|
||||||
|
"""
|
||||||
|
# 创建所有表结构
|
||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
|
# 加载环境变量
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
# 从环境变量中获取API相关配置
|
||||||
api_key = os.getenv("API_KEY")
|
api_key = os.getenv("API_KEY")
|
||||||
base_url = os.getenv("BASE_URL")
|
base_url = os.getenv("BASE_URL")
|
||||||
model_id = os.getenv("MODEL_ID")
|
model_id = os.getenv("MODEL_ID")
|
||||||
|
|
||||||
|
# 如果所有必要的环境变量都存在,则插入初始数据
|
||||||
if api_key and base_url and model_id:
|
if api_key and base_url and model_id:
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
# 使用新的 select() 语法查询
|
# 查询是否已存在APIProvider记录
|
||||||
statement = select(APIProvider).limit(1)
|
statement = select(APIProvider).limit(1)
|
||||||
existing_provider = session.exec(statement).first()
|
existing_provider = session.exec(statement).first()
|
||||||
|
|
||||||
|
# 如果不存在,则插入新的APIProvider记录
|
||||||
if not existing_provider:
|
if not existing_provider:
|
||||||
provider = APIProvider(
|
provider = APIProvider(
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
@ -47,6 +71,9 @@ def initialize_db(engine: Engine) -> None:
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# 定义工作目录路径
|
||||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||||
|
# 获取数据库引擎
|
||||||
engine = get_engine(workdir)
|
engine = get_engine(workdir)
|
||||||
|
# 初始化数据库
|
||||||
initialize_db(engine)
|
initialize_db(engine)
|
Loading…
x
Reference in New Issue
Block a user