完成refresh token能访问的bug修正

This commit is contained in:
carry 2025-02-14 16:59:41 +08:00
parent 375077be69
commit b76d721680
3 changed files with 24 additions and 6 deletions

View File

@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest
from services.auth import create_tokens_response, verify_token, refresh_tokens from services.auth import create_tokens_response, refresh_tokens
from services.user import authenticate_user from services.user import authenticate_user
from services.db import get_db_session_dep from services.db import get_db_session_dep

View File

@ -3,13 +3,13 @@ from fastapi.security import OAuth2PasswordBearer
from typing import Optional from typing import Optional
from schemas.auth import TokenPayload from schemas.auth import TokenPayload
from schemas.user import UserRole from schemas.user import UserRole
from services.auth import verify_token from services.auth import verify_access_token
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def _get_token_data(token: str) -> TokenPayload: async def _get_token_data(token: str) -> TokenPayload:
"""验证并返回TokenData""" """验证并返回TokenData"""
token_data = verify_token(token) token_data = verify_access_token(token)
if token_data is None: if token_data is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,

View File

@ -45,10 +45,28 @@ def create_tokens_response(user_id: int, username: str, role: str) -> TokenRespo
refresh_token_exp=refresh_token_exp refresh_token_exp=refresh_token_exp
) )
def verify_token(token: str) -> Optional[TokenPayload]: def verify_access_token(token: str) -> Optional[TokenPayload]:
"""验证token有效性并返回payload如果token无效则返回None""" """验证access token有效性并返回payload如果token无效或类型不匹配则返回None"""
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
if payload.get("token_type") != "access":
return None
return TokenPayload(
id=payload.get("id"),
username=payload.get("username"),
role=payload.get("role"),
exp=payload.get("exp"),
token_type=payload.get("token_type")
)
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
return None
def verify_refresh_token(token: str) -> Optional[TokenPayload]:
"""验证refresh token有效性并返回payload如果token无效或类型不匹配则返回None"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
if payload.get("token_type") != "refresh":
return None
return TokenPayload( return TokenPayload(
id=payload.get("id"), id=payload.get("id"),
username=payload.get("username"), username=payload.get("username"),
@ -61,7 +79,7 @@ def verify_token(token: str) -> Optional[TokenPayload]:
def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]: def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]:
"""使用refresh token刷新access token如果refresh token无效则返回None""" """使用refresh token刷新access token如果refresh token无效则返回None"""
token_data = verify_token(refresh_token) token_data = verify_refresh_token(refresh_token)
if token_data is None: if token_data is None:
return None return None
else: else: