From a90838b79fb6735066d51b5bc5b96aa97cf9fb5f Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 22:14:03 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BA=86users=E7=9A=84?= =?UTF-8?q?=E9=89=B4=E6=9D=83=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E4=BA=86=E4=BE=9D=E8=B5=96=E6=B3=A8=E5=85=A5=E7=9A=84=E6=96=B9?= =?UTF-8?q?=E5=BC=8F=E5=88=A4=E6=96=AD=E7=AE=A1=E7=90=86=E5=91=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- routes/depends.py | 20 +++++++ routes/users.py | 76 ++++++++----------------- services/user_services.py | 113 +++++++++++++++++++------------------- 3 files changed, 100 insertions(+), 109 deletions(-) diff --git a/routes/depends.py b/routes/depends.py index 43aebad..1f2d261 100644 --- a/routes/depends.py +++ b/routes/depends.py @@ -2,6 +2,7 @@ from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from typing import Optional from schemas.auth import TokenData +from schemas.user import UserRole from services.auth_service import verify_token oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") @@ -16,3 +17,22 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData: headers={"WWW-Authenticate": "Bearer"}, ) return token_data + + + +async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenData: + """获取当前用户""" + token_data = verify_token(token) + if token_data is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + if token_data.role not in [UserRole.SYSTEM_ADMIN.value, UserRole.ADMIN.value]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You are not admin", + headers={"WWW-Authenticate": "Bearer"}, + ) + return token_data \ No newline at end of file diff --git a/routes/users.py b/routes/users.py index 92a648d..7e598b6 100644 --- a/routes/users.py +++ b/routes/users.py @@ -1,9 +1,10 @@ from fastapi import APIRouter, Depends, HTTPException, status from typing import List, Optional from schemas.auth import TokenData -from schemas.user import UserCreate, UserUpdate, UserResponse -from routes.depends import get_current_user -from services.user_services import get_user_by_id +from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole +from routes.depends import get_current_user,get_current_admin +from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user +from services.db import get_db_session router = APIRouter(tags=["users"]) @@ -20,68 +21,39 @@ async def get_users( status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) - if current_user.role not in ["system_admin", "admin"]: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin can access user list" - ) - # 实现获取用户列表逻辑 - pass + async with get_db_session() as session: + skip = (page - 1) * limit + users = await get_users(session, skip=skip, limit=limit) + if role: + users = [user for user in users if user.role == role] + return users @router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def create_user( user_data: UserCreate, - current_user_token: TokenData = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_admin) ): - current_user = await get_user_by_id(current_user_token.id) - if current_user is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) - if current_user.role not in ["system_admin", "admin"]: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin can create users" - ) - # 实现创建用户逻辑 - pass + async with get_db_session() as session: + return await create_user(session, user_data) @router.put("/{user_id}", response_model=UserResponse) async def update_user( user_id: int, user_data: UserUpdate, - current_user_token: TokenData = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_admin) ): - current_user = await get_user_by_id(current_user_token.id) - if current_user is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) - if current_user.role not in ["system_admin", "admin"]: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin can update users" - ) - # 实现更新用户逻辑 - pass + async with get_db_session() as session: + return await update_user(session, user_id, user_data) @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_user( user_id: int, - current_user_token: TokenData = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_admin) ): - current_user = await get_user_by_id(current_user_token.id) - if current_user is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) - if current_user.role not in ["system_admin", "admin"]: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin can delete users" - ) - # 实现删除用户逻辑 - pass + async with get_db_session() as session: + success = await delete_user(session, user_id) + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) diff --git a/services/user_services.py b/services/user_services.py index a9a77db..79e1ce5 100644 --- a/services/user_services.py +++ b/services/user_services.py @@ -8,66 +8,65 @@ from schemas.user import UserCreate, UserUpdate, UserResponse # 创建一个密码上下文对象,指定使用 bcrypt 加密算法 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -class UserServices: - @staticmethod - async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse: - """创建用户""" - hashed_password = pwd_context.hash(user_data.password) - user = User( - username=user_data.username, - password=hashed_password, - role=user_data.role, - description=user_data.description - ) - session.add(user) - await session.commit() - await session.refresh(user) - return UserResponse.from_orm(user) + +async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse: + """创建用户""" + hashed_password = pwd_context.hash(user_data.password) + user = User( + username=user_data.username, + password=hashed_password, + role=user_data.role, + description=user_data.description + ) + session.add(user) + await session.commit() + await session.refresh(user) + return UserResponse.from_orm(user) - @staticmethod - async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]: - """根据ID获取用户""" - result = await session.execute(select(User).where(User.id == user_id)) - user = result.scalars().first() - return UserResponse.from_orm(user) if user else None - @staticmethod - async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]: - """获取用户列表""" - result = await session.execute(select(User).offset(skip).limit(limit)) - users = result.scalars().all() - return [UserResponse.from_orm(user) for user in users] +async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]: + """根据ID获取用户""" + result = await session.execute(select(User).where(User.id == user_id)) + user = result.scalars().first() + return UserResponse.from_orm(user) if user else None - @staticmethod - async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]: - """更新用户信息""" - await session.execute( - update(User) - .where(User.id == user_id) - .values(**user_data.dict(exclude_unset=True)) - ) - await session.commit() - return await UserServices.get_user(session, user_id) - @staticmethod - async def delete_user(session: AsyncSession, user_id: int) -> bool: - """删除用户""" - result = await session.execute(delete(User).where(User.id == user_id)) - await session.commit() - return result.rowcount > 0 +async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]: + """获取用户列表""" + result = await session.execute(select(User).offset(skip).limit(limit)) + users = result.scalars().all() + return [UserResponse.from_orm(user) for user in users] - @staticmethod - def verify_password(plain_password: str, hashed_password: str) -> bool: - """验证输入的明文密码是否与存储的哈希密码匹配""" - return pwd_context.verify(plain_password, hashed_password) - @staticmethod - async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]: - """验证用户登录""" - result = await session.execute(select(User).where(User.username == username)) - user = result.scalars().first() - if not user: - return None - if not UserServices.verify_password(password, user.password): - return None - return UserResponse.from_orm(user) +async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]: + """更新用户信息""" + await session.execute( + update(User) + .where(User.id == user_id) + .values(**user_data.dict(exclude_unset=True)) + ) + await session.commit() + return await get_user(session, user_id) + + +async def delete_user(session: AsyncSession, user_id: int) -> bool: + """删除用户""" + result = await session.execute(delete(User).where(User.id == user_id)) + await session.commit() + return result.rowcount > 0 + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """验证输入的明文密码是否与存储的哈希密码匹配""" + return pwd_context.verify(plain_password, hashed_password) + + +async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]: + """验证用户登录""" + result = await session.execute(select(User).where(User.username == username)) + user = result.scalars().first() + if not user: + return None + if not verify_password(password, user.password): + return None + return UserResponse.from_orm(user)