完了鉴权相关代码和鉴权依赖注入
This commit is contained in:
parent
3b7ac1f682
commit
f1cdbab0f4
18
routes/depends.py
Normal file
18
routes/depends.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from typing import Optional
|
||||||
|
from schemas.auth import TokenData
|
||||||
|
from services.auth_service import verify_token
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||||
|
|
||||||
|
async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData:
|
||||||
|
"""获取当前用户"""
|
||||||
|
token_data = verify_token(token)
|
||||||
|
if token_data is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid authentication credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
return token_data
|
@ -1,7 +1,9 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from schemas.auth import TokenData
|
||||||
from schemas.user import UserCreate, UserUpdate, UserResponse
|
from schemas.user import UserCreate, UserUpdate, UserResponse
|
||||||
from services.auth import get_current_user
|
from routes.depends import get_current_user
|
||||||
|
from services.user_services import get_user_by_id
|
||||||
|
|
||||||
router = APIRouter(tags=["users"])
|
router = APIRouter(tags=["users"])
|
||||||
|
|
||||||
@ -10,8 +12,14 @@ async def get_users(
|
|||||||
page: int = 1,
|
page: int = 1,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
role: Optional[str] = None,
|
role: Optional[str] = None,
|
||||||
current_user: UserResponse = Depends(get_current_user)
|
current_user_token: TokenData = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
|
current_user = await get_user_by_id(current_user_token.id)
|
||||||
|
if current_user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
if current_user.role not in ["system_admin", "admin"]:
|
if current_user.role not in ["system_admin", "admin"]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
@ -23,8 +31,14 @@ async def get_users(
|
|||||||
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def create_user(
|
async def create_user(
|
||||||
user_data: UserCreate,
|
user_data: UserCreate,
|
||||||
current_user: UserResponse = Depends(get_current_user)
|
current_user_token: TokenData = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
|
current_user = await get_user_by_id(current_user_token.id)
|
||||||
|
if current_user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
if current_user.role not in ["system_admin", "admin"]:
|
if current_user.role not in ["system_admin", "admin"]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
@ -37,8 +51,14 @@ async def create_user(
|
|||||||
async def update_user(
|
async def update_user(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
user_data: UserUpdate,
|
user_data: UserUpdate,
|
||||||
current_user: UserResponse = Depends(get_current_user)
|
current_user_token: TokenData = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
|
current_user = await get_user_by_id(current_user_token.id)
|
||||||
|
if current_user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
if current_user.role not in ["system_admin", "admin"]:
|
if current_user.role not in ["system_admin", "admin"]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
@ -50,8 +70,14 @@ async def update_user(
|
|||||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
async def delete_user(
|
async def delete_user(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
current_user: UserResponse = Depends(get_current_user)
|
current_user_token: TokenData = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
|
current_user = await get_user_by_id(current_user_token.id)
|
||||||
|
if current_user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
if current_user.role not in ["system_admin", "admin"]:
|
if current_user.role not in ["system_admin", "admin"]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
67
services/auth_service.py
Normal file
67
services/auth_service.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
import jwt
|
||||||
|
from config import JWT_CONFIG
|
||||||
|
from schemas.auth import Token, TokenData
|
||||||
|
|
||||||
|
SECRET_KEY = JWT_CONFIG['secret_key']
|
||||||
|
ALGORITHM = JWT_CONFIG['algorithm']
|
||||||
|
ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire']
|
||||||
|
REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire']
|
||||||
|
|
||||||
|
def create_access_token(user_id: int, username: str, role: str) -> str:
|
||||||
|
"""创建access token"""
|
||||||
|
expire = datetime.utcnow() + ACCESS_TOKEN_EXPIRE
|
||||||
|
to_encode = {
|
||||||
|
"id": user_id,
|
||||||
|
"username": username,
|
||||||
|
"role": role,
|
||||||
|
"exp": expire
|
||||||
|
}
|
||||||
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
def create_refresh_token(user_id: int, username: str, role: str) -> str:
|
||||||
|
"""创建refresh token"""
|
||||||
|
expire = datetime.utcnow() + REFRESH_TOKEN_EXPIRE
|
||||||
|
to_encode = {
|
||||||
|
"id": user_id,
|
||||||
|
"username": username,
|
||||||
|
"role": role,
|
||||||
|
"exp": expire
|
||||||
|
}
|
||||||
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
def create_tokens(user_id: int, username: str, role: str) -> Token:
|
||||||
|
"""创建access token和refresh token"""
|
||||||
|
access_token = create_access_token(user_id, username, role)
|
||||||
|
refresh_token = create_refresh_token(user_id, username, role)
|
||||||
|
return Token(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
token_type="bearer",
|
||||||
|
expires_in=int(ACCESS_TOKEN_EXPIRE.total_seconds())
|
||||||
|
)
|
||||||
|
|
||||||
|
def verify_token(token: str) -> Optional[TokenData]:
|
||||||
|
"""验证token有效性并返回payload,如果token无效则返回None"""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
return TokenData(
|
||||||
|
id=payload.get("id"),
|
||||||
|
username=payload.get("username"),
|
||||||
|
role=payload.get("role"),
|
||||||
|
exp=payload.get("exp")
|
||||||
|
)
|
||||||
|
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def refresh_tokens(refresh_token: str) -> Optional[Token]:
|
||||||
|
"""使用refresh token刷新access token,如果refresh token无效则返回None"""
|
||||||
|
token_data = verify_token(refresh_token)
|
||||||
|
if token_data is None:
|
||||||
|
return None
|
||||||
|
return create_tokens(
|
||||||
|
user_id=token_data.id,
|
||||||
|
username=token_data.username,
|
||||||
|
role=token_data.role
|
||||||
|
)
|
@ -4,7 +4,7 @@ from ..models.user import Base, User, UserRole
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG
|
from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG
|
||||||
from services.user_services import get_password_hash
|
from user_services import get_password_hash
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
# 全局数据库引擎实例
|
# 全局数据库引擎实例
|
||||||
|
Loading…
x
Reference in New Issue
Block a user