diff --git a/main.py b/main.py index 6704193..0346ec8 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,8 @@ import logging from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from config import DATABASE_CONFIG, JWT_CONFIG, LOGGING_CONFIG, SYSTEM_ADMIN_CONFIG -from sqlalchemy.ext.asyncio import create_async_engine -from services.init_db import init_db +from config import JWT_CONFIG, LOGGING_CONFIG, SYSTEM_ADMIN_CONFIG +from services.db import create_db_engine, init_db, close_db_connection from routes.auth import auth_router from routes.users import users_router @@ -22,20 +21,20 @@ app = FastAPI( version="1.0.0" ) -# 创建数据库引擎 -engine = create_async_engine( - f"mysql+asyncmy://{DATABASE_CONFIG['user']}:{DATABASE_CONFIG['password']}@" - f"{DATABASE_CONFIG['host']}:{DATABASE_CONFIG['port']}/{DATABASE_CONFIG['database']}", - echo=True -) - -# 初始化数据库 +# 数据库初始化 @app.on_event("startup") async def startup_event(): logger.info("Initializing database...") + engine = create_db_engine() await init_db(engine) logger.info("Database initialized successfully") +@app.on_event("shutdown") +async def shutdown_event(): + logger.info("Closing database connections...") + await close_db_connection() + logger.info("Database connections closed") + # 注册路由 app.include_router(auth_router, prefix="/api/auth", tags=["auth"]) app.include_router(users_router, prefix="/api/users", tags=["users"]) diff --git a/services/db.py b/services/db.py new file mode 100644 index 0000000..3f7c6cb --- /dev/null +++ b/services/db.py @@ -0,0 +1,67 @@ +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +from ..models.user import Base, User, UserRole +from sqlalchemy import select +from contextlib import asynccontextmanager +from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG +from services.auth import get_password_hash +from typing import AsyncGenerator + +# 全局数据库引擎实例 +_engine: AsyncEngine | None = None + +def create_db_engine() -> AsyncEngine: + """创建数据库引擎""" + return create_async_engine( + f"mysql+asyncmy://{DATABASE_CONFIG['user']}:{DATABASE_CONFIG['password']}@" + f"{DATABASE_CONFIG['host']}:{DATABASE_CONFIG['port']}/{DATABASE_CONFIG['database']}", + echo=True + ) + +def get_db_engine() -> AsyncEngine: + """获取全局数据库引擎实例""" + if _engine is None: + raise RuntimeError("Database engine not initialized") + return _engine + +@asynccontextmanager +async def get_db_session() -> AsyncGenerator[AsyncSession, None]: + """获取数据库会话""" + async with AsyncSession(get_db_engine()) as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + +async def init_db(engine: AsyncEngine): + """初始化数据库""" + global _engine + _engine = engine + + async with engine.begin() as conn: + # 创建所有表 + await conn.run_sync(Base.metadata.create_all) + + # 检查系统管理员是否存在 + result = await conn.execute( + select(User).where(User.role == UserRole.SYSTEM_ADMIN) + ) + if not result.scalars().first(): + # 创建默认系统管理员 + admin = User( + username=SYSTEM_ADMIN_CONFIG['username'], + password=get_password_hash(SYSTEM_ADMIN_CONFIG['password']), + role=UserRole.SYSTEM_ADMIN, + description=SYSTEM_ADMIN_CONFIG['description'] + ) + conn.add(admin) + await conn.commit() + +async def close_db_connection(): + """关闭数据库连接""" + global _engine + if _engine is not None: + await _engine.dispose() + _engine = None diff --git a/services/init_db.py b/services/init_db.py deleted file mode 100644 index 23c8125..0000000 --- a/services/init_db.py +++ /dev/null @@ -1,26 +0,0 @@ -from sqlalchemy.ext.asyncio import AsyncEngine -from ..models.user import Base, User, UserRole -from sqlalchemy import select -from config import SYSTEM_ADMIN_CONFIG -from services.auth import get_password_hash - -async def init_db(engine: AsyncEngine): - """初始化数据库""" - async with engine.begin() as conn: - # 创建所有表 - await conn.run_sync(Base.metadata.create_all) - - # 检查系统管理员是否存在 - result = await conn.execute( - select(User).where(User.role == UserRole.SYSTEM_ADMIN) - ) - if not result.scalars().first(): - # 创建默认系统管理员 - admin = User( - username=SYSTEM_ADMIN_CONFIG['username'], - password=get_password_hash(SYSTEM_ADMIN_CONFIG['password']), - role=UserRole.SYSTEM_ADMIN, - description=SYSTEM_ADMIN_CONFIG['description'] - ) - conn.add(admin) - await conn.commit()