重构了鉴权服务,重构payload
This commit is contained in:
parent
fc547eebe5
commit
2f1cf11d91
@ -1,15 +1,15 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from schemas.auth import Token, LoginRequest, RefreshTokenRequest
|
from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest
|
||||||
from services.auth_service import create_tokens, verify_token, refresh_tokens
|
from services.auth_service import create_tokens_response, verify_token, refresh_tokens
|
||||||
from services.user_services import authenticate_user
|
from services.user_services import authenticate_user
|
||||||
from services.db import get_db
|
from services.db import get_db
|
||||||
|
|
||||||
router = APIRouter(tags=["auth"])
|
router = APIRouter(tags=["auth"])
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
@router.post("/login", response_model=Token)
|
@router.post("/login", response_model=TokenResponse)
|
||||||
async def login(
|
async def login(
|
||||||
login_data: LoginRequest,
|
login_data: LoginRequest,
|
||||||
session: AsyncSession = Depends(get_db)
|
session: AsyncSession = Depends(get_db)
|
||||||
@ -21,17 +21,10 @@ async def login(
|
|||||||
detail="Invalid username or password",
|
detail="Invalid username or password",
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
return create_tokens(user.id, user.username, user.role)
|
return create_tokens_response(user.id, user.username, user.role)
|
||||||
|
|
||||||
@router.post("/logout")
|
|
||||||
async def logout(token: str = Depends(oauth2_scheme)):
|
|
||||||
token_data = verify_token(token)
|
|
||||||
if not token_data:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid token")
|
|
||||||
# TODO: 实现token黑名单
|
|
||||||
return {"message": "Successfully logged out"}
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=Token)
|
@router.post("/refresh", response_model=TokenResponse)
|
||||||
async def refresh_token(refresh_data: RefreshTokenRequest):
|
async def refresh_token(refresh_data: RefreshTokenRequest):
|
||||||
tokens = refresh_tokens(refresh_data.refresh_token)
|
tokens = refresh_tokens(refresh_data.refresh_token)
|
||||||
if not tokens:
|
if not tokens:
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from schemas.auth import TokenData
|
from schemas.auth import TokenPayload
|
||||||
from schemas.user import UserRole
|
from schemas.user import UserRole
|
||||||
from services.auth_service import verify_token
|
from services.auth_service import verify_token
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||||
|
|
||||||
async def _get_token_data(token: str) -> TokenData:
|
async def _get_token_data(token: str) -> TokenPayload:
|
||||||
"""验证并返回TokenData"""
|
"""验证并返回TokenData"""
|
||||||
token_data = verify_token(token)
|
token_data = verify_token(token)
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
@ -18,11 +18,11 @@ async def _get_token_data(token: str) -> TokenData:
|
|||||||
)
|
)
|
||||||
return token_data
|
return token_data
|
||||||
|
|
||||||
async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData:
|
async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenPayload:
|
||||||
"""获取当前用户"""
|
"""获取当前用户"""
|
||||||
return await _get_token_data(token)
|
return await _get_token_data(token)
|
||||||
|
|
||||||
async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenData:
|
async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenPayload:
|
||||||
"""获取当前管理员用户"""
|
"""获取当前管理员用户"""
|
||||||
token_data = await _get_token_data(token)
|
token_data = await _get_token_data(token)
|
||||||
if token_data.role not in [UserRole.SYSTEM_ADMIN, UserRole.ADMIN]:
|
if token_data.role not in [UserRole.SYSTEM_ADMIN, UserRole.ADMIN]:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from schemas.auth import TokenData
|
from schemas.auth import TokenPayload
|
||||||
from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole
|
from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole
|
||||||
from routes.depends import get_current_user,get_current_admin
|
from routes.depends import get_current_user,get_current_admin
|
||||||
from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user
|
from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user
|
||||||
@ -13,7 +13,7 @@ async def get_users(
|
|||||||
page: int = 1,
|
page: int = 1,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
role: Optional[str] = None,
|
role: Optional[str] = None,
|
||||||
current_user_token: TokenData = Depends(get_current_user)
|
current_user_token: TokenPayload = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
current_user = await get_user_by_id(current_user_token.id)
|
current_user = await get_user_by_id(current_user_token.id)
|
||||||
if current_user is None:
|
if current_user is None:
|
||||||
@ -31,7 +31,7 @@ async def get_users(
|
|||||||
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def create_user(
|
async def create_user(
|
||||||
user_data: UserCreate,
|
user_data: UserCreate,
|
||||||
current_user_token: TokenData = Depends(get_current_admin)
|
current_user_token: TokenPayload = Depends(get_current_admin)
|
||||||
):
|
):
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
return await create_user(session, user_data)
|
return await create_user(session, user_data)
|
||||||
@ -40,7 +40,7 @@ async def create_user(
|
|||||||
async def update_user(
|
async def update_user(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
user_data: UserUpdate,
|
user_data: UserUpdate,
|
||||||
current_user_token: TokenData = Depends(get_current_admin)
|
current_user_token: TokenPayload = Depends(get_current_admin)
|
||||||
):
|
):
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
return await update_user(session, user_id, user_data)
|
return await update_user(session, user_id, user_data)
|
||||||
@ -48,7 +48,7 @@ async def update_user(
|
|||||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
async def delete_user(
|
async def delete_user(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
current_user_token: TokenData = Depends(get_current_admin)
|
current_user_token: TokenPayload = Depends(get_current_admin)
|
||||||
):
|
):
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
success = await delete_user(session, user_id)
|
success = await delete_user(session, user_id)
|
||||||
|
@ -1,18 +1,17 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
class TokenResponse(BaseModel):
|
class TokenResponse(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
refresh_token: str
|
refresh_token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
access_token_exp: datetime
|
access_token_exp: int
|
||||||
refresh_token_exp: datetime
|
refresh_token_exp: int
|
||||||
|
|
||||||
class TokenPayload(BaseModel):
|
class TokenPayload(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
username: str
|
username: str
|
||||||
role: str
|
role: str
|
||||||
exp: datetime
|
exp: int
|
||||||
|
|
||||||
class LoginRequest(BaseModel):
|
class LoginRequest(BaseModel):
|
||||||
username: str
|
username: str
|
||||||
|
@ -1,17 +1,22 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import jwt
|
import jwt
|
||||||
|
import time
|
||||||
from config import JWT_CONFIG
|
from config import JWT_CONFIG
|
||||||
from schemas.auth import Token, TokenData
|
from schemas.auth import TokenResponse, TokenPayload
|
||||||
|
|
||||||
SECRET_KEY = JWT_CONFIG['secret_key']
|
SECRET_KEY = JWT_CONFIG['secret_key']
|
||||||
ALGORITHM = JWT_CONFIG['algorithm']
|
ALGORITHM = JWT_CONFIG['algorithm']
|
||||||
ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire']
|
ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire']
|
||||||
REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire']
|
REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire']
|
||||||
|
|
||||||
def create_access_token(user_id: int, username: str, role: str) -> str:
|
|
||||||
"""创建access token"""
|
def get_current_time() -> int:
|
||||||
expire = datetime.utcnow() + ACCESS_TOKEN_EXPIRE
|
"""获取当前UTC时间戳"""
|
||||||
|
return int(time.time())
|
||||||
|
|
||||||
|
def create_token(user_id: int, username: str, role: str, expire_delta: int) -> str:
|
||||||
|
"""创建JWT token"""
|
||||||
|
expire = get_current_time() + expire_delta
|
||||||
to_encode = {
|
to_encode = {
|
||||||
"id": user_id,
|
"id": user_id,
|
||||||
"username": username,
|
"username": username,
|
||||||
@ -20,33 +25,36 @@ def create_access_token(user_id: int, username: str, role: str) -> str:
|
|||||||
}
|
}
|
||||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
def create_access_token(user_id: int, username: str, role: str) -> str:
|
||||||
|
"""创建access token"""
|
||||||
|
return create_token(user_id, username, role, ACCESS_TOKEN_EXPIRE)
|
||||||
|
|
||||||
def create_refresh_token(user_id: int, username: str, role: str) -> str:
|
def create_refresh_token(user_id: int, username: str, role: str) -> str:
|
||||||
"""创建refresh token"""
|
"""创建refresh token"""
|
||||||
expire = datetime.utcnow() + REFRESH_TOKEN_EXPIRE
|
return create_token(user_id, username, role, REFRESH_TOKEN_EXPIRE)
|
||||||
to_encode = {
|
|
||||||
"id": user_id,
|
|
||||||
"username": username,
|
|
||||||
"role": role,
|
|
||||||
"exp": expire
|
|
||||||
}
|
|
||||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
||||||
|
|
||||||
def create_tokens(user_id: int, username: str, role: str) -> Token:
|
def create_tokens_response(user_id: int, username: str, role: str) -> TokenResponse:
|
||||||
"""创建access token和refresh token"""
|
"""创建access token和refresh token"""
|
||||||
access_token = create_access_token(user_id, username, role)
|
access_token = create_access_token(user_id, username, role)
|
||||||
refresh_token = create_refresh_token(user_id, username, role)
|
refresh_token = create_refresh_token(user_id, username, role)
|
||||||
return Token(
|
|
||||||
|
# 获取token的过期时间
|
||||||
|
access_token_exp = get_current_time() + int(ACCESS_TOKEN_EXPIRE.total_seconds())
|
||||||
|
refresh_token_exp = get_current_time() + int(REFRESH_TOKEN_EXPIRE.total_seconds())
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
refresh_token=refresh_token,
|
refresh_token=refresh_token,
|
||||||
token_type="bearer",
|
token_type="bearer",
|
||||||
expires_in=int(ACCESS_TOKEN_EXPIRE.total_seconds())
|
access_token_exp=access_token_exp,
|
||||||
|
refresh_token_exp=refresh_token_exp
|
||||||
)
|
)
|
||||||
|
|
||||||
def verify_token(token: str) -> Optional[TokenData]:
|
def verify_token(token: str) -> Optional[TokenPayload]:
|
||||||
"""验证token有效性并返回payload,如果token无效则返回None"""
|
"""验证token有效性并返回payload,如果token无效则返回None"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
return TokenData(
|
return TokenPayload(
|
||||||
id=payload.get("id"),
|
id=payload.get("id"),
|
||||||
username=payload.get("username"),
|
username=payload.get("username"),
|
||||||
role=payload.get("role"),
|
role=payload.get("role"),
|
||||||
@ -55,13 +63,14 @@ def verify_token(token: str) -> Optional[TokenData]:
|
|||||||
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def refresh_tokens(refresh_token: str) -> Optional[Token]:
|
def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]:
|
||||||
"""使用refresh token刷新access token,如果refresh token无效则返回None"""
|
"""使用refresh token刷新access token,如果refresh token无效则返回None"""
|
||||||
token_data = verify_token(refresh_token)
|
token_data = verify_token(refresh_token)
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
return None
|
return None
|
||||||
return create_tokens(
|
else:
|
||||||
user_id=token_data.id,
|
return create_tokens_response(
|
||||||
username=token_data.username,
|
user_id=token_data.id,
|
||||||
role=token_data.role
|
username=token_data.username,
|
||||||
)
|
role=token_data.role
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user