diff --git a/routes/auth.py b/routes/auth.py index c4bb51a..93016ec 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest from services.auth import create_tokens_response, verify_token, refresh_tokens from services.user import authenticate_user -from services.db import get_db_session +from services.db import get_db_session_dep router = APIRouter(tags=["auth"]) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -12,7 +12,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @router.post("/login", response_model=TokenResponse) async def login( login_data: LoginRequest, - session: AsyncSession = Depends(get_db_session) + session: AsyncSession = Depends(get_db_session_dep) ): user = await authenticate_user(session, login_data.username, login_data.password) if not user: diff --git a/routes/users.py b/routes/users.py index cbed7c9..bb67c61 100644 --- a/routes/users.py +++ b/routes/users.py @@ -1,59 +1,55 @@ from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Optional from schemas.auth import TokenPayload from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole from routes.depends import get_current_user,get_current_admin -from services.user import get_users,create_user,update_user,delete_user#,get_user_by_id -from services.db import get_db_session +import services.user as user_service + +from services.db import get_db_session_dep router = APIRouter(tags=["users"]) @router.get("/", response_model=List[UserResponse]) -async def get_users( +async def get_users_list( page: int = 1, - limit: int = 10, + limit: int = 100, role: Optional[str] = None, - current_user_token: TokenPayload = Depends(get_current_user) + current_user_token: TokenPayload = Depends(get_current_user), + session: AsyncSession = Depends(get_db_session_dep) ): - # current_user = await get_user_by_id(current_user_token.id) - # if current_user is None: - # raise HTTPException( - # status_code=status.HTTP_404_NOT_FOUND, - # detail="User not found" - # ) - async with get_db_session() as session: - skip = (page - 1) * limit - users = await get_users(session, skip=skip, limit=limit) - if role: - users = [user for user in users if user.role == role] - return users + skip = (page - 1) * limit + users = await user_service.get_users_list(session, skip=skip, limit=limit) + if role: + users = [user for user in users if user.role == role] + return users @router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def create_user( user_data: UserCreate, - current_user_token: TokenPayload = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin), + session: AsyncSession = Depends(get_db_session_dep) ): - async with get_db_session() as session: - return await create_user(session, user_data) + return await user_service.create_user(session, user_data) @router.put("/{user_id}", response_model=UserResponse) async def update_user( user_id: int, user_data: UserUpdate, - current_user_token: TokenPayload = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin), + session: AsyncSession = Depends(get_db_session_dep) ): - async with get_db_session() as session: - return await update_user(session, user_id, user_data) + return await user_service.update_user(session, user_id, user_data) @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_user( user_id: int, - current_user_token: TokenPayload = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin), + session: AsyncSession = Depends(get_db_session_dep) ): - async with get_db_session() as session: - success = await delete_user(session, user_id) - if not success: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) + success = await user_service.delete_user(session, user_id) + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) diff --git a/services/auth.py b/services/auth.py index 30a9b44..03352cc 100644 --- a/services/auth.py +++ b/services/auth.py @@ -14,9 +14,9 @@ def get_current_time() -> int: """获取当前UTC时间戳""" return int(time.time()) -def create_token(user_id: int, username: str, role: str, expire_delta: int) -> str: +def create_token(user_id: int, username: str, role: str, expire_delta) -> str: """创建JWT token""" - expire = get_current_time() + expire_delta + expire = get_current_time() + int(expire_delta.total_seconds()) to_encode = { "id": user_id, "username": username, diff --git a/services/db.py b/services/db.py index 6e3f9f6..eda231f 100644 --- a/services/db.py +++ b/services/db.py @@ -24,8 +24,7 @@ def get_db_engine() -> AsyncEngine: raise RuntimeError("Database engine not initialized") return _engine -@asynccontextmanager -async def get_db_session() -> AsyncGenerator[AsyncSession, None]: +async def get_db_session() -> AsyncSession: """获取数据库会话""" async with AsyncSession(get_db_engine()) as session: try: @@ -35,6 +34,16 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: await session.rollback() raise +async def get_db_session_dep() -> AsyncSession: + """FastAPI依赖注入使用的数据库会话获取函数""" + async with AsyncSession(get_db_engine()) as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + async def init_db(engine: AsyncEngine): """初始化数据库""" global _engine diff --git a/services/user.py b/services/user.py index 90c6828..c432427 100644 --- a/services/user.py +++ b/services/user.py @@ -30,7 +30,7 @@ async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse return UserResponse.from_orm(user) if user else None -async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]: +async def get_users_list(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]: """获取用户列表""" result = await session.execute(select(User).offset(skip).limit(limit)) users = result.scalars().all()