From fb41b442da6274bb19178bc8f330cb1eb4bc30c3 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Sun, 19 Jan 2025 23:44:11 +0800 Subject: [PATCH 01/26] =?UTF-8?q?=E9=A1=B9=E7=9B=AE=E5=88=9D=E5=A7=8B?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- devdoc.md | 293 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 293 insertions(+) create mode 100644 devdoc.md diff --git a/devdoc.md b/devdoc.md new file mode 100644 index 0000000..6fbd059 --- /dev/null +++ b/devdoc.md @@ -0,0 +1,293 @@ +## 项目概述 + +本项目是一个简单的用户管理系统,支持三种用户角色:系统管理员、管理员和普通用户。系统管理员在项目启动时自动生成,且唯一。所有账户必须由管理员手动创建。用户表包含密码、ID、角色、描述等信息,管理员可以对用户表进行增删改查,普通用户只能读取用户表。 + +## 技术栈 + +* **后端技术栈**:FastAPI、SQLAlchemy、MySQL、bcrypt、PyJWT、logging +* **前端技术栈**:Vue、Vue Router、Pinia、Vite、Axios、Element Plus + +## 数据库设计 + +### 用户表结构 + +| 字段名 | 数据类型 | 约束条件 | 默认值 | 说明 | +| -------- | ---------- | ---------- | -------- | ---------------------------------------------------- | +| `id` | `INT` | `PRIMARY KEY`, `AUTO_INCREMENT` | - | 用户ID,主键,自增 | +| `username` | `VARCHAR(50)` | `NOT NULL`, `UNIQUE` | - | 用户名,唯一 | +| `password` | `VARCHAR(255)` | `NOT NULL` | - | 用户密码,使用哈希加密存储 | +| `role` | `ENUM` | `NOT NULL` | `'user'` | 用户角色,枚举类型,可选值为 `system_admin`, `admin`, `user` | +| `description` | `TEXT` | - | - | 用户描述,可选 | +| `created_at` | `DATETIME` | - | `CURRENT_TIMESTAMP` | 用户创建时间,默认值为当前时间 | +| `updated_at` | `DATETIME` | - | `CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP` | 用户信息更新时间,默认值为当前时间,更新时自动更新 | + +### 数据库初始化 + +* 系统初始化时,后端控制层会检查表是否存在,若不存在则自动创建表。 +* 检查表中是否存在系统管理员,若不存在则创建一个名为`admin`、密码为`password`、描述为`default system admin`的系统管理员。 + +## 鉴权设计 + +### 角色权限 + +* **系统管理员、管理员**:可以访问所有 API,包括用户增删改查。 +* **普通用户**:只能访问用户列表(只读)。 + +### JWT 鉴权 + +JWT(JSON Web Token)用于用户身份验证和权限控制。JWT 包含以下信息: + +* **Payload**:用户 ID、用户名、角色、Token 过期时间等。 +* **签名**:使用后端密钥对 Payload 进行签名,确保 Token 的完整性和安全性。 + +### JWT 自动过期机制 + +* **Access Token**:用于常规 API 请求,有效期较短(如 30 分钟)。 +* **Refresh Token**:用于刷新 Access Token,有效期较长(如 7 天)。 + +#### Token 刷新流程 + +1. 客户端使用 Refresh Token 请求 `/api/auth/refresh` 接口。 +2. 服务端验证 Refresh Token 的有效性。 +3. 服务端生成新的 Access Token 和 Refresh Token,并返回给客户端。 +4. 客户端更新本地存储的 Token。 + +## API 设计 + +### 用户登录 + +* **URL**: `/api/auth/login` +* **HTTP 方法**: POST +* **请求体**: + + ```json + { + "username": "admin", + "password": "password123" + } + ``` +* **响应**: + + * 成功: + + ```json + { + "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "token_type": "bearer", + "expires_in": 1800 + } + ``` + * 失败: + + ```json + { + "detail": "Invalid username or password" + } + ``` + +### 用户登出 + +* **URL**: `/api/auth/logout` +* **HTTP 方法**: POST +* **请求头**: + + * `Authorization: Bearer ` +* **响应**: + + * 成功: + + ```json + { + "message": "Successfully logged out" + } + ``` + +### 刷新 JWT Token + +* **URL**: `/api/auth/refresh` +* **HTTP 方法**: POST +* **请求体**: + + ```json + { + "refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." + } + ``` +* **响应**: + + * 成功: + + ```json + { + "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "token_type": "bearer", + "expires_in": 1800 + } + ``` + * 失败: + + ```json + { + "detail": "Invalid refresh token" + } + ``` + +### 获取用户列表 + +* **URL**: `/api/users` +* **HTTP 方法**: GET +* **请求头**: + + * `Authorization: Bearer ` +* **查询参数**: + + * `page`:页码(默认 1) + * `limit`:每页条数(默认 10) + * `role`:按角色过滤(可选) +* **响应**: + + ```json + { + "total": 100, + "users": [ + { + "id": 1, + "username": "admin", + "role": "system_admin", + "description": "Default system admin", + "created_at": "2023-10-01T12:00:00Z", + "updated_at": "2023-10-01T12:00:00Z" + }, + ... + ] + } + ``` + +### 创建用户 + +* **URL**: `/api/users` +* **HTTP 方法**: POST +* **请求头**: + + * `Authorization: Bearer ` +* **请求体**: + + ```json + { + "username": "new_user", + "password": "new_password123", + "role": "user", + "description": "New user description" + } + ``` +* **响应**: + + * 成功: + + ```json + { + "id": 2, + "username": "new_user", + "role": "user", + "description": "New user description", + "created_at": "2023-10-01T12:00:00Z", + "updated_at": "2023-10-01T12:00:00Z" + } + ``` + +### 更新用户信息 + +* **URL**: `/api/users/{user_id}` +* **HTTP 方法**: PUT +* **请求头**: + + * `Authorization: Bearer ` +* **请求体**: + + ```json + { + "username": "updated_user", + "role": "admin", + "description": "Updated user description" + } + ``` +* **响应**: + + * 成功: + + ```json + { + "id": 1, + "username": "updated_user", + "role": "admin", + "description": "Updated user description", + "created_at": "2023-10-01T12:00:00Z", + "updated_at": "2023-10-01T12:30:00Z" + } + ``` + +### 删除用户 + +* **URL**: `/api/users/{user_id}` +* **HTTP 方法**: DELETE +* **请求头**: + + * `Authorization: Bearer ` +* **响应**: + + * 成功: + + ```json + { + "message": "User deleted successfully" + } + ``` + +## 单元测试 + +### 测试框架 + +* 使用 `pytest` 进行单元测试。 +* 使用 `requests` 库模拟 API 请求。 + +### 测试用例 + +1. **用户登录**: + + * 测试正确的用户名和密码。 + * 测试错误的用户名和密码。 +2. **用户登出**: + + * 测试已登录用户登出。 + * 测试未登录用户登出。 +3. **刷新 JWT Token**: + + * 测试有效的 Refresh Token。 + * 测试无效的 Refresh Token。 +4. **获取用户列表**: + + * 测试不同角色的用户访问权限。 + * 测试分页和过滤功能。 +5. **创建用户**: + + * 测试管理员创建用户。 + * 测试普通用户尝试创建用户。 +6. **更新用户信息**: + + * 测试管理员更新用户信息。 + * 测试普通用户尝试更新用户信息。 +7. **删除用户**: + + * 测试管理员删除用户。 + * 测试普通用户尝试删除用户。 + +## 日志记录 + +* 使用 `logging` 模块记录系统日志。 +* 日志级别包括 `INFO`、`WARNING`、`ERROR`。 +* 日志内容包括用户操作、API 请求、错误信息等。 + +## 后续开发 +* **部署文档**:包括 Docker 容器化、CI/CD 流程、环境变量配置等。 \ No newline at end of file From 287a1e5629bfcdbd22e02b5d40bacbe5be028a0b Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Sun, 19 Jan 2025 23:45:13 +0800 Subject: [PATCH 02/26] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=9F=BA=E6=9C=AC=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/__init__.py | 5 +++++ models/auth.py | 11 +++++++++++ models/database.py | 25 +++++++++++++++++++++++++ models/user.py | 25 +++++++++++++++++++++++++ requirements.txt | 8 ++++++++ 5 files changed, 74 insertions(+) create mode 100644 models/__init__.py create mode 100644 models/auth.py create mode 100644 models/database.py create mode 100644 models/user.py create mode 100644 requirements.txt diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..a1ec5f8 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,5 @@ +from .user import User +from .auth import get_password_hash, verify_password +from .database import init_db + +__all__ = ["User", "get_password_hash", "verify_password", "init_db"] diff --git a/models/auth.py b/models/auth.py new file mode 100644 index 0000000..9f22dee --- /dev/null +++ b/models/auth.py @@ -0,0 +1,11 @@ +from passlib.context import CryptContext + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +def get_password_hash(password: str) -> str: + """Generate password hash using bcrypt""" + return pwd_context.hash(password) + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify password against stored hash""" + return pwd_context.verify(plain_password, hashed_password) diff --git a/models/database.py b/models/database.py new file mode 100644 index 0000000..de2f5f9 --- /dev/null +++ b/models/database.py @@ -0,0 +1,25 @@ +from sqlalchemy.ext.asyncio import AsyncEngine +from .user import Base, User, UserRole +from sqlalchemy import select +from .auth import get_password_hash + +async def init_db(engine: AsyncEngine): + """Initialize database""" + async with engine.begin() as conn: + # Create all tables + await conn.run_sync(Base.metadata.create_all) + + # Check if system admin exists + result = await conn.execute( + select(User).where(User.role == UserRole.SYSTEM_ADMIN) + ) + if not result.scalars().first(): + # Create default system admin + admin = User( + username="admin", + password=get_password_hash("password"), + role=UserRole.SYSTEM_ADMIN, + description="default system admin" + ) + conn.add(admin) + await conn.commit() diff --git a/models/user.py b/models/user.py new file mode 100644 index 0000000..e36e3dd --- /dev/null +++ b/models/user.py @@ -0,0 +1,25 @@ +from datetime import datetime, timezone +from enum import Enum +from sqlalchemy import Column, Integer, String, Text, DateTime, Enum as SQLEnum +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class UserRole(str, Enum): + SYSTEM_ADMIN = "system_admin" + ADMIN = "admin" + USER = "user" + +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True, autoincrement=True) + username = Column(String(50), unique=True, nullable=False) + password = Column(String(255), nullable=False) + role = Column(SQLEnum(UserRole), nullable=False, default=UserRole.USER) + description = Column(Text) + created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + + def __repr__(self): + return f"" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f333860 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +fastapi>=0.95.2 +sqlalchemy>=2.0.15 +passlib>=1.7.4 +bcrypt>=4.0.1 +pyjwt>=2.6.0 +pytest>=7.3.1 +requests>=2.28.2 +uvicorn[standard]>=0.21.1 From 287e38aa48d3130c0c70f921ed723864750b00b2 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Mon, 20 Jan 2025 14:18:33 +0800 Subject: [PATCH 03/26] =?UTF-8?q?=E5=AE=8C=E6=88=90=E4=BA=86config?= =?UTF-8?q?=E6=96=87=E4=BB=B6=EF=BC=8C=E5=88=9D=E6=AD=A5=E4=BF=AE=E6=AD=A3?= =?UTF-8?q?=E4=BA=86model=E7=9A=84=E4=B8=80=E4=BA=9B=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 35 +++++++++++++++++++++++++++++++++++ models/__init__.py | 3 +-- {models => service}/auth.py | 0 3 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 config.py rename {models => service}/auth.py (100%) diff --git a/config.py b/config.py new file mode 100644 index 0000000..10e3097 --- /dev/null +++ b/config.py @@ -0,0 +1,35 @@ +import os +from datetime import timedelta + +# 数据库配置 +DATABASE_CONFIG = { + 'host': os.getenv('DB_HOST', 'localhost'), + 'port': int(os.getenv('DB_PORT', '3306')), + 'user': os.getenv('DB_USER', 'root'), + 'password': os.getenv('DB_PASSWORD', 'password'), + 'database': os.getenv('DB_NAME', 'user_manage'), + 'charset': 'utf8mb4' +} + +# JWT配置 +JWT_CONFIG = { + 'secret_key': os.getenv('JWT_SECRET_KEY', 'your-secret-key'), + 'algorithm': 'HS256', + 'access_token_expire': timedelta(minutes=int(os.getenv('JWT_ACCESS_EXPIRE_MINUTES', '30'))), + 'refresh_token_expire': timedelta(days=int(os.getenv('JWT_REFRESH_EXPIRE_DAYS', '7'))) +} + +# 日志配置 +LOGGING_CONFIG = { + 'level': os.getenv('LOG_LEVEL', 'INFO'), + 'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + 'filename': os.getenv('LOG_FILE', 'app.log') +} + +# 系统管理员初始配置 +SYSTEM_ADMIN_CONFIG = { + 'username': os.getenv('ADMIN_USERNAME', 'admin'), + 'password': os.getenv('ADMIN_PASSWORD', 'password'), + 'role': os.getenv('ADMIN_ROLE', 'system_admin'), + 'description': os.getenv('ADMIN_DESCRIPTION', 'default system admin') +} diff --git a/models/__init__.py b/models/__init__.py index a1ec5f8..74df867 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,5 +1,4 @@ from .user import User -from .auth import get_password_hash, verify_password from .database import init_db -__all__ = ["User", "get_password_hash", "verify_password", "init_db"] +__all__ = ["User", "init_db"] diff --git a/models/auth.py b/service/auth.py similarity index 100% rename from models/auth.py rename to service/auth.py From c73edc794f020030fd1376abfaed8d527d5dd5e1 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Mon, 20 Jan 2025 14:23:36 +0800 Subject: [PATCH 04/26] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E4=BA=86=E7=8E=B0?= =?UTF-8?q?=E6=9C=89=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/database.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/models/database.py b/models/database.py index de2f5f9..706098b 100644 --- a/models/database.py +++ b/models/database.py @@ -1,7 +1,8 @@ from sqlalchemy.ext.asyncio import AsyncEngine from .user import Base, User, UserRole from sqlalchemy import select -from .auth import get_password_hash +from config import SYSTEM_ADMIN_CONFIG +from service.auth import get_password_hash async def init_db(engine: AsyncEngine): """Initialize database""" @@ -16,10 +17,10 @@ async def init_db(engine: AsyncEngine): if not result.scalars().first(): # Create default system admin admin = User( - username="admin", - password=get_password_hash("password"), + username=SYSTEM_ADMIN_CONFIG['username'], + password=get_password_hash(SYSTEM_ADMIN_CONFIG['password']), role=UserRole.SYSTEM_ADMIN, - description="default system admin" + description=SYSTEM_ADMIN_CONFIG['description'] ) conn.add(admin) await conn.commit() From 01ff669e8d8f540d1d590112d8633c6f67154a0a Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Mon, 20 Jan 2025 15:26:40 +0800 Subject: [PATCH 05/26] =?UTF-8?q?=E5=81=9A=E4=BA=86=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E7=9A=84=E8=B0=83=E6=95=B4=EF=BC=8C=E6=96=B0?= =?UTF-8?q?=E5=BB=BA=E4=BA=86service=E6=96=87=E4=BB=B6=E5=A4=B9=E5=AD=98?= =?UTF-8?q?=E6=94=BE=E5=86=85=E9=83=A8=E6=9C=8D=E5=8A=A1=E3=80=82=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E4=BA=86git=20ignore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 30 +++++++++++++++++++++++ models/__init__.py | 3 +-- {service => services}/auth.py | 0 models/database.py => services/init_db.py | 4 +-- 4 files changed, 33 insertions(+), 4 deletions(-) create mode 100644 .gitignore rename {service => services}/auth.py (100%) rename models/database.py => services/init_db.py (90%) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5a87672 --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Virtual environments +venv/ +env/ +.venv/ + +# IDE specific files +.vscode/ +.idea/ + +# Environment variables +.env +.env.local + +# Logs and databases +*.log +*.sqlite3 + +# Testing +.coverage +htmlcov/ + +# Python packaging +*.egg-info/ +dist/ +build/ diff --git a/models/__init__.py b/models/__init__.py index 74df867..35b01f3 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,4 +1,3 @@ from .user import User -from .database import init_db -__all__ = ["User", "init_db"] +__all__ = ["User"] diff --git a/service/auth.py b/services/auth.py similarity index 100% rename from service/auth.py rename to services/auth.py diff --git a/models/database.py b/services/init_db.py similarity index 90% rename from models/database.py rename to services/init_db.py index 706098b..d285adb 100644 --- a/models/database.py +++ b/services/init_db.py @@ -1,8 +1,8 @@ from sqlalchemy.ext.asyncio import AsyncEngine -from .user import Base, User, UserRole +from ..models.user import Base, User, UserRole from sqlalchemy import select from config import SYSTEM_ADMIN_CONFIG -from service.auth import get_password_hash +from services.auth import get_password_hash async def init_db(engine: AsyncEngine): """Initialize database""" From e9be684b2e6e6f7c94276999d491a61c25547c72 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Mon, 20 Jan 2025 22:15:34 +0800 Subject: [PATCH 06/26] =?UTF-8?q?=E5=AE=8C=E6=88=90=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E7=9B=B8=E5=85=B3=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- schemas/auth.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++++ schemas/user.py | 43 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 schemas/auth.py create mode 100644 schemas/user.py diff --git a/schemas/auth.py b/schemas/auth.py new file mode 100644 index 0000000..14ecef6 --- /dev/null +++ b/schemas/auth.py @@ -0,0 +1,52 @@ +from pydantic import BaseModel +from datetime import datetime + +class Token(BaseModel): + """ + 表示一个JWT令牌的模型类。 + + Attributes: + access_token (str): 访问令牌,用于身份验证和授权。 + refresh_token (str): 刷新令牌,用于获取新的访问令牌。 + token_type (str): 令牌类型,通常是"Bearer"。 + expires_in (int): 访问令牌的有效期,以秒为单位。 + """ + access_token: str + refresh_token: str + token_type: str + expires_in: int + +class TokenData(BaseModel): + """ + 表示JWT令牌中存储的数据的模型类。 + + Attributes: + id (int): 用户的唯一标识符。 + username (str): 用户的用户名。 + role (str): 用户的角色或权限。 + exp (datetime): 令牌的过期时间。 + """ + id: int + username: str + role: str + exp: datetime + +class LoginRequest(BaseModel): + """ + 表示用户登录请求的模型类。 + + Attributes: + username (str): 用户登录时输入的用户名。 + password (str): 用户登录时输入的密码。 + """ + username: str + password: str + +class RefreshTokenRequest(BaseModel): + """ + 表示刷新令牌请求的模型类。 + + Attributes: + refresh_token (str): 用于刷新访问令牌的刷新令牌。 + """ + refresh_token: str \ No newline at end of file diff --git a/schemas/user.py b/schemas/user.py new file mode 100644 index 0000000..46dc5f3 --- /dev/null +++ b/schemas/user.py @@ -0,0 +1,43 @@ +from datetime import datetime +from enum import Enum +from pydantic import BaseModel, Field, root_validator +from typing import Optional + +# 用户角色枚举 +class UserRole(str, Enum): + SYSTEM_ADMIN = "system_admin" + ADMIN = "admin" + USER = "user" + +# 基础用户模型 +class UserBase(BaseModel): + username: str = Field(..., max_length=50, description="用户名") + role: UserRole = Field(default=UserRole.USER, description="用户角色") + description: Optional[str] = Field(None, max_length=255, description="用户描述") + +# 用户创建模型 +class UserCreate(UserBase): + password: str = Field(..., min_length=6, max_length=255, description="用户密码") + + +# 用户更新模型 +class UserUpdate(BaseModel): + username: Optional[str] = Field(None, max_length=50, description="用户名") + role: Optional[UserRole] = Field(None, description="用户角色") + description: Optional[str] = Field(None, max_length=255, description="用户描述") + + # 可选:确保至少更新一个字段 + @root_validator + def validate_at_least_one_field(cls, values): + if not any(values.values()): + raise ValueError("至少需要更新一个字段") + return values + +# 用户响应模型 +class UserResponse(UserBase): + id: int = Field(..., description="用户ID") + created_at: datetime = Field(..., description="创建时间") + updated_at: datetime = Field(..., description="更新时间") + + class Config: + orm_mode = True # 允许从 ORM 对象加载数据 \ No newline at end of file From 7b7828daa0018cc5eb4f4891cf07041895207a3e Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Mon, 20 Jan 2025 22:16:45 +0800 Subject: [PATCH 07/26] =?UTF-8?q?=E6=B3=A8=E9=87=8A=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- services/auth.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/services/auth.py b/services/auth.py index 9f22dee..1b6218b 100644 --- a/services/auth.py +++ b/services/auth.py @@ -1,11 +1,12 @@ from passlib.context import CryptContext +# 创建一个密码上下文对象,指定使用 bcrypt 加密算法 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def get_password_hash(password: str) -> str: - """Generate password hash using bcrypt""" + """生成密码的哈希值,使用 bcrypt 算法进行加密""" return pwd_context.hash(password) def verify_password(plain_password: str, hashed_password: str) -> bool: - """Verify password against stored hash""" + """验证输入的明文密码是否与存储的哈希密码匹配""" return pwd_context.verify(plain_password, hashed_password) From 1c82ab47ddd0ae58816e807a30fac7329d5431be Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 13:30:33 +0800 Subject: [PATCH 08/26] =?UTF-8?q?=E4=B8=BAinitdb=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E4=B8=AD=E6=96=87=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- services/init_db.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/services/init_db.py b/services/init_db.py index d285adb..23c8125 100644 --- a/services/init_db.py +++ b/services/init_db.py @@ -5,17 +5,17 @@ from config import SYSTEM_ADMIN_CONFIG from services.auth import get_password_hash async def init_db(engine: AsyncEngine): - """Initialize database""" + """初始化数据库""" async with engine.begin() as conn: - # Create all tables + # 创建所有表 await conn.run_sync(Base.metadata.create_all) - # Check if system admin exists + # 检查系统管理员是否存在 result = await conn.execute( select(User).where(User.role == UserRole.SYSTEM_ADMIN) ) if not result.scalars().first(): - # Create default system admin + # 创建默认系统管理员 admin = User( username=SYSTEM_ADMIN_CONFIG['username'], password=get_password_hash(SYSTEM_ADMIN_CONFIG['password']), From b58fd6a8c6c9d5711732c1ee0288515a3e6ab580 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 13:41:01 +0800 Subject: [PATCH 09/26] =?UTF-8?q?=E5=AE=8C=E6=88=90=E8=B7=AF=E7=94=B1?= =?UTF-8?q?=E5=92=8Capi=E9=AA=A8=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ routes/auth.py | 17 +++++++++++++++++ routes/users.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+) create mode 100644 main.py create mode 100644 routes/auth.py create mode 100644 routes/users.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..6704193 --- /dev/null +++ b/main.py @@ -0,0 +1,45 @@ +import logging +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from config import DATABASE_CONFIG, JWT_CONFIG, LOGGING_CONFIG, SYSTEM_ADMIN_CONFIG +from sqlalchemy.ext.asyncio import create_async_engine +from services.init_db import init_db +from routes.auth import auth_router +from routes.users import users_router + +# 配置日志 +logging.basicConfig( + level=LOGGING_CONFIG['level'], + format=LOGGING_CONFIG['format'], + filename=LOGGING_CONFIG['filename'] +) +logger = logging.getLogger(__name__) + +# 初始化FastAPI应用 +app = FastAPI( + title="User Management System", + description="API for managing users with role-based access control", + version="1.0.0" +) + +# 创建数据库引擎 +engine = create_async_engine( + f"mysql+asyncmy://{DATABASE_CONFIG['user']}:{DATABASE_CONFIG['password']}@" + f"{DATABASE_CONFIG['host']}:{DATABASE_CONFIG['port']}/{DATABASE_CONFIG['database']}", + echo=True +) + +# 初始化数据库 +@app.on_event("startup") +async def startup_event(): + logger.info("Initializing database...") + await init_db(engine) + logger.info("Database initialized successfully") + +# 注册路由 +app.include_router(auth_router, prefix="/api/auth", tags=["auth"]) +app.include_router(users_router, prefix="/api/users", tags=["users"]) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/routes/auth.py b/routes/auth.py new file mode 100644 index 0000000..d23d65c --- /dev/null +++ b/routes/auth.py @@ -0,0 +1,17 @@ +from schemas.auth import Token +from fastapi import APIRouter + + +router = APIRouter(prefix="/api/auth", tags=["auth"]) + +@router.post("/login", response_model=Token) +async def login(): + pass + +@router.post("/logout") +async def logout(): + pass + +@router.post("/refresh", response_model=Token) +async def refresh_token(): + pass diff --git a/routes/users.py b/routes/users.py new file mode 100644 index 0000000..5ebdcdb --- /dev/null +++ b/routes/users.py @@ -0,0 +1,30 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from typing import List, Optional +from models.user import User +from schemas.user import UserCreate, UserUpdate, UserResponse +from services.auth import get_current_user + +router = APIRouter() + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/auth/login") + +@router.get("/", response_model=List[UserResponse]) +async def get_users(): + # 实现获取用户列表逻辑 + pass + +@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) +async def create_user(): + # 实现创建用户逻辑 + pass + +@router.put("/{user_id}", response_model=UserResponse) +async def update_user(): + # 实现更新用户逻辑 + pass + +@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_user(): + # 实现删除用户逻辑 + pass From ed9a00cc54e664486f805d2a0dd8d47a42aa5ed8 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 14:22:09 +0800 Subject: [PATCH 10/26] =?UTF-8?q?=E5=AE=8C=E6=88=90=E4=BA=86=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E7=9B=B8=E5=85=B3=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 21 +++++++------- services/db.py | 67 +++++++++++++++++++++++++++++++++++++++++++++ services/init_db.py | 26 ------------------ 3 files changed, 77 insertions(+), 37 deletions(-) create mode 100644 services/db.py delete mode 100644 services/init_db.py diff --git a/main.py b/main.py index 6704193..0346ec8 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,8 @@ import logging from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from config import DATABASE_CONFIG, JWT_CONFIG, LOGGING_CONFIG, SYSTEM_ADMIN_CONFIG -from sqlalchemy.ext.asyncio import create_async_engine -from services.init_db import init_db +from config import JWT_CONFIG, LOGGING_CONFIG, SYSTEM_ADMIN_CONFIG +from services.db import create_db_engine, init_db, close_db_connection from routes.auth import auth_router from routes.users import users_router @@ -22,20 +21,20 @@ app = FastAPI( version="1.0.0" ) -# 创建数据库引擎 -engine = create_async_engine( - f"mysql+asyncmy://{DATABASE_CONFIG['user']}:{DATABASE_CONFIG['password']}@" - f"{DATABASE_CONFIG['host']}:{DATABASE_CONFIG['port']}/{DATABASE_CONFIG['database']}", - echo=True -) - -# 初始化数据库 +# 数据库初始化 @app.on_event("startup") async def startup_event(): logger.info("Initializing database...") + engine = create_db_engine() await init_db(engine) logger.info("Database initialized successfully") +@app.on_event("shutdown") +async def shutdown_event(): + logger.info("Closing database connections...") + await close_db_connection() + logger.info("Database connections closed") + # 注册路由 app.include_router(auth_router, prefix="/api/auth", tags=["auth"]) app.include_router(users_router, prefix="/api/users", tags=["users"]) diff --git a/services/db.py b/services/db.py new file mode 100644 index 0000000..3f7c6cb --- /dev/null +++ b/services/db.py @@ -0,0 +1,67 @@ +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +from ..models.user import Base, User, UserRole +from sqlalchemy import select +from contextlib import asynccontextmanager +from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG +from services.auth import get_password_hash +from typing import AsyncGenerator + +# 全局数据库引擎实例 +_engine: AsyncEngine | None = None + +def create_db_engine() -> AsyncEngine: + """创建数据库引擎""" + return create_async_engine( + f"mysql+asyncmy://{DATABASE_CONFIG['user']}:{DATABASE_CONFIG['password']}@" + f"{DATABASE_CONFIG['host']}:{DATABASE_CONFIG['port']}/{DATABASE_CONFIG['database']}", + echo=True + ) + +def get_db_engine() -> AsyncEngine: + """获取全局数据库引擎实例""" + if _engine is None: + raise RuntimeError("Database engine not initialized") + return _engine + +@asynccontextmanager +async def get_db_session() -> AsyncGenerator[AsyncSession, None]: + """获取数据库会话""" + async with AsyncSession(get_db_engine()) as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + +async def init_db(engine: AsyncEngine): + """初始化数据库""" + global _engine + _engine = engine + + async with engine.begin() as conn: + # 创建所有表 + await conn.run_sync(Base.metadata.create_all) + + # 检查系统管理员是否存在 + result = await conn.execute( + select(User).where(User.role == UserRole.SYSTEM_ADMIN) + ) + if not result.scalars().first(): + # 创建默认系统管理员 + admin = User( + username=SYSTEM_ADMIN_CONFIG['username'], + password=get_password_hash(SYSTEM_ADMIN_CONFIG['password']), + role=UserRole.SYSTEM_ADMIN, + description=SYSTEM_ADMIN_CONFIG['description'] + ) + conn.add(admin) + await conn.commit() + +async def close_db_connection(): + """关闭数据库连接""" + global _engine + if _engine is not None: + await _engine.dispose() + _engine = None diff --git a/services/init_db.py b/services/init_db.py deleted file mode 100644 index 23c8125..0000000 --- a/services/init_db.py +++ /dev/null @@ -1,26 +0,0 @@ -from sqlalchemy.ext.asyncio import AsyncEngine -from ..models.user import Base, User, UserRole -from sqlalchemy import select -from config import SYSTEM_ADMIN_CONFIG -from services.auth import get_password_hash - -async def init_db(engine: AsyncEngine): - """初始化数据库""" - async with engine.begin() as conn: - # 创建所有表 - await conn.run_sync(Base.metadata.create_all) - - # 检查系统管理员是否存在 - result = await conn.execute( - select(User).where(User.role == UserRole.SYSTEM_ADMIN) - ) - if not result.scalars().first(): - # 创建默认系统管理员 - admin = User( - username=SYSTEM_ADMIN_CONFIG['username'], - password=get_password_hash(SYSTEM_ADMIN_CONFIG['password']), - role=UserRole.SYSTEM_ADMIN, - description=SYSTEM_ADMIN_CONFIG['description'] - ) - conn.add(admin) - await conn.commit() From 093f3e75e1bb0dfdbd84d5adc066b3fc4a9014ef Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 15:06:34 +0800 Subject: [PATCH 11/26] =?UTF-8?q?=E5=AE=8C=E6=88=90=E4=BA=86=E7=94=A8?= =?UTF-8?q?=E6=88=B7crud=E7=9A=84=E7=9B=B8=E5=85=B3=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- services/auth.py | 12 ------- services/user_service.py | 73 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 12 deletions(-) delete mode 100644 services/auth.py create mode 100644 services/user_service.py diff --git a/services/auth.py b/services/auth.py deleted file mode 100644 index 1b6218b..0000000 --- a/services/auth.py +++ /dev/null @@ -1,12 +0,0 @@ -from passlib.context import CryptContext - -# 创建一个密码上下文对象,指定使用 bcrypt 加密算法 -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - -def get_password_hash(password: str) -> str: - """生成密码的哈希值,使用 bcrypt 算法进行加密""" - return pwd_context.hash(password) - -def verify_password(plain_password: str, hashed_password: str) -> bool: - """验证输入的明文密码是否与存储的哈希密码匹配""" - return pwd_context.verify(plain_password, hashed_password) diff --git a/services/user_service.py b/services/user_service.py new file mode 100644 index 0000000..984d0f0 --- /dev/null +++ b/services/user_service.py @@ -0,0 +1,73 @@ +from typing import List, Optional +from sqlalchemy import select, update, delete +from sqlalchemy.ext.asyncio import AsyncSession +from passlib.context import CryptContext +from models.user import User +from schemas.user import UserCreate, UserUpdate, UserResponse + +# 创建一个密码上下文对象,指定使用 bcrypt 加密算法 +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +class UserService: + @staticmethod + async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse: + """创建用户""" + hashed_password = pwd_context.hash(user_data.password) + user = User( + username=user_data.username, + password=hashed_password, + role=user_data.role, + description=user_data.description + ) + session.add(user) + await session.commit() + await session.refresh(user) + return UserResponse.from_orm(user) + + @staticmethod + async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]: + """根据ID获取用户""" + result = await session.execute(select(User).where(User.id == user_id)) + user = result.scalars().first() + return UserResponse.from_orm(user) if user else None + + @staticmethod + async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]: + """获取用户列表""" + result = await session.execute(select(User).offset(skip).limit(limit)) + users = result.scalars().all() + return [UserResponse.from_orm(user) for user in users] + + @staticmethod + async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]: + """更新用户信息""" + await session.execute( + update(User) + .where(User.id == user_id) + .values(**user_data.dict(exclude_unset=True)) + ) + await session.commit() + return await UserService.get_user(session, user_id) + + @staticmethod + async def delete_user(session: AsyncSession, user_id: int) -> bool: + """删除用户""" + result = await session.execute(delete(User).where(User.id == user_id)) + await session.commit() + return result.rowcount > 0 + + @staticmethod + def verify_password(plain_password: str, hashed_password: str) -> bool: + """验证输入的明文密码是否与存储的哈希密码匹配""" + return pwd_context.verify(plain_password, hashed_password) + + @staticmethod + async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]: + """验证用户登录""" + result = await session.execute(select(User).where(User.username == username)) + user = result.scalars().first() + if not user: + return None + if not UserService.verify_password(password, user.password): + return None + return UserResponse.from_orm(user) From d9152e85e531f40456c53c195df914abd17a9f6c Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 15:23:44 +0800 Subject: [PATCH 12/26] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E4=BA=86api=E8=B7=AF?= =?UTF-8?q?=E7=94=B1=E7=9A=84=E4=B8=80=E4=BA=9B=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- routes/auth.py | 2 +- routes/users.py | 49 ++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/routes/auth.py b/routes/auth.py index d23d65c..2e0965a 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -2,7 +2,7 @@ from schemas.auth import Token from fastapi import APIRouter -router = APIRouter(prefix="/api/auth", tags=["auth"]) +router = APIRouter(tags=["auth"]) @router.post("/login", response_model=Token) async def login(): diff --git a/routes/users.py b/routes/users.py index 5ebdcdb..4caf448 100644 --- a/routes/users.py +++ b/routes/users.py @@ -1,30 +1,61 @@ from fastapi import APIRouter, Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer from typing import List, Optional -from models.user import User from schemas.user import UserCreate, UserUpdate, UserResponse from services.auth import get_current_user -router = APIRouter() - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/auth/login") +router = APIRouter(tags=["users"]) @router.get("/", response_model=List[UserResponse]) -async def get_users(): +async def get_users( + page: int = 1, + limit: int = 10, + role: Optional[str] = None, + current_user: UserResponse = Depends(get_current_user) +): + if current_user.role not in ["system_admin", "admin"]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only admin can access user list" + ) # 实现获取用户列表逻辑 pass @router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) -async def create_user(): +async def create_user( + user_data: UserCreate, + current_user: UserResponse = Depends(get_current_user) +): + if current_user.role not in ["system_admin", "admin"]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only admin can create users" + ) # 实现创建用户逻辑 pass @router.put("/{user_id}", response_model=UserResponse) -async def update_user(): +async def update_user( + user_id: int, + user_data: UserUpdate, + current_user: UserResponse = Depends(get_current_user) +): + if current_user.role not in ["system_admin", "admin"]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only admin can update users" + ) # 实现更新用户逻辑 pass @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) -async def delete_user(): +async def delete_user( + user_id: int, + current_user: UserResponse = Depends(get_current_user) +): + if current_user.role not in ["system_admin", "admin"]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only admin can delete users" + ) # 实现删除用户逻辑 pass From f172c5d2bc8542687d642b29805eb87d19e9669a Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 20:11:31 +0800 Subject: [PATCH 13/26] =?UTF-8?q?=E9=87=8D=E6=9E=84=E4=BA=86user=5Fservice?= =?UTF-8?q?=E7=9A=84=E5=90=8D=E5=AD=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- services/{user_service.py => user_services.py} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename services/{user_service.py => user_services.py} (94%) diff --git a/services/user_service.py b/services/user_services.py similarity index 94% rename from services/user_service.py rename to services/user_services.py index 984d0f0..a9a77db 100644 --- a/services/user_service.py +++ b/services/user_services.py @@ -8,7 +8,7 @@ from schemas.user import UserCreate, UserUpdate, UserResponse # 创建一个密码上下文对象,指定使用 bcrypt 加密算法 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -class UserService: +class UserServices: @staticmethod async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse: """创建用户""" @@ -47,7 +47,7 @@ class UserService: .values(**user_data.dict(exclude_unset=True)) ) await session.commit() - return await UserService.get_user(session, user_id) + return await UserServices.get_user(session, user_id) @staticmethod async def delete_user(session: AsyncSession, user_id: int) -> bool: @@ -68,6 +68,6 @@ class UserService: user = result.scalars().first() if not user: return None - if not UserService.verify_password(password, user.password): + if not UserServices.verify_password(password, user.password): return None return UserResponse.from_orm(user) From 3b7ac1f6820969fcdc2ac4c9781b399ef2de8b20 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 20:14:15 +0800 Subject: [PATCH 14/26] =?UTF-8?q?=E5=AE=8C=E6=88=90=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E5=90=8E=E7=9A=84=E5=AF=BC=E5=85=A5=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- services/db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/db.py b/services/db.py index 3f7c6cb..acca4d0 100644 --- a/services/db.py +++ b/services/db.py @@ -4,7 +4,7 @@ from ..models.user import Base, User, UserRole from sqlalchemy import select from contextlib import asynccontextmanager from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG -from services.auth import get_password_hash +from services.user_services import get_password_hash from typing import AsyncGenerator # 全局数据库引擎实例 From f1cdbab0f484517776d9b6c949ae3e66b6fe9999 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 21:28:18 +0800 Subject: [PATCH 15/26] =?UTF-8?q?=E5=AE=8C=E4=BA=86=E9=89=B4=E6=9D=83?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E4=BB=A3=E7=A0=81=E5=92=8C=E9=89=B4=E6=9D=83?= =?UTF-8?q?=E4=BE=9D=E8=B5=96=E6=B3=A8=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- routes/depends.py | 18 +++++++++++ routes/users.py | 36 ++++++++++++++++++--- services/auth_service.py | 67 ++++++++++++++++++++++++++++++++++++++++ services/db.py | 2 +- 4 files changed, 117 insertions(+), 6 deletions(-) create mode 100644 routes/depends.py create mode 100644 services/auth_service.py diff --git a/routes/depends.py b/routes/depends.py new file mode 100644 index 0000000..43aebad --- /dev/null +++ b/routes/depends.py @@ -0,0 +1,18 @@ +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from typing import Optional +from schemas.auth import TokenData +from services.auth_service import verify_token + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") + +async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData: + """获取当前用户""" + token_data = verify_token(token) + if token_data is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + return token_data diff --git a/routes/users.py b/routes/users.py index 4caf448..92a648d 100644 --- a/routes/users.py +++ b/routes/users.py @@ -1,7 +1,9 @@ from fastapi import APIRouter, Depends, HTTPException, status from typing import List, Optional +from schemas.auth import TokenData from schemas.user import UserCreate, UserUpdate, UserResponse -from services.auth import get_current_user +from routes.depends import get_current_user +from services.user_services import get_user_by_id router = APIRouter(tags=["users"]) @@ -10,8 +12,14 @@ async def get_users( page: int = 1, limit: int = 10, role: Optional[str] = None, - current_user: UserResponse = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_user) ): + current_user = await get_user_by_id(current_user_token.id) + if current_user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) if current_user.role not in ["system_admin", "admin"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -23,8 +31,14 @@ async def get_users( @router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def create_user( user_data: UserCreate, - current_user: UserResponse = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_user) ): + current_user = await get_user_by_id(current_user_token.id) + if current_user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) if current_user.role not in ["system_admin", "admin"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -37,8 +51,14 @@ async def create_user( async def update_user( user_id: int, user_data: UserUpdate, - current_user: UserResponse = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_user) ): + current_user = await get_user_by_id(current_user_token.id) + if current_user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) if current_user.role not in ["system_admin", "admin"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -50,8 +70,14 @@ async def update_user( @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_user( user_id: int, - current_user: UserResponse = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_user) ): + current_user = await get_user_by_id(current_user_token.id) + if current_user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) if current_user.role not in ["system_admin", "admin"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/services/auth_service.py b/services/auth_service.py new file mode 100644 index 0000000..7549224 --- /dev/null +++ b/services/auth_service.py @@ -0,0 +1,67 @@ +from datetime import datetime +from typing import Optional +import jwt +from config import JWT_CONFIG +from schemas.auth import Token, TokenData + +SECRET_KEY = JWT_CONFIG['secret_key'] +ALGORITHM = JWT_CONFIG['algorithm'] +ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire'] +REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire'] + +def create_access_token(user_id: int, username: str, role: str) -> str: + """创建access token""" + expire = datetime.utcnow() + ACCESS_TOKEN_EXPIRE + to_encode = { + "id": user_id, + "username": username, + "role": role, + "exp": expire + } + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + +def create_refresh_token(user_id: int, username: str, role: str) -> str: + """创建refresh token""" + expire = datetime.utcnow() + REFRESH_TOKEN_EXPIRE + to_encode = { + "id": user_id, + "username": username, + "role": role, + "exp": expire + } + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + +def create_tokens(user_id: int, username: str, role: str) -> Token: + """创建access token和refresh token""" + access_token = create_access_token(user_id, username, role) + refresh_token = create_refresh_token(user_id, username, role) + return Token( + access_token=access_token, + refresh_token=refresh_token, + token_type="bearer", + expires_in=int(ACCESS_TOKEN_EXPIRE.total_seconds()) + ) + +def verify_token(token: str) -> Optional[TokenData]: + """验证token有效性并返回payload,如果token无效则返回None""" + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + return TokenData( + id=payload.get("id"), + username=payload.get("username"), + role=payload.get("role"), + exp=payload.get("exp") + ) + except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): + return None + +def refresh_tokens(refresh_token: str) -> Optional[Token]: + """使用refresh token刷新access token,如果refresh token无效则返回None""" + token_data = verify_token(refresh_token) + if token_data is None: + return None + return create_tokens( + user_id=token_data.id, + username=token_data.username, + role=token_data.role + ) diff --git a/services/db.py b/services/db.py index acca4d0..74777bf 100644 --- a/services/db.py +++ b/services/db.py @@ -4,7 +4,7 @@ from ..models.user import Base, User, UserRole from sqlalchemy import select from contextlib import asynccontextmanager from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG -from services.user_services import get_password_hash +from user_services import get_password_hash from typing import AsyncGenerator # 全局数据库引擎实例 From a90838b79fb6735066d51b5bc5b96aa97cf9fb5f Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 22:14:03 +0800 Subject: [PATCH 16/26] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BA=86users=E7=9A=84?= =?UTF-8?q?=E9=89=B4=E6=9D=83=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E4=BA=86=E4=BE=9D=E8=B5=96=E6=B3=A8=E5=85=A5=E7=9A=84=E6=96=B9?= =?UTF-8?q?=E5=BC=8F=E5=88=A4=E6=96=AD=E7=AE=A1=E7=90=86=E5=91=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- routes/depends.py | 20 +++++++ routes/users.py | 76 ++++++++----------------- services/user_services.py | 113 +++++++++++++++++++------------------- 3 files changed, 100 insertions(+), 109 deletions(-) diff --git a/routes/depends.py b/routes/depends.py index 43aebad..1f2d261 100644 --- a/routes/depends.py +++ b/routes/depends.py @@ -2,6 +2,7 @@ from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from typing import Optional from schemas.auth import TokenData +from schemas.user import UserRole from services.auth_service import verify_token oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") @@ -16,3 +17,22 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData: headers={"WWW-Authenticate": "Bearer"}, ) return token_data + + + +async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenData: + """获取当前用户""" + token_data = verify_token(token) + if token_data is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + if token_data.role not in [UserRole.SYSTEM_ADMIN.value, UserRole.ADMIN.value]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You are not admin", + headers={"WWW-Authenticate": "Bearer"}, + ) + return token_data \ No newline at end of file diff --git a/routes/users.py b/routes/users.py index 92a648d..7e598b6 100644 --- a/routes/users.py +++ b/routes/users.py @@ -1,9 +1,10 @@ from fastapi import APIRouter, Depends, HTTPException, status from typing import List, Optional from schemas.auth import TokenData -from schemas.user import UserCreate, UserUpdate, UserResponse -from routes.depends import get_current_user -from services.user_services import get_user_by_id +from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole +from routes.depends import get_current_user,get_current_admin +from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user +from services.db import get_db_session router = APIRouter(tags=["users"]) @@ -20,68 +21,39 @@ async def get_users( status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) - if current_user.role not in ["system_admin", "admin"]: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin can access user list" - ) - # 实现获取用户列表逻辑 - pass + async with get_db_session() as session: + skip = (page - 1) * limit + users = await get_users(session, skip=skip, limit=limit) + if role: + users = [user for user in users if user.role == role] + return users @router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def create_user( user_data: UserCreate, - current_user_token: TokenData = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_admin) ): - current_user = await get_user_by_id(current_user_token.id) - if current_user is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) - if current_user.role not in ["system_admin", "admin"]: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin can create users" - ) - # 实现创建用户逻辑 - pass + async with get_db_session() as session: + return await create_user(session, user_data) @router.put("/{user_id}", response_model=UserResponse) async def update_user( user_id: int, user_data: UserUpdate, - current_user_token: TokenData = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_admin) ): - current_user = await get_user_by_id(current_user_token.id) - if current_user is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) - if current_user.role not in ["system_admin", "admin"]: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin can update users" - ) - # 实现更新用户逻辑 - pass + async with get_db_session() as session: + return await update_user(session, user_id, user_data) @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_user( user_id: int, - current_user_token: TokenData = Depends(get_current_user) + current_user_token: TokenData = Depends(get_current_admin) ): - current_user = await get_user_by_id(current_user_token.id) - if current_user is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) - if current_user.role not in ["system_admin", "admin"]: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin can delete users" - ) - # 实现删除用户逻辑 - pass + async with get_db_session() as session: + success = await delete_user(session, user_id) + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) diff --git a/services/user_services.py b/services/user_services.py index a9a77db..79e1ce5 100644 --- a/services/user_services.py +++ b/services/user_services.py @@ -8,66 +8,65 @@ from schemas.user import UserCreate, UserUpdate, UserResponse # 创建一个密码上下文对象,指定使用 bcrypt 加密算法 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -class UserServices: - @staticmethod - async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse: - """创建用户""" - hashed_password = pwd_context.hash(user_data.password) - user = User( - username=user_data.username, - password=hashed_password, - role=user_data.role, - description=user_data.description - ) - session.add(user) - await session.commit() - await session.refresh(user) - return UserResponse.from_orm(user) + +async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse: + """创建用户""" + hashed_password = pwd_context.hash(user_data.password) + user = User( + username=user_data.username, + password=hashed_password, + role=user_data.role, + description=user_data.description + ) + session.add(user) + await session.commit() + await session.refresh(user) + return UserResponse.from_orm(user) - @staticmethod - async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]: - """根据ID获取用户""" - result = await session.execute(select(User).where(User.id == user_id)) - user = result.scalars().first() - return UserResponse.from_orm(user) if user else None - @staticmethod - async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]: - """获取用户列表""" - result = await session.execute(select(User).offset(skip).limit(limit)) - users = result.scalars().all() - return [UserResponse.from_orm(user) for user in users] +async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]: + """根据ID获取用户""" + result = await session.execute(select(User).where(User.id == user_id)) + user = result.scalars().first() + return UserResponse.from_orm(user) if user else None - @staticmethod - async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]: - """更新用户信息""" - await session.execute( - update(User) - .where(User.id == user_id) - .values(**user_data.dict(exclude_unset=True)) - ) - await session.commit() - return await UserServices.get_user(session, user_id) - @staticmethod - async def delete_user(session: AsyncSession, user_id: int) -> bool: - """删除用户""" - result = await session.execute(delete(User).where(User.id == user_id)) - await session.commit() - return result.rowcount > 0 +async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]: + """获取用户列表""" + result = await session.execute(select(User).offset(skip).limit(limit)) + users = result.scalars().all() + return [UserResponse.from_orm(user) for user in users] - @staticmethod - def verify_password(plain_password: str, hashed_password: str) -> bool: - """验证输入的明文密码是否与存储的哈希密码匹配""" - return pwd_context.verify(plain_password, hashed_password) - @staticmethod - async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]: - """验证用户登录""" - result = await session.execute(select(User).where(User.username == username)) - user = result.scalars().first() - if not user: - return None - if not UserServices.verify_password(password, user.password): - return None - return UserResponse.from_orm(user) +async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]: + """更新用户信息""" + await session.execute( + update(User) + .where(User.id == user_id) + .values(**user_data.dict(exclude_unset=True)) + ) + await session.commit() + return await get_user(session, user_id) + + +async def delete_user(session: AsyncSession, user_id: int) -> bool: + """删除用户""" + result = await session.execute(delete(User).where(User.id == user_id)) + await session.commit() + return result.rowcount > 0 + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """验证输入的明文密码是否与存储的哈希密码匹配""" + return pwd_context.verify(plain_password, hashed_password) + + +async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]: + """验证用户登录""" + result = await session.execute(select(User).where(User.username == username)) + user = result.scalars().first() + if not user: + return None + if not verify_password(password, user.password): + return None + return UserResponse.from_orm(user) From 1915bcfaa9431b10cf33eb63ccf18e49b7184a67 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 22:32:26 +0800 Subject: [PATCH 17/26] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BA=86=E9=89=B4?= =?UTF-8?q?=E6=9D=83=E7=9B=B8=E5=85=B3=E7=9A=84=E4=BE=9D=E8=B5=96=E6=B3=A8?= =?UTF-8?q?=E5=85=A5=EF=BC=8C=E5=AE=8C=E6=88=90=E4=BA=86auth=20api?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- routes/auth.py | 40 +++++++++++++++++++++++++++++++--------- routes/depends.py | 24 ++++++++++-------------- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/routes/auth.py b/routes/auth.py index 2e0965a..1ba2498 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -1,17 +1,39 @@ -from schemas.auth import Token -from fastapi import APIRouter - +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.ext.asyncio import AsyncSession +from schemas.auth import Token, LoginRequest, RefreshTokenRequest +from services.auth_service import create_tokens, verify_token, refresh_tokens +from services.user_services import authenticate_user +from services.db import get_db router = APIRouter(tags=["auth"]) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @router.post("/login", response_model=Token) -async def login(): - pass +async def login( + login_data: LoginRequest, + session: AsyncSession = Depends(get_db) +): + user = await authenticate_user(session, login_data.username, login_data.password) + if not user: + raise HTTPException( + status_code=401, + detail="Invalid username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + return create_tokens(user.id, user.username, user.role) @router.post("/logout") -async def logout(): - pass +async def logout(token: str = Depends(oauth2_scheme)): + token_data = verify_token(token) + if not token_data: + raise HTTPException(status_code=401, detail="Invalid token") + # TODO: 实现token黑名单 + return {"message": "Successfully logged out"} @router.post("/refresh", response_model=Token) -async def refresh_token(): - pass +async def refresh_token(refresh_data: RefreshTokenRequest): + tokens = refresh_tokens(refresh_data.refresh_token) + if not tokens: + raise HTTPException(status_code=401, detail="Invalid refresh token") + return tokens diff --git a/routes/depends.py b/routes/depends.py index 1f2d261..af1e0fe 100644 --- a/routes/depends.py +++ b/routes/depends.py @@ -7,32 +7,28 @@ from services.auth_service import verify_token oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") -async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData: - """获取当前用户""" +async def _get_token_data(token: str) -> TokenData: + """验证并返回TokenData""" token_data = verify_token(token) if token_data is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid authentication credentials", + detail="Invalid or expired authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) return token_data - +async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData: + """获取当前用户""" + return await _get_token_data(token) async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenData: - """获取当前用户""" - token_data = verify_token(token) - if token_data is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid authentication credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - if token_data.role not in [UserRole.SYSTEM_ADMIN.value, UserRole.ADMIN.value]: + """获取当前管理员用户""" + token_data = await _get_token_data(token) + if token_data.role not in [UserRole.SYSTEM_ADMIN, UserRole.ADMIN]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="You are not admin", + detail="Access denied: Insufficient privileges for this operation", headers={"WWW-Authenticate": "Bearer"}, ) return token_data \ No newline at end of file From fc547eebe573a8aac20d2c24d3efd8c5ef31c697 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 23:11:28 +0800 Subject: [PATCH 18/26] =?UTF-8?q?=E9=87=8D=E6=9E=84=E4=BA=86=E9=89=B4?= =?UTF-8?q?=E6=9D=83=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- schemas/auth.py | 40 +++++----------------------------------- 1 file changed, 5 insertions(+), 35 deletions(-) diff --git a/schemas/auth.py b/schemas/auth.py index 14ecef6..f6679ea 100644 --- a/schemas/auth.py +++ b/schemas/auth.py @@ -1,52 +1,22 @@ from pydantic import BaseModel from datetime import datetime -class Token(BaseModel): - """ - 表示一个JWT令牌的模型类。 - - Attributes: - access_token (str): 访问令牌,用于身份验证和授权。 - refresh_token (str): 刷新令牌,用于获取新的访问令牌。 - token_type (str): 令牌类型,通常是"Bearer"。 - expires_in (int): 访问令牌的有效期,以秒为单位。 - """ +class TokenResponse(BaseModel): access_token: str refresh_token: str token_type: str - expires_in: int + access_token_exp: datetime + refresh_token_exp: datetime -class TokenData(BaseModel): - """ - 表示JWT令牌中存储的数据的模型类。 - - Attributes: - id (int): 用户的唯一标识符。 - username (str): 用户的用户名。 - role (str): 用户的角色或权限。 - exp (datetime): 令牌的过期时间。 - """ +class TokenPayload(BaseModel): id: int username: str role: str exp: datetime class LoginRequest(BaseModel): - """ - 表示用户登录请求的模型类。 - - Attributes: - username (str): 用户登录时输入的用户名。 - password (str): 用户登录时输入的密码。 - """ username: str password: str class RefreshTokenRequest(BaseModel): - """ - 表示刷新令牌请求的模型类。 - - Attributes: - refresh_token (str): 用于刷新访问令牌的刷新令牌。 - """ - refresh_token: str \ No newline at end of file + refresh_token: str From 2f1cf11d91dd45b72bcea54f9cd343118b915017 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 21 Jan 2025 23:38:52 +0800 Subject: [PATCH 19/26] =?UTF-8?q?=E9=87=8D=E6=9E=84=E4=BA=86=E9=89=B4?= =?UTF-8?q?=E6=9D=83=E6=9C=8D=E5=8A=A1=EF=BC=8C=E9=87=8D=E6=9E=84payload?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- routes/auth.py | 17 ++++-------- routes/depends.py | 8 +++--- routes/users.py | 10 +++---- schemas/auth.py | 7 +++-- services/auth_service.py | 57 +++++++++++++++++++++++----------------- 5 files changed, 50 insertions(+), 49 deletions(-) diff --git a/routes/auth.py b/routes/auth.py index 1ba2498..52bd654 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -1,15 +1,15 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession -from schemas.auth import Token, LoginRequest, RefreshTokenRequest -from services.auth_service import create_tokens, verify_token, refresh_tokens +from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest +from services.auth_service import create_tokens_response, verify_token, refresh_tokens from services.user_services import authenticate_user from services.db import get_db router = APIRouter(tags=["auth"]) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -@router.post("/login", response_model=Token) +@router.post("/login", response_model=TokenResponse) async def login( login_data: LoginRequest, session: AsyncSession = Depends(get_db) @@ -21,17 +21,10 @@ async def login( detail="Invalid username or password", headers={"WWW-Authenticate": "Bearer"}, ) - return create_tokens(user.id, user.username, user.role) + return create_tokens_response(user.id, user.username, user.role) -@router.post("/logout") -async def logout(token: str = Depends(oauth2_scheme)): - token_data = verify_token(token) - if not token_data: - raise HTTPException(status_code=401, detail="Invalid token") - # TODO: 实现token黑名单 - return {"message": "Successfully logged out"} -@router.post("/refresh", response_model=Token) +@router.post("/refresh", response_model=TokenResponse) async def refresh_token(refresh_data: RefreshTokenRequest): tokens = refresh_tokens(refresh_data.refresh_token) if not tokens: diff --git a/routes/depends.py b/routes/depends.py index af1e0fe..bdcb70e 100644 --- a/routes/depends.py +++ b/routes/depends.py @@ -1,13 +1,13 @@ from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from typing import Optional -from schemas.auth import TokenData +from schemas.auth import TokenPayload from schemas.user import UserRole from services.auth_service import verify_token oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") -async def _get_token_data(token: str) -> TokenData: +async def _get_token_data(token: str) -> TokenPayload: """验证并返回TokenData""" token_data = verify_token(token) if token_data is None: @@ -18,11 +18,11 @@ async def _get_token_data(token: str) -> TokenData: ) return token_data -async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenData: +async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenPayload: """获取当前用户""" return await _get_token_data(token) -async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenData: +async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenPayload: """获取当前管理员用户""" token_data = await _get_token_data(token) if token_data.role not in [UserRole.SYSTEM_ADMIN, UserRole.ADMIN]: diff --git a/routes/users.py b/routes/users.py index 7e598b6..2e863a2 100644 --- a/routes/users.py +++ b/routes/users.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends, HTTPException, status from typing import List, Optional -from schemas.auth import TokenData +from schemas.auth import TokenPayload from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole from routes.depends import get_current_user,get_current_admin from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user @@ -13,7 +13,7 @@ async def get_users( page: int = 1, limit: int = 10, role: Optional[str] = None, - current_user_token: TokenData = Depends(get_current_user) + current_user_token: TokenPayload = Depends(get_current_user) ): current_user = await get_user_by_id(current_user_token.id) if current_user is None: @@ -31,7 +31,7 @@ async def get_users( @router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def create_user( user_data: UserCreate, - current_user_token: TokenData = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin) ): async with get_db_session() as session: return await create_user(session, user_data) @@ -40,7 +40,7 @@ async def create_user( async def update_user( user_id: int, user_data: UserUpdate, - current_user_token: TokenData = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin) ): async with get_db_session() as session: return await update_user(session, user_id, user_data) @@ -48,7 +48,7 @@ async def update_user( @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_user( user_id: int, - current_user_token: TokenData = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin) ): async with get_db_session() as session: success = await delete_user(session, user_id) diff --git a/schemas/auth.py b/schemas/auth.py index f6679ea..bde44ff 100644 --- a/schemas/auth.py +++ b/schemas/auth.py @@ -1,18 +1,17 @@ from pydantic import BaseModel -from datetime import datetime class TokenResponse(BaseModel): access_token: str refresh_token: str token_type: str - access_token_exp: datetime - refresh_token_exp: datetime + access_token_exp: int + refresh_token_exp: int class TokenPayload(BaseModel): id: int username: str role: str - exp: datetime + exp: int class LoginRequest(BaseModel): username: str diff --git a/services/auth_service.py b/services/auth_service.py index 7549224..30a9b44 100644 --- a/services/auth_service.py +++ b/services/auth_service.py @@ -1,17 +1,22 @@ -from datetime import datetime from typing import Optional import jwt +import time from config import JWT_CONFIG -from schemas.auth import Token, TokenData +from schemas.auth import TokenResponse, TokenPayload SECRET_KEY = JWT_CONFIG['secret_key'] ALGORITHM = JWT_CONFIG['algorithm'] ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire'] REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire'] -def create_access_token(user_id: int, username: str, role: str) -> str: - """创建access token""" - expire = datetime.utcnow() + ACCESS_TOKEN_EXPIRE + +def get_current_time() -> int: + """获取当前UTC时间戳""" + return int(time.time()) + +def create_token(user_id: int, username: str, role: str, expire_delta: int) -> str: + """创建JWT token""" + expire = get_current_time() + expire_delta to_encode = { "id": user_id, "username": username, @@ -20,33 +25,36 @@ def create_access_token(user_id: int, username: str, role: str) -> str: } return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) +def create_access_token(user_id: int, username: str, role: str) -> str: + """创建access token""" + return create_token(user_id, username, role, ACCESS_TOKEN_EXPIRE) + def create_refresh_token(user_id: int, username: str, role: str) -> str: """创建refresh token""" - expire = datetime.utcnow() + REFRESH_TOKEN_EXPIRE - to_encode = { - "id": user_id, - "username": username, - "role": role, - "exp": expire - } - return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return create_token(user_id, username, role, REFRESH_TOKEN_EXPIRE) -def create_tokens(user_id: int, username: str, role: str) -> Token: +def create_tokens_response(user_id: int, username: str, role: str) -> TokenResponse: """创建access token和refresh token""" access_token = create_access_token(user_id, username, role) refresh_token = create_refresh_token(user_id, username, role) - return Token( + + # 获取token的过期时间 + access_token_exp = get_current_time() + int(ACCESS_TOKEN_EXPIRE.total_seconds()) + refresh_token_exp = get_current_time() + int(REFRESH_TOKEN_EXPIRE.total_seconds()) + + return TokenResponse( access_token=access_token, refresh_token=refresh_token, token_type="bearer", - expires_in=int(ACCESS_TOKEN_EXPIRE.total_seconds()) + access_token_exp=access_token_exp, + refresh_token_exp=refresh_token_exp ) -def verify_token(token: str) -> Optional[TokenData]: +def verify_token(token: str) -> Optional[TokenPayload]: """验证token有效性并返回payload,如果token无效则返回None""" try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - return TokenData( + return TokenPayload( id=payload.get("id"), username=payload.get("username"), role=payload.get("role"), @@ -55,13 +63,14 @@ def verify_token(token: str) -> Optional[TokenData]: except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): return None -def refresh_tokens(refresh_token: str) -> Optional[Token]: +def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]: """使用refresh token刷新access token,如果refresh token无效则返回None""" token_data = verify_token(refresh_token) if token_data is None: return None - return create_tokens( - user_id=token_data.id, - username=token_data.username, - role=token_data.role - ) + else: + return create_tokens_response( + user_id=token_data.id, + username=token_data.username, + role=token_data.role + ) From bb314a2c6bec5946352f86046e29bc244e150783 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Wed, 22 Jan 2025 13:48:01 +0800 Subject: [PATCH 20/26] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E4=BA=86.env=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 4 ++-- requirements.txt | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/config.py b/config.py index 10e3097..26c2986 100644 --- a/config.py +++ b/config.py @@ -15,8 +15,8 @@ DATABASE_CONFIG = { JWT_CONFIG = { 'secret_key': os.getenv('JWT_SECRET_KEY', 'your-secret-key'), 'algorithm': 'HS256', - 'access_token_expire': timedelta(minutes=int(os.getenv('JWT_ACCESS_EXPIRE_MINUTES', '30'))), - 'refresh_token_expire': timedelta(days=int(os.getenv('JWT_REFRESH_EXPIRE_DAYS', '7'))) + 'access_token_expire': timedelta(minutes=int(os.getenv('JWT_ACCESS_EXPIRE_MINUTES', '10'))), + 'refresh_token_expire': timedelta(days=int(os.getenv('JWT_REFRESH_EXPIRE_DAYS', '3'))) } # 日志配置 diff --git a/requirements.txt b/requirements.txt index f333860..acec49f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ fastapi>=0.95.2 +python-dotenv>=1.0.0 sqlalchemy>=2.0.15 passlib>=1.7.4 bcrypt>=4.0.1 From bf856af9f9b685792273502b68999aaf1769a093 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Wed, 22 Jan 2025 13:58:39 +0800 Subject: [PATCH 21/26] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E4=BA=86=E7=A8=8B?= =?UTF-8?q?=E5=BA=8F=EF=BC=8C=E4=BC=98=E5=8C=96=E4=BA=86=E7=9B=AE=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 7 +++++-- routes/auth.py | 8 ++++---- routes/depends.py | 2 +- routes/users.py | 14 +++++++------- services/{auth_service.py => auth.py} | 0 services/db.py | 11 ++++++----- services/{user_services.py => user.py} | 4 +++- 7 files changed, 26 insertions(+), 20 deletions(-) rename services/{auth_service.py => auth.py} (100%) rename services/{user_services.py => user.py} (95%) diff --git a/main.py b/main.py index 0346ec8..6a9e6f6 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,13 @@ +from dotenv import load_dotenv +load_dotenv() + import logging from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from config import JWT_CONFIG, LOGGING_CONFIG, SYSTEM_ADMIN_CONFIG from services.db import create_db_engine, init_db, close_db_connection -from routes.auth import auth_router -from routes.users import users_router +from routes.auth import router as auth_router +from routes.users import router as users_router # 配置日志 logging.basicConfig( diff --git a/routes/auth.py b/routes/auth.py index 52bd654..c4bb51a 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -2,9 +2,9 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest -from services.auth_service import create_tokens_response, verify_token, refresh_tokens -from services.user_services import authenticate_user -from services.db import get_db +from services.auth import create_tokens_response, verify_token, refresh_tokens +from services.user import authenticate_user +from services.db import get_db_session router = APIRouter(tags=["auth"]) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -12,7 +12,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @router.post("/login", response_model=TokenResponse) async def login( login_data: LoginRequest, - session: AsyncSession = Depends(get_db) + session: AsyncSession = Depends(get_db_session) ): user = await authenticate_user(session, login_data.username, login_data.password) if not user: diff --git a/routes/depends.py b/routes/depends.py index bdcb70e..34e0e18 100644 --- a/routes/depends.py +++ b/routes/depends.py @@ -3,7 +3,7 @@ from fastapi.security import OAuth2PasswordBearer from typing import Optional from schemas.auth import TokenPayload from schemas.user import UserRole -from services.auth_service import verify_token +from services.auth import verify_token oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") diff --git a/routes/users.py b/routes/users.py index 2e863a2..cbed7c9 100644 --- a/routes/users.py +++ b/routes/users.py @@ -3,7 +3,7 @@ from typing import List, Optional from schemas.auth import TokenPayload from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole from routes.depends import get_current_user,get_current_admin -from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user +from services.user import get_users,create_user,update_user,delete_user#,get_user_by_id from services.db import get_db_session router = APIRouter(tags=["users"]) @@ -15,12 +15,12 @@ async def get_users( role: Optional[str] = None, current_user_token: TokenPayload = Depends(get_current_user) ): - current_user = await get_user_by_id(current_user_token.id) - if current_user is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) + # current_user = await get_user_by_id(current_user_token.id) + # if current_user is None: + # raise HTTPException( + # status_code=status.HTTP_404_NOT_FOUND, + # detail="User not found" + # ) async with get_db_session() as session: skip = (page - 1) * limit users = await get_users(session, skip=skip, limit=limit) diff --git a/services/auth_service.py b/services/auth.py similarity index 100% rename from services/auth_service.py rename to services/auth.py diff --git a/services/db.py b/services/db.py index 74777bf..6e3f9f6 100644 --- a/services/db.py +++ b/services/db.py @@ -1,10 +1,10 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker -from ..models.user import Base, User, UserRole +from models.user import Base, User, UserRole from sqlalchemy import select from contextlib import asynccontextmanager from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG -from user_services import get_password_hash +from services.user import get_password_hash from typing import AsyncGenerator # 全局数据库引擎实例 @@ -44,8 +44,9 @@ async def init_db(engine: AsyncEngine): # 创建所有表 await conn.run_sync(Base.metadata.create_all) + async with AsyncSession(engine) as session: # 检查系统管理员是否存在 - result = await conn.execute( + result = await session.execute( select(User).where(User.role == UserRole.SYSTEM_ADMIN) ) if not result.scalars().first(): @@ -56,8 +57,8 @@ async def init_db(engine: AsyncEngine): role=UserRole.SYSTEM_ADMIN, description=SYSTEM_ADMIN_CONFIG['description'] ) - conn.add(admin) - await conn.commit() + session.add(admin) + await session.commit() async def close_db_connection(): """关闭数据库连接""" diff --git a/services/user_services.py b/services/user.py similarity index 95% rename from services/user_services.py rename to services/user.py index 79e1ce5..90c6828 100644 --- a/services/user_services.py +++ b/services/user.py @@ -7,7 +7,6 @@ from schemas.user import UserCreate, UserUpdate, UserResponse # 创建一个密码上下文对象,指定使用 bcrypt 加密算法 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse: """创建用户""" @@ -60,6 +59,9 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: """验证输入的明文密码是否与存储的哈希密码匹配""" return pwd_context.verify(plain_password, hashed_password) +def get_password_hash(password: str) -> str: + """生成使用 bcrypt 的密码哈希""" + return pwd_context.hash(password) async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]: """验证用户登录""" From 1fd8af3be941a07831e6fd9269ae20a3e5857d9d Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Wed, 22 Jan 2025 15:02:07 +0800 Subject: [PATCH 22/26] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E8=8B=A5=E5=B9=B2bug?= =?UTF-8?q?=EF=BC=8C=E5=90=8E=E7=AB=AF=E5=9F=BA=E6=9C=AC=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- routes/auth.py | 4 ++-- routes/users.py | 58 ++++++++++++++++++++++-------------------------- services/auth.py | 4 ++-- services/db.py | 13 +++++++++-- services/user.py | 2 +- 5 files changed, 43 insertions(+), 38 deletions(-) diff --git a/routes/auth.py b/routes/auth.py index c4bb51a..93016ec 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest from services.auth import create_tokens_response, verify_token, refresh_tokens from services.user import authenticate_user -from services.db import get_db_session +from services.db import get_db_session_dep router = APIRouter(tags=["auth"]) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -12,7 +12,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @router.post("/login", response_model=TokenResponse) async def login( login_data: LoginRequest, - session: AsyncSession = Depends(get_db_session) + session: AsyncSession = Depends(get_db_session_dep) ): user = await authenticate_user(session, login_data.username, login_data.password) if not user: diff --git a/routes/users.py b/routes/users.py index cbed7c9..bb67c61 100644 --- a/routes/users.py +++ b/routes/users.py @@ -1,59 +1,55 @@ from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Optional from schemas.auth import TokenPayload from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole from routes.depends import get_current_user,get_current_admin -from services.user import get_users,create_user,update_user,delete_user#,get_user_by_id -from services.db import get_db_session +import services.user as user_service + +from services.db import get_db_session_dep router = APIRouter(tags=["users"]) @router.get("/", response_model=List[UserResponse]) -async def get_users( +async def get_users_list( page: int = 1, - limit: int = 10, + limit: int = 100, role: Optional[str] = None, - current_user_token: TokenPayload = Depends(get_current_user) + current_user_token: TokenPayload = Depends(get_current_user), + session: AsyncSession = Depends(get_db_session_dep) ): - # current_user = await get_user_by_id(current_user_token.id) - # if current_user is None: - # raise HTTPException( - # status_code=status.HTTP_404_NOT_FOUND, - # detail="User not found" - # ) - async with get_db_session() as session: - skip = (page - 1) * limit - users = await get_users(session, skip=skip, limit=limit) - if role: - users = [user for user in users if user.role == role] - return users + skip = (page - 1) * limit + users = await user_service.get_users_list(session, skip=skip, limit=limit) + if role: + users = [user for user in users if user.role == role] + return users @router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def create_user( user_data: UserCreate, - current_user_token: TokenPayload = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin), + session: AsyncSession = Depends(get_db_session_dep) ): - async with get_db_session() as session: - return await create_user(session, user_data) + return await user_service.create_user(session, user_data) @router.put("/{user_id}", response_model=UserResponse) async def update_user( user_id: int, user_data: UserUpdate, - current_user_token: TokenPayload = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin), + session: AsyncSession = Depends(get_db_session_dep) ): - async with get_db_session() as session: - return await update_user(session, user_id, user_data) + return await user_service.update_user(session, user_id, user_data) @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_user( user_id: int, - current_user_token: TokenPayload = Depends(get_current_admin) + current_user_token: TokenPayload = Depends(get_current_admin), + session: AsyncSession = Depends(get_db_session_dep) ): - async with get_db_session() as session: - success = await delete_user(session, user_id) - if not success: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) + success = await user_service.delete_user(session, user_id) + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) diff --git a/services/auth.py b/services/auth.py index 30a9b44..03352cc 100644 --- a/services/auth.py +++ b/services/auth.py @@ -14,9 +14,9 @@ def get_current_time() -> int: """获取当前UTC时间戳""" return int(time.time()) -def create_token(user_id: int, username: str, role: str, expire_delta: int) -> str: +def create_token(user_id: int, username: str, role: str, expire_delta) -> str: """创建JWT token""" - expire = get_current_time() + expire_delta + expire = get_current_time() + int(expire_delta.total_seconds()) to_encode = { "id": user_id, "username": username, diff --git a/services/db.py b/services/db.py index 6e3f9f6..eda231f 100644 --- a/services/db.py +++ b/services/db.py @@ -24,8 +24,7 @@ def get_db_engine() -> AsyncEngine: raise RuntimeError("Database engine not initialized") return _engine -@asynccontextmanager -async def get_db_session() -> AsyncGenerator[AsyncSession, None]: +async def get_db_session() -> AsyncSession: """获取数据库会话""" async with AsyncSession(get_db_engine()) as session: try: @@ -35,6 +34,16 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: await session.rollback() raise +async def get_db_session_dep() -> AsyncSession: + """FastAPI依赖注入使用的数据库会话获取函数""" + async with AsyncSession(get_db_engine()) as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + async def init_db(engine: AsyncEngine): """初始化数据库""" global _engine diff --git a/services/user.py b/services/user.py index 90c6828..c432427 100644 --- a/services/user.py +++ b/services/user.py @@ -30,7 +30,7 @@ async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse return UserResponse.from_orm(user) if user else None -async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]: +async def get_users_list(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]: """获取用户列表""" result = await session.execute(select(User).offset(skip).limit(limit)) users = result.scalars().all() From 4ba89c8ccc0894256715f1fd036cbcfa005d8986 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Fri, 14 Feb 2025 16:19:34 +0800 Subject: [PATCH 23/26] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2=E5=8D=95=E4=B8=AA=E7=94=A8=E6=88=B7=E7=9A=84api?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- routes/users.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/routes/users.py b/routes/users.py index bb67c61..72f47d9 100644 --- a/routes/users.py +++ b/routes/users.py @@ -41,6 +41,20 @@ async def update_user( ): return await user_service.update_user(session, user_id, user_data) +@router.get("/{user_id}", response_model=UserResponse) +async def get_user( + user_id: int, + current_user_token: TokenPayload = Depends(get_current_user), + session: AsyncSession = Depends(get_db_session_dep) +): + user = await user_service.get_user(session, user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + return user + @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_user( user_id: int, From 375077be693aaf1ec068e5711b17c67bc7512413 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Fri, 14 Feb 2025 16:49:22 +0800 Subject: [PATCH 24/26] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=AE=8C=E6=88=90?= =?UTF-8?q?=E8=83=BD=E4=BD=BF=E7=94=A8refresh=20token=E8=AE=BF=E9=97=AE?= =?UTF-8?q?=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- schemas/auth.py | 1 + services/auth.py | 22 +++++++++------------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/schemas/auth.py b/schemas/auth.py index bde44ff..67fe815 100644 --- a/schemas/auth.py +++ b/schemas/auth.py @@ -12,6 +12,7 @@ class TokenPayload(BaseModel): username: str role: str exp: int + token_type: str class LoginRequest(BaseModel): username: str diff --git a/services/auth.py b/services/auth.py index 03352cc..3fb46bb 100644 --- a/services/auth.py +++ b/services/auth.py @@ -14,29 +14,24 @@ def get_current_time() -> int: """获取当前UTC时间戳""" return int(time.time()) -def create_token(user_id: int, username: str, role: str, expire_delta) -> str: +def create_token(user_id: int, username: str, role: str, token_type: str = "access") -> str: """创建JWT token""" + expire_delta = ACCESS_TOKEN_EXPIRE if token_type == "access" else REFRESH_TOKEN_EXPIRE expire = get_current_time() + int(expire_delta.total_seconds()) + to_encode = { "id": user_id, "username": username, "role": role, - "exp": expire + "exp": expire, + "token_type": token_type } return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) -def create_access_token(user_id: int, username: str, role: str) -> str: - """创建access token""" - return create_token(user_id, username, role, ACCESS_TOKEN_EXPIRE) - -def create_refresh_token(user_id: int, username: str, role: str) -> str: - """创建refresh token""" - return create_token(user_id, username, role, REFRESH_TOKEN_EXPIRE) - def create_tokens_response(user_id: int, username: str, role: str) -> TokenResponse: """创建access token和refresh token""" - access_token = create_access_token(user_id, username, role) - refresh_token = create_refresh_token(user_id, username, role) + access_token = create_token(user_id, username, role, "access") + refresh_token = create_token(user_id, username, role, "refresh") # 获取token的过期时间 access_token_exp = get_current_time() + int(ACCESS_TOKEN_EXPIRE.total_seconds()) @@ -58,7 +53,8 @@ def verify_token(token: str) -> Optional[TokenPayload]: id=payload.get("id"), username=payload.get("username"), role=payload.get("role"), - exp=payload.get("exp") + exp=payload.get("exp"), + token_type=payload.get("token_type") ) except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): return None From b76d72168038b469a55f1b9085fe20a137089970 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Fri, 14 Feb 2025 16:59:41 +0800 Subject: [PATCH 25/26] =?UTF-8?q?=E5=AE=8C=E6=88=90refresh=20token?= =?UTF-8?q?=E8=83=BD=E8=AE=BF=E9=97=AE=E7=9A=84bug=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- routes/auth.py | 2 +- routes/depends.py | 4 ++-- services/auth.py | 24 +++++++++++++++++++++--- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/routes/auth.py b/routes/auth.py index 93016ec..3ad1498 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest -from services.auth import create_tokens_response, verify_token, refresh_tokens +from services.auth import create_tokens_response, refresh_tokens from services.user import authenticate_user from services.db import get_db_session_dep diff --git a/routes/depends.py b/routes/depends.py index 34e0e18..9fa5edb 100644 --- a/routes/depends.py +++ b/routes/depends.py @@ -3,13 +3,13 @@ from fastapi.security import OAuth2PasswordBearer from typing import Optional from schemas.auth import TokenPayload from schemas.user import UserRole -from services.auth import verify_token +from services.auth import verify_access_token oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") async def _get_token_data(token: str) -> TokenPayload: """验证并返回TokenData""" - token_data = verify_token(token) + token_data = verify_access_token(token) if token_data is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/services/auth.py b/services/auth.py index 3fb46bb..32c64fd 100644 --- a/services/auth.py +++ b/services/auth.py @@ -45,10 +45,28 @@ def create_tokens_response(user_id: int, username: str, role: str) -> TokenRespo refresh_token_exp=refresh_token_exp ) -def verify_token(token: str) -> Optional[TokenPayload]: - """验证token有效性并返回payload,如果token无效则返回None""" +def verify_access_token(token: str) -> Optional[TokenPayload]: + """验证access token有效性并返回payload,如果token无效或类型不匹配则返回None""" try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + if payload.get("token_type") != "access": + return None + return TokenPayload( + id=payload.get("id"), + username=payload.get("username"), + role=payload.get("role"), + exp=payload.get("exp"), + token_type=payload.get("token_type") + ) + except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): + return None + +def verify_refresh_token(token: str) -> Optional[TokenPayload]: + """验证refresh token有效性并返回payload,如果token无效或类型不匹配则返回None""" + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + if payload.get("token_type") != "refresh": + return None return TokenPayload( id=payload.get("id"), username=payload.get("username"), @@ -61,7 +79,7 @@ def verify_token(token: str) -> Optional[TokenPayload]: def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]: """使用refresh token刷新access token,如果refresh token无效则返回None""" - token_data = verify_token(refresh_token) + token_data = verify_refresh_token(refresh_token) if token_data is None: return None else: From 48a644fb354d6c6efcbd12bc1b4a2cb83137b68e Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Mon, 17 Feb 2025 15:22:09 +0800 Subject: [PATCH 26/26] =?UTF-8?q?=E4=BC=98=E5=8C=96devdoc?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- devdoc.md | 220 +++++------------------------------------------------- 1 file changed, 18 insertions(+), 202 deletions(-) diff --git a/devdoc.md b/devdoc.md index 6fbd059..0e05f84 100644 --- a/devdoc.md +++ b/devdoc.md @@ -16,10 +16,10 @@ | `id` | `INT` | `PRIMARY KEY`, `AUTO_INCREMENT` | - | 用户ID,主键,自增 | | `username` | `VARCHAR(50)` | `NOT NULL`, `UNIQUE` | - | 用户名,唯一 | | `password` | `VARCHAR(255)` | `NOT NULL` | - | 用户密码,使用哈希加密存储 | -| `role` | `ENUM` | `NOT NULL` | `'user'` | 用户角色,枚举类型,可选值为 `system_admin`, `admin`, `user` | +| `role` | `ENUM` | `NOT NULL` | `UserRole.USER` | 用户角色,枚举类型,可选值为 `UserRole.SYSTEM_ADMIN`, `UserRole.ADMIN`, `UserRole.USER` | | `description` | `TEXT` | - | - | 用户描述,可选 | -| `created_at` | `DATETIME` | - | `CURRENT_TIMESTAMP` | 用户创建时间,默认值为当前时间 | -| `updated_at` | `DATETIME` | - | `CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP` | 用户信息更新时间,默认值为当前时间,更新时自动更新 | +| `created_at` | `DATETIME` | - | `当前UTC时间` | 用户创建时间,默认值为当前UTC时间 | +| `updated_at` | `DATETIME` | - | `当前UTC时间,更新时自动更新` | 用户信息更新时间,默认值为当前UTC时间,更新时自动更新 | ### 数据库初始化 @@ -37,9 +37,20 @@ JWT(JSON Web Token)用于用户身份验证和权限控制。JWT 包含以下信息: -* **Payload**:用户 ID、用户名、角色、Token 过期时间等。 +* **Payload**:用户 ID、用户名、角色、Token 过期时间、token_type等。 * **签名**:使用后端密钥对 Payload 进行签名,确保 Token 的完整性和安全性。 +### JWT Token 数据结构 + +```python +class TokenPayload: + id: int # 用户ID + username: str # 用户名 + role: str # 用户角色 + exp: int # Token过期时间 + token_type: str # Token类型(access或refresh) +``` + ### JWT 自动过期机制 * **Access Token**:用于常规 API 请求,有效期较短(如 30 分钟)。 @@ -52,198 +63,10 @@ JWT(JSON Web Token)用于用户身份验证和权限控制。JWT 包含以 3. 服务端生成新的 Access Token 和 Refresh Token,并返回给客户端。 4. 客户端更新本地存储的 Token。 -## API 设计 +### Token 验证 -### 用户登录 - -* **URL**: `/api/auth/login` -* **HTTP 方法**: POST -* **请求体**: - - ```json - { - "username": "admin", - "password": "password123" - } - ``` -* **响应**: - - * 成功: - - ```json - { - "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", - "refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", - "token_type": "bearer", - "expires_in": 1800 - } - ``` - * 失败: - - ```json - { - "detail": "Invalid username or password" - } - ``` - -### 用户登出 - -* **URL**: `/api/auth/logout` -* **HTTP 方法**: POST -* **请求头**: - - * `Authorization: Bearer ` -* **响应**: - - * 成功: - - ```json - { - "message": "Successfully logged out" - } - ``` - -### 刷新 JWT Token - -* **URL**: `/api/auth/refresh` -* **HTTP 方法**: POST -* **请求体**: - - ```json - { - "refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." - } - ``` -* **响应**: - - * 成功: - - ```json - { - "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", - "refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", - "token_type": "bearer", - "expires_in": 1800 - } - ``` - * 失败: - - ```json - { - "detail": "Invalid refresh token" - } - ``` - -### 获取用户列表 - -* **URL**: `/api/users` -* **HTTP 方法**: GET -* **请求头**: - - * `Authorization: Bearer ` -* **查询参数**: - - * `page`:页码(默认 1) - * `limit`:每页条数(默认 10) - * `role`:按角色过滤(可选) -* **响应**: - - ```json - { - "total": 100, - "users": [ - { - "id": 1, - "username": "admin", - "role": "system_admin", - "description": "Default system admin", - "created_at": "2023-10-01T12:00:00Z", - "updated_at": "2023-10-01T12:00:00Z" - }, - ... - ] - } - ``` - -### 创建用户 - -* **URL**: `/api/users` -* **HTTP 方法**: POST -* **请求头**: - - * `Authorization: Bearer ` -* **请求体**: - - ```json - { - "username": "new_user", - "password": "new_password123", - "role": "user", - "description": "New user description" - } - ``` -* **响应**: - - * 成功: - - ```json - { - "id": 2, - "username": "new_user", - "role": "user", - "description": "New user description", - "created_at": "2023-10-01T12:00:00Z", - "updated_at": "2023-10-01T12:00:00Z" - } - ``` - -### 更新用户信息 - -* **URL**: `/api/users/{user_id}` -* **HTTP 方法**: PUT -* **请求头**: - - * `Authorization: Bearer ` -* **请求体**: - - ```json - { - "username": "updated_user", - "role": "admin", - "description": "Updated user description" - } - ``` -* **响应**: - - * 成功: - - ```json - { - "id": 1, - "username": "updated_user", - "role": "admin", - "description": "Updated user description", - "created_at": "2023-10-01T12:00:00Z", - "updated_at": "2023-10-01T12:30:00Z" - } - ``` - -### 删除用户 - -* **URL**: `/api/users/{user_id}` -* **HTTP 方法**: DELETE -* **请求头**: - - * `Authorization: Bearer ` -* **响应**: - - * 成功: - - ```json - { - "message": "User deleted successfully" - } - ``` +* **verify_access_token**:验证access token有效性并返回payload,如果token无效或类型不匹配则返回None +* **verify_refresh_token**:验证refresh token有效性并返回payload,如果token无效或类型不匹配则返回None ## 单元测试 @@ -255,31 +78,24 @@ JWT(JSON Web Token)用于用户身份验证和权限控制。JWT 包含以 ### 测试用例 1. **用户登录**: - * 测试正确的用户名和密码。 * 测试错误的用户名和密码。 2. **用户登出**: - * 测试已登录用户登出。 * 测试未登录用户登出。 3. **刷新 JWT Token**: - * 测试有效的 Refresh Token。 * 测试无效的 Refresh Token。 4. **获取用户列表**: - * 测试不同角色的用户访问权限。 * 测试分页和过滤功能。 5. **创建用户**: - * 测试管理员创建用户。 * 测试普通用户尝试创建用户。 6. **更新用户信息**: - * 测试管理员更新用户信息。 * 测试普通用户尝试更新用户信息。 7. **删除用户**: - * 测试管理员删除用户。 * 测试普通用户尝试删除用户。