优化了users的鉴权逻辑,使用了依赖注入的方式判断管理员
This commit is contained in:
parent
f1cdbab0f4
commit
a90838b79f
@ -2,6 +2,7 @@ from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from typing import Optional
|
||||
from schemas.auth import TokenData
|
||||
from schemas.user import UserRole
|
||||
from services.auth_service import verify_token
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
@ -16,3 +17,22 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData:
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return token_data
|
||||
|
||||
|
||||
|
||||
async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenData:
|
||||
"""获取当前用户"""
|
||||
token_data = verify_token(token)
|
||||
if token_data is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
if token_data.role not in [UserRole.SYSTEM_ADMIN.value, UserRole.ADMIN.value]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You are not admin",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return token_data
|
@ -1,9 +1,10 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from typing import List, Optional
|
||||
from schemas.auth import TokenData
|
||||
from schemas.user import UserCreate, UserUpdate, UserResponse
|
||||
from routes.depends import get_current_user
|
||||
from services.user_services import get_user_by_id
|
||||
from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole
|
||||
from routes.depends import get_current_user,get_current_admin
|
||||
from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user
|
||||
from services.db import get_db_session
|
||||
|
||||
router = APIRouter(tags=["users"])
|
||||
|
||||
@ -20,68 +21,39 @@ async def get_users(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
if current_user.role not in ["system_admin", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only admin can access user list"
|
||||
)
|
||||
# 实现获取用户列表逻辑
|
||||
pass
|
||||
async with get_db_session() as session:
|
||||
skip = (page - 1) * limit
|
||||
users = await get_users(session, skip=skip, limit=limit)
|
||||
if role:
|
||||
users = [user for user in users if user.role == role]
|
||||
return users
|
||||
|
||||
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
user_data: UserCreate,
|
||||
current_user_token: TokenData = Depends(get_current_user)
|
||||
current_user_token: TokenData = Depends(get_current_admin)
|
||||
):
|
||||
current_user = await get_user_by_id(current_user_token.id)
|
||||
if current_user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
if current_user.role not in ["system_admin", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only admin can create users"
|
||||
)
|
||||
# 实现创建用户逻辑
|
||||
pass
|
||||
async with get_db_session() as session:
|
||||
return await create_user(session, user_data)
|
||||
|
||||
@router.put("/{user_id}", response_model=UserResponse)
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
user_data: UserUpdate,
|
||||
current_user_token: TokenData = Depends(get_current_user)
|
||||
current_user_token: TokenData = Depends(get_current_admin)
|
||||
):
|
||||
current_user = await get_user_by_id(current_user_token.id)
|
||||
if current_user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
if current_user.role not in ["system_admin", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only admin can update users"
|
||||
)
|
||||
# 实现更新用户逻辑
|
||||
pass
|
||||
async with get_db_session() as session:
|
||||
return await update_user(session, user_id, user_data)
|
||||
|
||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
current_user_token: TokenData = Depends(get_current_user)
|
||||
current_user_token: TokenData = Depends(get_current_admin)
|
||||
):
|
||||
current_user = await get_user_by_id(current_user_token.id)
|
||||
if current_user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
if current_user.role not in ["system_admin", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only admin can delete users"
|
||||
)
|
||||
# 实现删除用户逻辑
|
||||
pass
|
||||
async with get_db_session() as session:
|
||||
success = await delete_user(session, user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
@ -8,66 +8,65 @@ from schemas.user import UserCreate, UserUpdate, UserResponse
|
||||
# 创建一个密码上下文对象,指定使用 bcrypt 加密算法
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
class UserServices:
|
||||
@staticmethod
|
||||
async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse:
|
||||
"""创建用户"""
|
||||
hashed_password = pwd_context.hash(user_data.password)
|
||||
user = User(
|
||||
username=user_data.username,
|
||||
password=hashed_password,
|
||||
role=user_data.role,
|
||||
description=user_data.description
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserResponse.from_orm(user)
|
||||
|
||||
async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse:
|
||||
"""创建用户"""
|
||||
hashed_password = pwd_context.hash(user_data.password)
|
||||
user = User(
|
||||
username=user_data.username,
|
||||
password=hashed_password,
|
||||
role=user_data.role,
|
||||
description=user_data.description
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserResponse.from_orm(user)
|
||||
|
||||
@staticmethod
|
||||
async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]:
|
||||
"""根据ID获取用户"""
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
return UserResponse.from_orm(user) if user else None
|
||||
|
||||
@staticmethod
|
||||
async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]:
|
||||
"""获取用户列表"""
|
||||
result = await session.execute(select(User).offset(skip).limit(limit))
|
||||
users = result.scalars().all()
|
||||
return [UserResponse.from_orm(user) for user in users]
|
||||
async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]:
|
||||
"""根据ID获取用户"""
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
return UserResponse.from_orm(user) if user else None
|
||||
|
||||
@staticmethod
|
||||
async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]:
|
||||
"""更新用户信息"""
|
||||
await session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id)
|
||||
.values(**user_data.dict(exclude_unset=True))
|
||||
)
|
||||
await session.commit()
|
||||
return await UserServices.get_user(session, user_id)
|
||||
|
||||
@staticmethod
|
||||
async def delete_user(session: AsyncSession, user_id: int) -> bool:
|
||||
"""删除用户"""
|
||||
result = await session.execute(delete(User).where(User.id == user_id))
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]:
|
||||
"""获取用户列表"""
|
||||
result = await session.execute(select(User).offset(skip).limit(limit))
|
||||
users = result.scalars().all()
|
||||
return [UserResponse.from_orm(user) for user in users]
|
||||
|
||||
@staticmethod
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证输入的明文密码是否与存储的哈希密码匹配"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]:
|
||||
"""验证用户登录"""
|
||||
result = await session.execute(select(User).where(User.username == username))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
if not UserServices.verify_password(password, user.password):
|
||||
return None
|
||||
return UserResponse.from_orm(user)
|
||||
async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]:
|
||||
"""更新用户信息"""
|
||||
await session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id)
|
||||
.values(**user_data.dict(exclude_unset=True))
|
||||
)
|
||||
await session.commit()
|
||||
return await get_user(session, user_id)
|
||||
|
||||
|
||||
async def delete_user(session: AsyncSession, user_id: int) -> bool:
|
||||
"""删除用户"""
|
||||
result = await session.execute(delete(User).where(User.id == user_id))
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证输入的明文密码是否与存储的哈希密码匹配"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]:
|
||||
"""验证用户登录"""
|
||||
result = await session.execute(select(User).where(User.username == username))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, user.password):
|
||||
return None
|
||||
return UserResponse.from_orm(user)
|
||||
|
Loading…
x
Reference in New Issue
Block a user