重构了鉴权服务,重构payload

This commit is contained in:
carry 2025-01-21 23:38:52 +08:00
parent fc547eebe5
commit 2f1cf11d91
5 changed files with 50 additions and 49 deletions

View File

@ -1,15 +1,15 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from schemas.auth import Token, LoginRequest, RefreshTokenRequest
from services.auth_service import create_tokens, verify_token, refresh_tokens
from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest
from services.auth_service import create_tokens_response, verify_token, refresh_tokens
from services.user_services import authenticate_user
from services.db import get_db
router = APIRouter(tags=["auth"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@router.post("/login", response_model=Token)
@router.post("/login", response_model=TokenResponse)
async def login(
login_data: LoginRequest,
session: AsyncSession = Depends(get_db)
@ -21,17 +21,10 @@ async def login(
detail="Invalid username or password",
headers={"WWW-Authenticate": "Bearer"},
)
return create_tokens(user.id, user.username, user.role)
return create_tokens_response(user.id, user.username, user.role)
@router.post("/logout")
async def logout(token: str = Depends(oauth2_scheme)):
token_data = verify_token(token)
if not token_data:
raise HTTPException(status_code=401, detail="Invalid token")
# TODO: 实现token黑名单
return {"message": "Successfully logged out"}
@router.post("/refresh", response_model=Token)
@router.post("/refresh", response_model=TokenResponse)
async def refresh_token(refresh_data: RefreshTokenRequest):
tokens = refresh_tokens(refresh_data.refresh_token)
if not tokens:

View File

@ -1,13 +1,13 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from typing import Optional
from schemas.auth import TokenData
from schemas.auth import TokenPayload
from schemas.user import UserRole
from services.auth_service import verify_token
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
async def _get_token_data(token: str) -> TokenData:
async def _get_token_data(token: str) -> TokenPayload:
"""验证并返回TokenData"""
token_data = verify_token(token)
if token_data is None:
@ -18,11 +18,11 @@ async def _get_token_data(token: str) -> TokenData:
)
return token_data
async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData:
async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenPayload:
"""获取当前用户"""
return await _get_token_data(token)
async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenData:
async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenPayload:
"""获取当前管理员用户"""
token_data = await _get_token_data(token)
if token_data.role not in [UserRole.SYSTEM_ADMIN, UserRole.ADMIN]:

View File

@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException, status
from typing import List, Optional
from schemas.auth import TokenData
from schemas.auth import TokenPayload
from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole
from routes.depends import get_current_user,get_current_admin
from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user
@ -13,7 +13,7 @@ async def get_users(
page: int = 1,
limit: int = 10,
role: Optional[str] = None,
current_user_token: TokenData = Depends(get_current_user)
current_user_token: TokenPayload = Depends(get_current_user)
):
current_user = await get_user_by_id(current_user_token.id)
if current_user is None:
@ -31,7 +31,7 @@ async def get_users(
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_user(
user_data: UserCreate,
current_user_token: TokenData = Depends(get_current_admin)
current_user_token: TokenPayload = Depends(get_current_admin)
):
async with get_db_session() as session:
return await create_user(session, user_data)
@ -40,7 +40,7 @@ async def create_user(
async def update_user(
user_id: int,
user_data: UserUpdate,
current_user_token: TokenData = Depends(get_current_admin)
current_user_token: TokenPayload = Depends(get_current_admin)
):
async with get_db_session() as session:
return await update_user(session, user_id, user_data)
@ -48,7 +48,7 @@ async def update_user(
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(
user_id: int,
current_user_token: TokenData = Depends(get_current_admin)
current_user_token: TokenPayload = Depends(get_current_admin)
):
async with get_db_session() as session:
success = await delete_user(session, user_id)

View File

@ -1,18 +1,17 @@
from pydantic import BaseModel
from datetime import datetime
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str
access_token_exp: datetime
refresh_token_exp: datetime
access_token_exp: int
refresh_token_exp: int
class TokenPayload(BaseModel):
id: int
username: str
role: str
exp: datetime
exp: int
class LoginRequest(BaseModel):
username: str

View File

@ -1,17 +1,22 @@
from datetime import datetime
from typing import Optional
import jwt
import time
from config import JWT_CONFIG
from schemas.auth import Token, TokenData
from schemas.auth import TokenResponse, TokenPayload
SECRET_KEY = JWT_CONFIG['secret_key']
ALGORITHM = JWT_CONFIG['algorithm']
ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire']
REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire']
def create_access_token(user_id: int, username: str, role: str) -> str:
"""创建access token"""
expire = datetime.utcnow() + ACCESS_TOKEN_EXPIRE
def get_current_time() -> int:
"""获取当前UTC时间戳"""
return int(time.time())
def create_token(user_id: int, username: str, role: str, expire_delta: int) -> str:
"""创建JWT token"""
expire = get_current_time() + expire_delta
to_encode = {
"id": user_id,
"username": username,
@ -20,33 +25,36 @@ def create_access_token(user_id: int, username: str, role: str) -> str:
}
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(user_id: int, username: str, role: str) -> str:
"""创建access token"""
return create_token(user_id, username, role, ACCESS_TOKEN_EXPIRE)
def create_refresh_token(user_id: int, username: str, role: str) -> str:
"""创建refresh token"""
expire = datetime.utcnow() + REFRESH_TOKEN_EXPIRE
to_encode = {
"id": user_id,
"username": username,
"role": role,
"exp": expire
}
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return create_token(user_id, username, role, REFRESH_TOKEN_EXPIRE)
def create_tokens(user_id: int, username: str, role: str) -> Token:
def create_tokens_response(user_id: int, username: str, role: str) -> TokenResponse:
"""创建access token和refresh token"""
access_token = create_access_token(user_id, username, role)
refresh_token = create_refresh_token(user_id, username, role)
return Token(
# 获取token的过期时间
access_token_exp = get_current_time() + int(ACCESS_TOKEN_EXPIRE.total_seconds())
refresh_token_exp = get_current_time() + int(REFRESH_TOKEN_EXPIRE.total_seconds())
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=int(ACCESS_TOKEN_EXPIRE.total_seconds())
access_token_exp=access_token_exp,
refresh_token_exp=refresh_token_exp
)
def verify_token(token: str) -> Optional[TokenData]:
def verify_token(token: str) -> Optional[TokenPayload]:
"""验证token有效性并返回payload如果token无效则返回None"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return TokenData(
return TokenPayload(
id=payload.get("id"),
username=payload.get("username"),
role=payload.get("role"),
@ -55,13 +63,14 @@ def verify_token(token: str) -> Optional[TokenData]:
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
return None
def refresh_tokens(refresh_token: str) -> Optional[Token]:
def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]:
"""使用refresh token刷新access token如果refresh token无效则返回None"""
token_data = verify_token(refresh_token)
if token_data is None:
return None
return create_tokens(
user_id=token_data.id,
username=token_data.username,
role=token_data.role
)
else:
return create_tokens_response(
user_id=token_data.id,
username=token_data.username,
role=token_data.role
)