From 375077be693aaf1ec068e5711b17c67bc7512413 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Fri, 14 Feb 2025 16:49:22 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=AE=8C=E6=88=90=E8=83=BD?= =?UTF-8?q?=E4=BD=BF=E7=94=A8refresh=20token=E8=AE=BF=E9=97=AE=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- schemas/auth.py | 1 + services/auth.py | 22 +++++++++------------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/schemas/auth.py b/schemas/auth.py index bde44ff..67fe815 100644 --- a/schemas/auth.py +++ b/schemas/auth.py @@ -12,6 +12,7 @@ class TokenPayload(BaseModel): username: str role: str exp: int + token_type: str class LoginRequest(BaseModel): username: str diff --git a/services/auth.py b/services/auth.py index 03352cc..3fb46bb 100644 --- a/services/auth.py +++ b/services/auth.py @@ -14,29 +14,24 @@ def get_current_time() -> int: """获取当前UTC时间戳""" return int(time.time()) -def create_token(user_id: int, username: str, role: str, expire_delta) -> str: +def create_token(user_id: int, username: str, role: str, token_type: str = "access") -> str: """创建JWT token""" + expire_delta = ACCESS_TOKEN_EXPIRE if token_type == "access" else REFRESH_TOKEN_EXPIRE expire = get_current_time() + int(expire_delta.total_seconds()) + to_encode = { "id": user_id, "username": username, "role": role, - "exp": expire + "exp": expire, + "token_type": token_type } return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) -def create_access_token(user_id: int, username: str, role: str) -> str: - """创建access token""" - return create_token(user_id, username, role, ACCESS_TOKEN_EXPIRE) - -def create_refresh_token(user_id: int, username: str, role: str) -> str: - """创建refresh token""" - return create_token(user_id, username, role, REFRESH_TOKEN_EXPIRE) - def create_tokens_response(user_id: int, username: str, role: str) -> TokenResponse: """创建access token和refresh token""" - access_token = create_access_token(user_id, username, role) - refresh_token = create_refresh_token(user_id, username, role) + access_token = create_token(user_id, username, role, "access") + refresh_token = create_token(user_id, username, role, "refresh") # 获取token的过期时间 access_token_exp = get_current_time() + int(ACCESS_TOKEN_EXPIRE.total_seconds()) @@ -58,7 +53,8 @@ def verify_token(token: str) -> Optional[TokenPayload]: id=payload.get("id"), username=payload.get("username"), role=payload.get("role"), - exp=payload.get("exp") + exp=payload.get("exp"), + token_type=payload.get("token_type") ) except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): return None