优化了users的鉴权逻辑,使用了依赖注入的方式判断管理员

This commit is contained in:
carry 2025-01-21 22:14:03 +08:00
parent f1cdbab0f4
commit a90838b79f
3 changed files with 100 additions and 109 deletions

View File

@ -2,6 +2,7 @@ from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from typing import Optional from typing import Optional
from schemas.auth import TokenData from schemas.auth import TokenData
from schemas.user import UserRole
from services.auth_service import verify_token from services.auth_service import verify_token
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
@ -16,3 +17,22 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData:
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
return token_data return token_data
async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenData:
"""获取当前用户"""
token_data = verify_token(token)
if token_data is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
if token_data.role not in [UserRole.SYSTEM_ADMIN.value, UserRole.ADMIN.value]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You are not admin",
headers={"WWW-Authenticate": "Bearer"},
)
return token_data

View File

@ -1,9 +1,10 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from typing import List, Optional from typing import List, Optional
from schemas.auth import TokenData from schemas.auth import TokenData
from schemas.user import UserCreate, UserUpdate, UserResponse from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole
from routes.depends import get_current_user from routes.depends import get_current_user,get_current_admin
from services.user_services import get_user_by_id from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user
from services.db import get_db_session
router = APIRouter(tags=["users"]) router = APIRouter(tags=["users"])
@ -20,68 +21,39 @@ async def get_users(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="User not found" detail="User not found"
) )
if current_user.role not in ["system_admin", "admin"]: async with get_db_session() as session:
raise HTTPException( skip = (page - 1) * limit
status_code=status.HTTP_403_FORBIDDEN, users = await get_users(session, skip=skip, limit=limit)
detail="Only admin can access user list" if role:
) users = [user for user in users if user.role == role]
# 实现获取用户列表逻辑 return users
pass
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_user( async def create_user(
user_data: UserCreate, user_data: UserCreate,
current_user_token: TokenData = Depends(get_current_user) current_user_token: TokenData = Depends(get_current_admin)
): ):
current_user = await get_user_by_id(current_user_token.id) async with get_db_session() as session:
if current_user is None: return await create_user(session, user_data)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if current_user.role not in ["system_admin", "admin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only admin can create users"
)
# 实现创建用户逻辑
pass
@router.put("/{user_id}", response_model=UserResponse) @router.put("/{user_id}", response_model=UserResponse)
async def update_user( async def update_user(
user_id: int, user_id: int,
user_data: UserUpdate, user_data: UserUpdate,
current_user_token: TokenData = Depends(get_current_user) current_user_token: TokenData = Depends(get_current_admin)
): ):
current_user = await get_user_by_id(current_user_token.id) async with get_db_session() as session:
if current_user is None: return await update_user(session, user_id, user_data)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if current_user.role not in ["system_admin", "admin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only admin can update users"
)
# 实现更新用户逻辑
pass
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user( async def delete_user(
user_id: int, user_id: int,
current_user_token: TokenData = Depends(get_current_user) current_user_token: TokenData = Depends(get_current_admin)
): ):
current_user = await get_user_by_id(current_user_token.id) async with get_db_session() as session:
if current_user is None: success = await delete_user(session, user_id)
raise HTTPException( if not success:
status_code=status.HTTP_404_NOT_FOUND, raise HTTPException(
detail="User not found" status_code=status.HTTP_404_NOT_FOUND,
) detail="User not found"
if current_user.role not in ["system_admin", "admin"]: )
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only admin can delete users"
)
# 实现删除用户逻辑
pass

View File

@ -8,66 +8,65 @@ from schemas.user import UserCreate, UserUpdate, UserResponse
# 创建一个密码上下文对象,指定使用 bcrypt 加密算法 # 创建一个密码上下文对象,指定使用 bcrypt 加密算法
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class UserServices:
@staticmethod async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse:
async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse: """创建用户"""
"""创建用户""" hashed_password = pwd_context.hash(user_data.password)
hashed_password = pwd_context.hash(user_data.password) user = User(
user = User( username=user_data.username,
username=user_data.username, password=hashed_password,
password=hashed_password, role=user_data.role,
role=user_data.role, description=user_data.description
description=user_data.description )
) session.add(user)
session.add(user) await session.commit()
await session.commit() await session.refresh(user)
await session.refresh(user) return UserResponse.from_orm(user)
return UserResponse.from_orm(user)
@staticmethod
async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]:
"""根据ID获取用户"""
result = await session.execute(select(User).where(User.id == user_id))
user = result.scalars().first()
return UserResponse.from_orm(user) if user else None
@staticmethod async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]:
async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]: """根据ID获取用户"""
"""获取用户列表""" result = await session.execute(select(User).where(User.id == user_id))
result = await session.execute(select(User).offset(skip).limit(limit)) user = result.scalars().first()
users = result.scalars().all() return UserResponse.from_orm(user) if user else None
return [UserResponse.from_orm(user) for user in users]
@staticmethod
async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]:
"""更新用户信息"""
await session.execute(
update(User)
.where(User.id == user_id)
.values(**user_data.dict(exclude_unset=True))
)
await session.commit()
return await UserServices.get_user(session, user_id)
@staticmethod async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]:
async def delete_user(session: AsyncSession, user_id: int) -> bool: """获取用户列表"""
"""删除用户""" result = await session.execute(select(User).offset(skip).limit(limit))
result = await session.execute(delete(User).where(User.id == user_id)) users = result.scalars().all()
await session.commit() return [UserResponse.from_orm(user) for user in users]
return result.rowcount > 0
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证输入的明文密码是否与存储的哈希密码匹配"""
return pwd_context.verify(plain_password, hashed_password)
@staticmethod async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]:
async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]: """更新用户信息"""
"""验证用户登录""" await session.execute(
result = await session.execute(select(User).where(User.username == username)) update(User)
user = result.scalars().first() .where(User.id == user_id)
if not user: .values(**user_data.dict(exclude_unset=True))
return None )
if not UserServices.verify_password(password, user.password): await session.commit()
return None return await get_user(session, user_id)
return UserResponse.from_orm(user)
async def delete_user(session: AsyncSession, user_id: int) -> bool:
"""删除用户"""
result = await session.execute(delete(User).where(User.id == user_id))
await session.commit()
return result.rowcount > 0
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证输入的明文密码是否与存储的哈希密码匹配"""
return pwd_context.verify(plain_password, hashed_password)
async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]:
"""验证用户登录"""
result = await session.execute(select(User).where(User.username == username))
user = result.scalars().first()
if not user:
return None
if not verify_password(password, user.password):
return None
return UserResponse.from_orm(user)