重构了鉴权服务,重构payload

This commit is contained in:
carry 2025-01-21 23:38:52 +08:00
parent fc547eebe5
commit 2f1cf11d91
5 changed files with 50 additions and 49 deletions

View File

@ -1,15 +1,15 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from schemas.auth import Token, LoginRequest, RefreshTokenRequest from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest
from services.auth_service import create_tokens, verify_token, refresh_tokens from services.auth_service import create_tokens_response, verify_token, refresh_tokens
from services.user_services import authenticate_user from services.user_services import authenticate_user
from services.db import get_db from services.db import get_db
router = APIRouter(tags=["auth"]) router = APIRouter(tags=["auth"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@router.post("/login", response_model=Token) @router.post("/login", response_model=TokenResponse)
async def login( async def login(
login_data: LoginRequest, login_data: LoginRequest,
session: AsyncSession = Depends(get_db) session: AsyncSession = Depends(get_db)
@ -21,17 +21,10 @@ async def login(
detail="Invalid username or password", detail="Invalid username or password",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
return create_tokens(user.id, user.username, user.role) return create_tokens_response(user.id, user.username, user.role)
@router.post("/logout")
async def logout(token: str = Depends(oauth2_scheme)):
token_data = verify_token(token)
if not token_data:
raise HTTPException(status_code=401, detail="Invalid token")
# TODO: 实现token黑名单
return {"message": "Successfully logged out"}
@router.post("/refresh", response_model=Token) @router.post("/refresh", response_model=TokenResponse)
async def refresh_token(refresh_data: RefreshTokenRequest): async def refresh_token(refresh_data: RefreshTokenRequest):
tokens = refresh_tokens(refresh_data.refresh_token) tokens = refresh_tokens(refresh_data.refresh_token)
if not tokens: if not tokens:

View File

@ -1,13 +1,13 @@
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from typing import Optional from typing import Optional
from schemas.auth import TokenData from schemas.auth import TokenPayload
from schemas.user import UserRole from schemas.user import UserRole
from services.auth_service import verify_token from services.auth_service import verify_token
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def _get_token_data(token: str) -> TokenData: async def _get_token_data(token: str) -> TokenPayload:
"""验证并返回TokenData""" """验证并返回TokenData"""
token_data = verify_token(token) token_data = verify_token(token)
if token_data is None: if token_data is None:
@ -18,11 +18,11 @@ async def _get_token_data(token: str) -> TokenData:
) )
return token_data return token_data
async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData: async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenPayload:
"""获取当前用户""" """获取当前用户"""
return await _get_token_data(token) return await _get_token_data(token)
async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenData: async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenPayload:
"""获取当前管理员用户""" """获取当前管理员用户"""
token_data = await _get_token_data(token) token_data = await _get_token_data(token)
if token_data.role not in [UserRole.SYSTEM_ADMIN, UserRole.ADMIN]: if token_data.role not in [UserRole.SYSTEM_ADMIN, UserRole.ADMIN]:

View File

@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from typing import List, Optional from typing import List, Optional
from schemas.auth import TokenData from schemas.auth import TokenPayload
from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole
from routes.depends import get_current_user,get_current_admin from routes.depends import get_current_user,get_current_admin
from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user
@ -13,7 +13,7 @@ async def get_users(
page: int = 1, page: int = 1,
limit: int = 10, limit: int = 10,
role: Optional[str] = None, role: Optional[str] = None,
current_user_token: TokenData = Depends(get_current_user) current_user_token: TokenPayload = Depends(get_current_user)
): ):
current_user = await get_user_by_id(current_user_token.id) current_user = await get_user_by_id(current_user_token.id)
if current_user is None: if current_user is None:
@ -31,7 +31,7 @@ async def get_users(
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_user( async def create_user(
user_data: UserCreate, user_data: UserCreate,
current_user_token: TokenData = Depends(get_current_admin) current_user_token: TokenPayload = Depends(get_current_admin)
): ):
async with get_db_session() as session: async with get_db_session() as session:
return await create_user(session, user_data) return await create_user(session, user_data)
@ -40,7 +40,7 @@ async def create_user(
async def update_user( async def update_user(
user_id: int, user_id: int,
user_data: UserUpdate, user_data: UserUpdate,
current_user_token: TokenData = Depends(get_current_admin) current_user_token: TokenPayload = Depends(get_current_admin)
): ):
async with get_db_session() as session: async with get_db_session() as session:
return await update_user(session, user_id, user_data) return await update_user(session, user_id, user_data)
@ -48,7 +48,7 @@ async def update_user(
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user( async def delete_user(
user_id: int, user_id: int,
current_user_token: TokenData = Depends(get_current_admin) current_user_token: TokenPayload = Depends(get_current_admin)
): ):
async with get_db_session() as session: async with get_db_session() as session:
success = await delete_user(session, user_id) success = await delete_user(session, user_id)

View File

@ -1,18 +1,17 @@
from pydantic import BaseModel from pydantic import BaseModel
from datetime import datetime
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
access_token: str access_token: str
refresh_token: str refresh_token: str
token_type: str token_type: str
access_token_exp: datetime access_token_exp: int
refresh_token_exp: datetime refresh_token_exp: int
class TokenPayload(BaseModel): class TokenPayload(BaseModel):
id: int id: int
username: str username: str
role: str role: str
exp: datetime exp: int
class LoginRequest(BaseModel): class LoginRequest(BaseModel):
username: str username: str

View File

@ -1,17 +1,22 @@
from datetime import datetime
from typing import Optional from typing import Optional
import jwt import jwt
import time
from config import JWT_CONFIG from config import JWT_CONFIG
from schemas.auth import Token, TokenData from schemas.auth import TokenResponse, TokenPayload
SECRET_KEY = JWT_CONFIG['secret_key'] SECRET_KEY = JWT_CONFIG['secret_key']
ALGORITHM = JWT_CONFIG['algorithm'] ALGORITHM = JWT_CONFIG['algorithm']
ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire'] ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire']
REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire'] REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire']
def create_access_token(user_id: int, username: str, role: str) -> str:
"""创建access token""" def get_current_time() -> int:
expire = datetime.utcnow() + ACCESS_TOKEN_EXPIRE """获取当前UTC时间戳"""
return int(time.time())
def create_token(user_id: int, username: str, role: str, expire_delta: int) -> str:
"""创建JWT token"""
expire = get_current_time() + expire_delta
to_encode = { to_encode = {
"id": user_id, "id": user_id,
"username": username, "username": username,
@ -20,33 +25,36 @@ def create_access_token(user_id: int, username: str, role: str) -> str:
} }
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(user_id: int, username: str, role: str) -> str:
"""创建access token"""
return create_token(user_id, username, role, ACCESS_TOKEN_EXPIRE)
def create_refresh_token(user_id: int, username: str, role: str) -> str: def create_refresh_token(user_id: int, username: str, role: str) -> str:
"""创建refresh token""" """创建refresh token"""
expire = datetime.utcnow() + REFRESH_TOKEN_EXPIRE return create_token(user_id, username, role, REFRESH_TOKEN_EXPIRE)
to_encode = {
"id": user_id,
"username": username,
"role": role,
"exp": expire
}
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_tokens(user_id: int, username: str, role: str) -> Token: def create_tokens_response(user_id: int, username: str, role: str) -> TokenResponse:
"""创建access token和refresh token""" """创建access token和refresh token"""
access_token = create_access_token(user_id, username, role) access_token = create_access_token(user_id, username, role)
refresh_token = create_refresh_token(user_id, username, role) refresh_token = create_refresh_token(user_id, username, role)
return Token(
# 获取token的过期时间
access_token_exp = get_current_time() + int(ACCESS_TOKEN_EXPIRE.total_seconds())
refresh_token_exp = get_current_time() + int(REFRESH_TOKEN_EXPIRE.total_seconds())
return TokenResponse(
access_token=access_token, access_token=access_token,
refresh_token=refresh_token, refresh_token=refresh_token,
token_type="bearer", token_type="bearer",
expires_in=int(ACCESS_TOKEN_EXPIRE.total_seconds()) access_token_exp=access_token_exp,
refresh_token_exp=refresh_token_exp
) )
def verify_token(token: str) -> Optional[TokenData]: def verify_token(token: str) -> Optional[TokenPayload]:
"""验证token有效性并返回payload如果token无效则返回None""" """验证token有效性并返回payload如果token无效则返回None"""
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return TokenData( return TokenPayload(
id=payload.get("id"), id=payload.get("id"),
username=payload.get("username"), username=payload.get("username"),
role=payload.get("role"), role=payload.get("role"),
@ -55,13 +63,14 @@ def verify_token(token: str) -> Optional[TokenData]:
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
return None return None
def refresh_tokens(refresh_token: str) -> Optional[Token]: def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]:
"""使用refresh token刷新access token如果refresh token无效则返回None""" """使用refresh token刷新access token如果refresh token无效则返回None"""
token_data = verify_token(refresh_token) token_data = verify_token(refresh_token)
if token_data is None: if token_data is None:
return None return None
return create_tokens( else:
user_id=token_data.id, return create_tokens_response(
username=token_data.username, user_id=token_data.id,
role=token_data.role username=token_data.username,
) role=token_data.role
)