diff --git a/routes/auth.py b/routes/auth.py index 1ba2498..52bd654 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -1,15 +1,15 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession -from schemas.auth import Token, LoginRequest, RefreshTokenRequest -from services.auth_service import create_tokens, verify_token, refresh_tokens +from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest +from services.auth_service import create_tokens_response, verify_token, refresh_tokens from services.user_services import authenticate_user from services.db import get_db router = APIRouter(tags=["auth"]) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -@router.post("/login", response_model=Token) +@router.post("/login", response_model=TokenResponse) async def login( login_data: LoginRequest, session: AsyncSession = Depends(get_db) @@ -21,17 +21,10 @@ async def login( detail="Invalid username or password", headers={"WWW-Authenticate": "Bearer"}, ) - return create_tokens(user.id, user.username, user.role) + return create_tokens_response(user.id, user.username, user.role) -@router.post("/logout") -async def logout(token: str = Depends(oauth2_scheme)): - token_data = verify_token(token) - if not token_data: - raise HTTPException(status_code=401, detail="Invalid token") - # TODO: 实现token黑名单 - return {"message": "Successfully logged out"} -@router.post("/refresh", response_model=Token) +@router.post("/refresh", response_model=TokenResponse) async def refresh_token(refresh_data: RefreshTokenRequest): tokens = refresh_tokens(refresh_data.refresh_token) if not tokens: diff --git a/routes/depends.py b/routes/depends.py index af1e0fe..bdcb70e 100644 --- a/routes/depends.py +++ b/routes/depends.py @@ -1,13 +1,13 @@ from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from typing import Optional -from schemas.auth import TokenData +from schemas.auth import TokenPayload from schemas.user import UserRole from services.auth_service import verify_token oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") -async def _get_token_data(token: str) -> TokenData: +async def _get_token_data(token: str) -> TokenPayload: """验证并返回TokenData""" token_data = verify_token(token) if token_data is None: @@ -18,11 +18,11 @@ async def _get_token_data(token: str) -> TokenData: ) return token_data -async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData: +async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenPayload: """获取当前用户""" return await _get_token_data(token) -async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenData: +async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenPayload: """获取当前管理员用户""" token_data = await _get_token_data(token) if token_data.role not in [UserRole.SYSTEM_ADMIN, UserRole.ADMIN]: diff --git a/routes/users.py b/routes/users.py index 7e598b6..2e863a2 100644 --- a/routes/users.py +++ b/routes/users.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends, HTTPException, status from typing import List, Optional -from schemas.auth import TokenData +from schemas.auth import TokenPayload from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole from routes.depends import get_current_user,get_current_admin from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user @@ -13,7 +13,7 @@ async def get_users( page: int = 1, limit: int = 10, role: Optional[str] = None, - current_user_token: TokenData = Depends(get_current_user) + current_user_token: TokenPayload = Depends(get_current_user) ): current_user = await get_user_by_id(current_user_token.id) if current_user is None: @@ -31,7 +31,7 @@ async def get_users( @router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def create_user( user_data: UserCreate, - current_user_token: TokenData = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin) ): async with get_db_session() as session: return await create_user(session, user_data) @@ -40,7 +40,7 @@ async def create_user( async def update_user( user_id: int, user_data: UserUpdate, - current_user_token: TokenData = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin) ): async with get_db_session() as session: return await update_user(session, user_id, user_data) @@ -48,7 +48,7 @@ async def update_user( @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_user( user_id: int, - current_user_token: TokenData = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin) ): async with get_db_session() as session: success = await delete_user(session, user_id) diff --git a/schemas/auth.py b/schemas/auth.py index f6679ea..bde44ff 100644 --- a/schemas/auth.py +++ b/schemas/auth.py @@ -1,18 +1,17 @@ from pydantic import BaseModel -from datetime import datetime class TokenResponse(BaseModel): access_token: str refresh_token: str token_type: str - access_token_exp: datetime - refresh_token_exp: datetime + access_token_exp: int + refresh_token_exp: int class TokenPayload(BaseModel): id: int username: str role: str - exp: datetime + exp: int class LoginRequest(BaseModel): username: str diff --git a/services/auth_service.py b/services/auth_service.py index 7549224..30a9b44 100644 --- a/services/auth_service.py +++ b/services/auth_service.py @@ -1,17 +1,22 @@ -from datetime import datetime from typing import Optional import jwt +import time from config import JWT_CONFIG -from schemas.auth import Token, TokenData +from schemas.auth import TokenResponse, TokenPayload SECRET_KEY = JWT_CONFIG['secret_key'] ALGORITHM = JWT_CONFIG['algorithm'] ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire'] REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire'] -def create_access_token(user_id: int, username: str, role: str) -> str: - """创建access token""" - expire = datetime.utcnow() + ACCESS_TOKEN_EXPIRE + +def get_current_time() -> int: + """获取当前UTC时间戳""" + return int(time.time()) + +def create_token(user_id: int, username: str, role: str, expire_delta: int) -> str: + """创建JWT token""" + expire = get_current_time() + expire_delta to_encode = { "id": user_id, "username": username, @@ -20,33 +25,36 @@ def create_access_token(user_id: int, username: str, role: str) -> str: } return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) +def create_access_token(user_id: int, username: str, role: str) -> str: + """创建access token""" + return create_token(user_id, username, role, ACCESS_TOKEN_EXPIRE) + def create_refresh_token(user_id: int, username: str, role: str) -> str: """创建refresh token""" - expire = datetime.utcnow() + REFRESH_TOKEN_EXPIRE - to_encode = { - "id": user_id, - "username": username, - "role": role, - "exp": expire - } - return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return create_token(user_id, username, role, REFRESH_TOKEN_EXPIRE) -def create_tokens(user_id: int, username: str, role: str) -> Token: +def create_tokens_response(user_id: int, username: str, role: str) -> TokenResponse: """创建access token和refresh token""" access_token = create_access_token(user_id, username, role) refresh_token = create_refresh_token(user_id, username, role) - return Token( + + # 获取token的过期时间 + access_token_exp = get_current_time() + int(ACCESS_TOKEN_EXPIRE.total_seconds()) + refresh_token_exp = get_current_time() + int(REFRESH_TOKEN_EXPIRE.total_seconds()) + + return TokenResponse( access_token=access_token, refresh_token=refresh_token, token_type="bearer", - expires_in=int(ACCESS_TOKEN_EXPIRE.total_seconds()) + access_token_exp=access_token_exp, + refresh_token_exp=refresh_token_exp ) -def verify_token(token: str) -> Optional[TokenData]: +def verify_token(token: str) -> Optional[TokenPayload]: """验证token有效性并返回payload,如果token无效则返回None""" try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - return TokenData( + return TokenPayload( id=payload.get("id"), username=payload.get("username"), role=payload.get("role"), @@ -55,13 +63,14 @@ def verify_token(token: str) -> Optional[TokenData]: except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): return None -def refresh_tokens(refresh_token: str) -> Optional[Token]: +def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]: """使用refresh token刷新access token,如果refresh token无效则返回None""" token_data = verify_token(refresh_token) if token_data is None: return None - return create_tokens( - user_id=token_data.id, - username=token_data.username, - role=token_data.role - ) + else: + return create_tokens_response( + user_id=token_data.id, + username=token_data.username, + role=token_data.role + )