From f1cdbab0f484517776d9b6c949ae3e66b6fe9999 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 21:28:18 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E4=BA=86=E9=89=B4=E6=9D=83=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E4=BB=A3=E7=A0=81=E5=92=8C=E9=89=B4=E6=9D=83=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=E6=B3=A8=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- routes/depends.py | 18 +++++++++++ routes/users.py | 36 ++++++++++++++++++--- services/auth_service.py | 67 ++++++++++++++++++++++++++++++++++++++++ services/db.py | 2 +- 4 files changed, 117 insertions(+), 6 deletions(-) create mode 100644 routes/depends.py create mode 100644 services/auth_service.py diff --git a/routes/depends.py b/routes/depends.py new file mode 100644 index 0000000..43aebad --- /dev/null +++ b/routes/depends.py @@ -0,0 +1,18 @@ +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from typing import Optional +from schemas.auth import TokenData +from services.auth_service import verify_token + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") + +async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData: + """获取当前用户""" + token_data = verify_token(token) + if token_data is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + return token_data diff --git a/routes/users.py b/routes/users.py index 4caf448..92a648d 100644 --- a/routes/users.py +++ b/routes/users.py @@ -1,7 +1,9 @@ from fastapi import APIRouter, Depends, HTTPException, status from typing import List, Optional +from schemas.auth import TokenData from schemas.user import UserCreate, UserUpdate, UserResponse -from services.auth import get_current_user +from routes.depends import get_current_user +from services.user_services import get_user_by_id router = APIRouter(tags=["users"]) @@ -10,8 +12,14 @@ async def get_users( page: int = 1, limit: int = 10, role: Optional[str] = None, - current_user: UserResponse = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_user) ): + current_user = await get_user_by_id(current_user_token.id) + if current_user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) if current_user.role not in ["system_admin", "admin"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -23,8 +31,14 @@ async def get_users( @router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def create_user( user_data: UserCreate, - current_user: UserResponse = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_user) ): + current_user = await get_user_by_id(current_user_token.id) + if current_user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) if current_user.role not in ["system_admin", "admin"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -37,8 +51,14 @@ async def create_user( async def update_user( user_id: int, user_data: UserUpdate, - current_user: UserResponse = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_user) ): + current_user = await get_user_by_id(current_user_token.id) + if current_user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) if current_user.role not in ["system_admin", "admin"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -50,8 +70,14 @@ async def update_user( @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_user( user_id: int, - current_user: UserResponse = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_user) ): + current_user = await get_user_by_id(current_user_token.id) + if current_user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) if current_user.role not in ["system_admin", "admin"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/services/auth_service.py b/services/auth_service.py new file mode 100644 index 0000000..7549224 --- /dev/null +++ b/services/auth_service.py @@ -0,0 +1,67 @@ +from datetime import datetime +from typing import Optional +import jwt +from config import JWT_CONFIG +from schemas.auth import Token, TokenData + +SECRET_KEY = JWT_CONFIG['secret_key'] +ALGORITHM = JWT_CONFIG['algorithm'] +ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire'] +REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire'] + +def create_access_token(user_id: int, username: str, role: str) -> str: + """创建access token""" + expire = datetime.utcnow() + ACCESS_TOKEN_EXPIRE + to_encode = { + "id": user_id, + "username": username, + "role": role, + "exp": expire + } + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + +def create_refresh_token(user_id: int, username: str, role: str) -> str: + """创建refresh token""" + expire = datetime.utcnow() + REFRESH_TOKEN_EXPIRE + to_encode = { + "id": user_id, + "username": username, + "role": role, + "exp": expire + } + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + +def create_tokens(user_id: int, username: str, role: str) -> Token: + """创建access token和refresh token""" + access_token = create_access_token(user_id, username, role) + refresh_token = create_refresh_token(user_id, username, role) + return Token( + access_token=access_token, + refresh_token=refresh_token, + token_type="bearer", + expires_in=int(ACCESS_TOKEN_EXPIRE.total_seconds()) + ) + +def verify_token(token: str) -> Optional[TokenData]: + """验证token有效性并返回payload,如果token无效则返回None""" + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + return TokenData( + id=payload.get("id"), + username=payload.get("username"), + role=payload.get("role"), + exp=payload.get("exp") + ) + except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): + return None + +def refresh_tokens(refresh_token: str) -> Optional[Token]: + """使用refresh token刷新access token,如果refresh token无效则返回None""" + token_data = verify_token(refresh_token) + if token_data is None: + return None + return create_tokens( + user_id=token_data.id, + username=token_data.username, + role=token_data.role + ) diff --git a/services/db.py b/services/db.py index acca4d0..74777bf 100644 --- a/services/db.py +++ b/services/db.py @@ -4,7 +4,7 @@ from ..models.user import Base, User, UserRole from sqlalchemy import select from contextlib import asynccontextmanager from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG -from services.user_services import get_password_hash +from user_services import get_password_hash from typing import AsyncGenerator # 全局数据库引擎实例