Add 'backend/' from commit '48a644fb354d6c6efcbd12bc1b4a2cb83137b68e'
git-subtree-dir: backend git-subtree-mainline: 545699d16fda1029201c9bfadbfb8d5c7ffe2464 git-subtree-split: 48a644fb354d6c6efcbd12bc1b4a2cb83137b68e
This commit is contained in:
commit
3f114b2cc3
30
backend/.gitignore
vendored
Normal file
30
backend/.gitignore
vendored
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
# IDE specific files
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
|
||||||
|
# Logs and databases
|
||||||
|
*.log
|
||||||
|
*.sqlite3
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# Python packaging
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
35
backend/config.py
Normal file
35
backend/config.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
# 数据库配置
|
||||||
|
DATABASE_CONFIG = {
|
||||||
|
'host': os.getenv('DB_HOST', 'localhost'),
|
||||||
|
'port': int(os.getenv('DB_PORT', '3306')),
|
||||||
|
'user': os.getenv('DB_USER', 'root'),
|
||||||
|
'password': os.getenv('DB_PASSWORD', 'password'),
|
||||||
|
'database': os.getenv('DB_NAME', 'user_manage'),
|
||||||
|
'charset': 'utf8mb4'
|
||||||
|
}
|
||||||
|
|
||||||
|
# JWT配置
|
||||||
|
JWT_CONFIG = {
|
||||||
|
'secret_key': os.getenv('JWT_SECRET_KEY', 'your-secret-key'),
|
||||||
|
'algorithm': 'HS256',
|
||||||
|
'access_token_expire': timedelta(minutes=int(os.getenv('JWT_ACCESS_EXPIRE_MINUTES', '10'))),
|
||||||
|
'refresh_token_expire': timedelta(days=int(os.getenv('JWT_REFRESH_EXPIRE_DAYS', '3')))
|
||||||
|
}
|
||||||
|
|
||||||
|
# 日志配置
|
||||||
|
LOGGING_CONFIG = {
|
||||||
|
'level': os.getenv('LOG_LEVEL', 'INFO'),
|
||||||
|
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
'filename': os.getenv('LOG_FILE', 'app.log')
|
||||||
|
}
|
||||||
|
|
||||||
|
# 系统管理员初始配置
|
||||||
|
SYSTEM_ADMIN_CONFIG = {
|
||||||
|
'username': os.getenv('ADMIN_USERNAME', 'admin'),
|
||||||
|
'password': os.getenv('ADMIN_PASSWORD', 'password'),
|
||||||
|
'role': os.getenv('ADMIN_ROLE', 'system_admin'),
|
||||||
|
'description': os.getenv('ADMIN_DESCRIPTION', 'default system admin')
|
||||||
|
}
|
109
backend/devdoc.md
Normal file
109
backend/devdoc.md
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
## 项目概述
|
||||||
|
|
||||||
|
本项目是一个简单的用户管理系统,支持三种用户角色:系统管理员、管理员和普通用户。系统管理员在项目启动时自动生成,且唯一。所有账户必须由管理员手动创建。用户表包含密码、ID、角色、描述等信息,管理员可以对用户表进行增删改查,普通用户只能读取用户表。
|
||||||
|
|
||||||
|
## 技术栈
|
||||||
|
|
||||||
|
* **后端技术栈**:FastAPI、SQLAlchemy、MySQL、bcrypt、PyJWT、logging
|
||||||
|
* **前端技术栈**:Vue、Vue Router、Pinia、Vite、Axios、Element Plus
|
||||||
|
|
||||||
|
## 数据库设计
|
||||||
|
|
||||||
|
### 用户表结构
|
||||||
|
|
||||||
|
| 字段名 | 数据类型 | 约束条件 | 默认值 | 说明 |
|
||||||
|
| -------- | ---------- | ---------- | -------- | ---------------------------------------------------- |
|
||||||
|
| `id` | `INT` | `PRIMARY KEY`, `AUTO_INCREMENT` | - | 用户ID,主键,自增 |
|
||||||
|
| `username` | `VARCHAR(50)` | `NOT NULL`, `UNIQUE` | - | 用户名,唯一 |
|
||||||
|
| `password` | `VARCHAR(255)` | `NOT NULL` | - | 用户密码,使用哈希加密存储 |
|
||||||
|
| `role` | `ENUM` | `NOT NULL` | `UserRole.USER` | 用户角色,枚举类型,可选值为 `UserRole.SYSTEM_ADMIN`, `UserRole.ADMIN`, `UserRole.USER` |
|
||||||
|
| `description` | `TEXT` | - | - | 用户描述,可选 |
|
||||||
|
| `created_at` | `DATETIME` | - | `当前UTC时间` | 用户创建时间,默认值为当前UTC时间 |
|
||||||
|
| `updated_at` | `DATETIME` | - | `当前UTC时间,更新时自动更新` | 用户信息更新时间,默认值为当前UTC时间,更新时自动更新 |
|
||||||
|
|
||||||
|
### 数据库初始化
|
||||||
|
|
||||||
|
* 系统初始化时,后端控制层会检查表是否存在,若不存在则自动创建表。
|
||||||
|
* 检查表中是否存在系统管理员,若不存在则创建一个名为`admin`、密码为`password`、描述为`default system admin`的系统管理员。
|
||||||
|
|
||||||
|
## 鉴权设计
|
||||||
|
|
||||||
|
### 角色权限
|
||||||
|
|
||||||
|
* **系统管理员、管理员**:可以访问所有 API,包括用户增删改查。
|
||||||
|
* **普通用户**:只能访问用户列表(只读)。
|
||||||
|
|
||||||
|
### JWT 鉴权
|
||||||
|
|
||||||
|
JWT(JSON Web Token)用于用户身份验证和权限控制。JWT 包含以下信息:
|
||||||
|
|
||||||
|
* **Payload**:用户 ID、用户名、角色、Token 过期时间、token_type等。
|
||||||
|
* **签名**:使用后端密钥对 Payload 进行签名,确保 Token 的完整性和安全性。
|
||||||
|
|
||||||
|
### JWT Token 数据结构
|
||||||
|
|
||||||
|
```python
|
||||||
|
class TokenPayload:
|
||||||
|
id: int # 用户ID
|
||||||
|
username: str # 用户名
|
||||||
|
role: str # 用户角色
|
||||||
|
exp: int # Token过期时间
|
||||||
|
token_type: str # Token类型(access或refresh)
|
||||||
|
```
|
||||||
|
|
||||||
|
### JWT 自动过期机制
|
||||||
|
|
||||||
|
* **Access Token**:用于常规 API 请求,有效期较短(如 30 分钟)。
|
||||||
|
* **Refresh Token**:用于刷新 Access Token,有效期较长(如 7 天)。
|
||||||
|
|
||||||
|
#### Token 刷新流程
|
||||||
|
|
||||||
|
1. 客户端使用 Refresh Token 请求 `/api/auth/refresh` 接口。
|
||||||
|
2. 服务端验证 Refresh Token 的有效性。
|
||||||
|
3. 服务端生成新的 Access Token 和 Refresh Token,并返回给客户端。
|
||||||
|
4. 客户端更新本地存储的 Token。
|
||||||
|
|
||||||
|
### Token 验证
|
||||||
|
|
||||||
|
* **verify_access_token**:验证access token有效性并返回payload,如果token无效或类型不匹配则返回None
|
||||||
|
* **verify_refresh_token**:验证refresh token有效性并返回payload,如果token无效或类型不匹配则返回None
|
||||||
|
|
||||||
|
## 单元测试
|
||||||
|
|
||||||
|
### 测试框架
|
||||||
|
|
||||||
|
* 使用 `pytest` 进行单元测试。
|
||||||
|
* 使用 `requests` 库模拟 API 请求。
|
||||||
|
|
||||||
|
### 测试用例
|
||||||
|
|
||||||
|
1. **用户登录**:
|
||||||
|
* 测试正确的用户名和密码。
|
||||||
|
* 测试错误的用户名和密码。
|
||||||
|
2. **用户登出**:
|
||||||
|
* 测试已登录用户登出。
|
||||||
|
* 测试未登录用户登出。
|
||||||
|
3. **刷新 JWT Token**:
|
||||||
|
* 测试有效的 Refresh Token。
|
||||||
|
* 测试无效的 Refresh Token。
|
||||||
|
4. **获取用户列表**:
|
||||||
|
* 测试不同角色的用户访问权限。
|
||||||
|
* 测试分页和过滤功能。
|
||||||
|
5. **创建用户**:
|
||||||
|
* 测试管理员创建用户。
|
||||||
|
* 测试普通用户尝试创建用户。
|
||||||
|
6. **更新用户信息**:
|
||||||
|
* 测试管理员更新用户信息。
|
||||||
|
* 测试普通用户尝试更新用户信息。
|
||||||
|
7. **删除用户**:
|
||||||
|
* 测试管理员删除用户。
|
||||||
|
* 测试普通用户尝试删除用户。
|
||||||
|
|
||||||
|
## 日志记录
|
||||||
|
|
||||||
|
* 使用 `logging` 模块记录系统日志。
|
||||||
|
* 日志级别包括 `INFO`、`WARNING`、`ERROR`。
|
||||||
|
* 日志内容包括用户操作、API 请求、错误信息等。
|
||||||
|
|
||||||
|
## 后续开发
|
||||||
|
* **部署文档**:包括 Docker 容器化、CI/CD 流程、环境变量配置等。
|
47
backend/main.py
Normal file
47
backend/main.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from config import JWT_CONFIG, LOGGING_CONFIG, SYSTEM_ADMIN_CONFIG
|
||||||
|
from services.db import create_db_engine, init_db, close_db_connection
|
||||||
|
from routes.auth import router as auth_router
|
||||||
|
from routes.users import router as users_router
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=LOGGING_CONFIG['level'],
|
||||||
|
format=LOGGING_CONFIG['format'],
|
||||||
|
filename=LOGGING_CONFIG['filename']
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 初始化FastAPI应用
|
||||||
|
app = FastAPI(
|
||||||
|
title="User Management System",
|
||||||
|
description="API for managing users with role-based access control",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 数据库初始化
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
logger.info("Initializing database...")
|
||||||
|
engine = create_db_engine()
|
||||||
|
await init_db(engine)
|
||||||
|
logger.info("Database initialized successfully")
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
logger.info("Closing database connections...")
|
||||||
|
await close_db_connection()
|
||||||
|
logger.info("Database connections closed")
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
|
||||||
|
app.include_router(users_router, prefix="/api/users", tags=["users"])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
3
backend/models/__init__.py
Normal file
3
backend/models/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .user import User
|
||||||
|
|
||||||
|
__all__ = ["User"]
|
25
backend/models/user.py
Normal file
25
backend/models/user.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from sqlalchemy import Column, Integer, String, Text, DateTime, Enum as SQLEnum
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
class UserRole(str, Enum):
|
||||||
|
SYSTEM_ADMIN = "system_admin"
|
||||||
|
ADMIN = "admin"
|
||||||
|
USER = "user"
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
username = Column(String(50), unique=True, nullable=False)
|
||||||
|
password = Column(String(255), nullable=False)
|
||||||
|
role = Column(SQLEnum(UserRole), nullable=False, default=UserRole.USER)
|
||||||
|
description = Column(Text)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<User(id={self.id}, username={self.username}, role={self.role})>"
|
9
backend/requirements.txt
Normal file
9
backend/requirements.txt
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
fastapi>=0.95.2
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
sqlalchemy>=2.0.15
|
||||||
|
passlib>=1.7.4
|
||||||
|
bcrypt>=4.0.1
|
||||||
|
pyjwt>=2.6.0
|
||||||
|
pytest>=7.3.1
|
||||||
|
requests>=2.28.2
|
||||||
|
uvicorn[standard]>=0.21.1
|
32
backend/routes/auth.py
Normal file
32
backend/routes/auth.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest
|
||||||
|
from services.auth import create_tokens_response, refresh_tokens
|
||||||
|
from services.user import authenticate_user
|
||||||
|
from services.db import get_db_session_dep
|
||||||
|
|
||||||
|
router = APIRouter(tags=["auth"])
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
@router.post("/login", response_model=TokenResponse)
|
||||||
|
async def login(
|
||||||
|
login_data: LoginRequest,
|
||||||
|
session: AsyncSession = Depends(get_db_session_dep)
|
||||||
|
):
|
||||||
|
user = await authenticate_user(session, login_data.username, login_data.password)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Invalid username or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
return create_tokens_response(user.id, user.username, user.role)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=TokenResponse)
|
||||||
|
async def refresh_token(refresh_data: RefreshTokenRequest):
|
||||||
|
tokens = refresh_tokens(refresh_data.refresh_token)
|
||||||
|
if not tokens:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||||
|
return tokens
|
34
backend/routes/depends.py
Normal file
34
backend/routes/depends.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from typing import Optional
|
||||||
|
from schemas.auth import TokenPayload
|
||||||
|
from schemas.user import UserRole
|
||||||
|
from services.auth import verify_access_token
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||||
|
|
||||||
|
async def _get_token_data(token: str) -> TokenPayload:
|
||||||
|
"""验证并返回TokenData"""
|
||||||
|
token_data = verify_access_token(token)
|
||||||
|
if token_data is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or expired authentication credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
return token_data
|
||||||
|
|
||||||
|
async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenPayload:
|
||||||
|
"""获取当前用户"""
|
||||||
|
return await _get_token_data(token)
|
||||||
|
|
||||||
|
async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenPayload:
|
||||||
|
"""获取当前管理员用户"""
|
||||||
|
token_data = await _get_token_data(token)
|
||||||
|
if token_data.role not in [UserRole.SYSTEM_ADMIN, UserRole.ADMIN]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Access denied: Insufficient privileges for this operation",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
return token_data
|
69
backend/routes/users.py
Normal file
69
backend/routes/users.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from typing import List, Optional
|
||||||
|
from schemas.auth import TokenPayload
|
||||||
|
from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole
|
||||||
|
from routes.depends import get_current_user,get_current_admin
|
||||||
|
import services.user as user_service
|
||||||
|
|
||||||
|
from services.db import get_db_session_dep
|
||||||
|
|
||||||
|
router = APIRouter(tags=["users"])
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[UserResponse])
|
||||||
|
async def get_users_list(
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 100,
|
||||||
|
role: Optional[str] = None,
|
||||||
|
current_user_token: TokenPayload = Depends(get_current_user),
|
||||||
|
session: AsyncSession = Depends(get_db_session_dep)
|
||||||
|
):
|
||||||
|
skip = (page - 1) * limit
|
||||||
|
users = await user_service.get_users_list(session, skip=skip, limit=limit)
|
||||||
|
if role:
|
||||||
|
users = [user for user in users if user.role == role]
|
||||||
|
return users
|
||||||
|
|
||||||
|
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_user(
|
||||||
|
user_data: UserCreate,
|
||||||
|
current_user_token: TokenPayload = Depends(get_current_admin),
|
||||||
|
session: AsyncSession = Depends(get_db_session_dep)
|
||||||
|
):
|
||||||
|
return await user_service.create_user(session, user_data)
|
||||||
|
|
||||||
|
@router.put("/{user_id}", response_model=UserResponse)
|
||||||
|
async def update_user(
|
||||||
|
user_id: int,
|
||||||
|
user_data: UserUpdate,
|
||||||
|
current_user_token: TokenPayload = Depends(get_current_admin),
|
||||||
|
session: AsyncSession = Depends(get_db_session_dep)
|
||||||
|
):
|
||||||
|
return await user_service.update_user(session, user_id, user_data)
|
||||||
|
|
||||||
|
@router.get("/{user_id}", response_model=UserResponse)
|
||||||
|
async def get_user(
|
||||||
|
user_id: int,
|
||||||
|
current_user_token: TokenPayload = Depends(get_current_user),
|
||||||
|
session: AsyncSession = Depends(get_db_session_dep)
|
||||||
|
):
|
||||||
|
user = await user_service.get_user(session, user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_user(
|
||||||
|
user_id: int,
|
||||||
|
current_user_token: TokenPayload = Depends(get_current_admin),
|
||||||
|
session: AsyncSession = Depends(get_db_session_dep)
|
||||||
|
):
|
||||||
|
success = await user_service.delete_user(session, user_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
22
backend/schemas/auth.py
Normal file
22
backend/schemas/auth.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str
|
||||||
|
access_token_exp: int
|
||||||
|
refresh_token_exp: int
|
||||||
|
|
||||||
|
class TokenPayload(BaseModel):
|
||||||
|
id: int
|
||||||
|
username: str
|
||||||
|
role: str
|
||||||
|
exp: int
|
||||||
|
token_type: str
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
class RefreshTokenRequest(BaseModel):
|
||||||
|
refresh_token: str
|
43
backend/schemas/user.py
Normal file
43
backend/schemas/user.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# 用户角色枚举
|
||||||
|
class UserRole(str, Enum):
|
||||||
|
SYSTEM_ADMIN = "system_admin"
|
||||||
|
ADMIN = "admin"
|
||||||
|
USER = "user"
|
||||||
|
|
||||||
|
# 基础用户模型
|
||||||
|
class UserBase(BaseModel):
|
||||||
|
username: str = Field(..., max_length=50, description="用户名")
|
||||||
|
role: UserRole = Field(default=UserRole.USER, description="用户角色")
|
||||||
|
description: Optional[str] = Field(None, max_length=255, description="用户描述")
|
||||||
|
|
||||||
|
# 用户创建模型
|
||||||
|
class UserCreate(UserBase):
|
||||||
|
password: str = Field(..., min_length=6, max_length=255, description="用户密码")
|
||||||
|
|
||||||
|
|
||||||
|
# 用户更新模型
|
||||||
|
class UserUpdate(BaseModel):
|
||||||
|
username: Optional[str] = Field(None, max_length=50, description="用户名")
|
||||||
|
role: Optional[UserRole] = Field(None, description="用户角色")
|
||||||
|
description: Optional[str] = Field(None, max_length=255, description="用户描述")
|
||||||
|
|
||||||
|
# 可选:确保至少更新一个字段
|
||||||
|
@root_validator
|
||||||
|
def validate_at_least_one_field(cls, values):
|
||||||
|
if not any(values.values()):
|
||||||
|
raise ValueError("至少需要更新一个字段")
|
||||||
|
return values
|
||||||
|
|
||||||
|
# 用户响应模型
|
||||||
|
class UserResponse(UserBase):
|
||||||
|
id: int = Field(..., description="用户ID")
|
||||||
|
created_at: datetime = Field(..., description="创建时间")
|
||||||
|
updated_at: datetime = Field(..., description="更新时间")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True # 允许从 ORM 对象加载数据
|
90
backend/services/auth.py
Normal file
90
backend/services/auth.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import jwt
|
||||||
|
import time
|
||||||
|
from config import JWT_CONFIG
|
||||||
|
from schemas.auth import TokenResponse, TokenPayload
|
||||||
|
|
||||||
|
SECRET_KEY = JWT_CONFIG['secret_key']
|
||||||
|
ALGORITHM = JWT_CONFIG['algorithm']
|
||||||
|
ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire']
|
||||||
|
REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire']
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_time() -> int:
|
||||||
|
"""获取当前UTC时间戳"""
|
||||||
|
return int(time.time())
|
||||||
|
|
||||||
|
def create_token(user_id: int, username: str, role: str, token_type: str = "access") -> str:
|
||||||
|
"""创建JWT token"""
|
||||||
|
expire_delta = ACCESS_TOKEN_EXPIRE if token_type == "access" else REFRESH_TOKEN_EXPIRE
|
||||||
|
expire = get_current_time() + int(expire_delta.total_seconds())
|
||||||
|
|
||||||
|
to_encode = {
|
||||||
|
"id": user_id,
|
||||||
|
"username": username,
|
||||||
|
"role": role,
|
||||||
|
"exp": expire,
|
||||||
|
"token_type": token_type
|
||||||
|
}
|
||||||
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
def create_tokens_response(user_id: int, username: str, role: str) -> TokenResponse:
|
||||||
|
"""创建access token和refresh token"""
|
||||||
|
access_token = create_token(user_id, username, role, "access")
|
||||||
|
refresh_token = create_token(user_id, username, role, "refresh")
|
||||||
|
|
||||||
|
# 获取token的过期时间
|
||||||
|
access_token_exp = get_current_time() + int(ACCESS_TOKEN_EXPIRE.total_seconds())
|
||||||
|
refresh_token_exp = get_current_time() + int(REFRESH_TOKEN_EXPIRE.total_seconds())
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
token_type="bearer",
|
||||||
|
access_token_exp=access_token_exp,
|
||||||
|
refresh_token_exp=refresh_token_exp
|
||||||
|
)
|
||||||
|
|
||||||
|
def verify_access_token(token: str) -> Optional[TokenPayload]:
|
||||||
|
"""验证access token有效性并返回payload,如果token无效或类型不匹配则返回None"""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
if payload.get("token_type") != "access":
|
||||||
|
return None
|
||||||
|
return TokenPayload(
|
||||||
|
id=payload.get("id"),
|
||||||
|
username=payload.get("username"),
|
||||||
|
role=payload.get("role"),
|
||||||
|
exp=payload.get("exp"),
|
||||||
|
token_type=payload.get("token_type")
|
||||||
|
)
|
||||||
|
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def verify_refresh_token(token: str) -> Optional[TokenPayload]:
|
||||||
|
"""验证refresh token有效性并返回payload,如果token无效或类型不匹配则返回None"""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
if payload.get("token_type") != "refresh":
|
||||||
|
return None
|
||||||
|
return TokenPayload(
|
||||||
|
id=payload.get("id"),
|
||||||
|
username=payload.get("username"),
|
||||||
|
role=payload.get("role"),
|
||||||
|
exp=payload.get("exp"),
|
||||||
|
token_type=payload.get("token_type")
|
||||||
|
)
|
||||||
|
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]:
|
||||||
|
"""使用refresh token刷新access token,如果refresh token无效则返回None"""
|
||||||
|
token_data = verify_refresh_token(refresh_token)
|
||||||
|
if token_data is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return create_tokens_response(
|
||||||
|
user_id=token_data.id,
|
||||||
|
username=token_data.username,
|
||||||
|
role=token_data.role
|
||||||
|
)
|
77
backend/services/db.py
Normal file
77
backend/services/db.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine, AsyncSession
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from models.user import Base, User, UserRole
|
||||||
|
from sqlalchemy import select
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG
|
||||||
|
from services.user import get_password_hash
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
# 全局数据库引擎实例
|
||||||
|
_engine: AsyncEngine | None = None
|
||||||
|
|
||||||
|
def create_db_engine() -> AsyncEngine:
|
||||||
|
"""创建数据库引擎"""
|
||||||
|
return create_async_engine(
|
||||||
|
f"mysql+asyncmy://{DATABASE_CONFIG['user']}:{DATABASE_CONFIG['password']}@"
|
||||||
|
f"{DATABASE_CONFIG['host']}:{DATABASE_CONFIG['port']}/{DATABASE_CONFIG['database']}",
|
||||||
|
echo=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_db_engine() -> AsyncEngine:
|
||||||
|
"""获取全局数据库引擎实例"""
|
||||||
|
if _engine is None:
|
||||||
|
raise RuntimeError("Database engine not initialized")
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
async def get_db_session() -> AsyncSession:
|
||||||
|
"""获取数据库会话"""
|
||||||
|
async with AsyncSession(get_db_engine()) as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_db_session_dep() -> AsyncSession:
|
||||||
|
"""FastAPI依赖注入使用的数据库会话获取函数"""
|
||||||
|
async with AsyncSession(get_db_engine()) as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def init_db(engine: AsyncEngine):
|
||||||
|
"""初始化数据库"""
|
||||||
|
global _engine
|
||||||
|
_engine = engine
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
# 创建所有表
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
# 检查系统管理员是否存在
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).where(User.role == UserRole.SYSTEM_ADMIN)
|
||||||
|
)
|
||||||
|
if not result.scalars().first():
|
||||||
|
# 创建默认系统管理员
|
||||||
|
admin = User(
|
||||||
|
username=SYSTEM_ADMIN_CONFIG['username'],
|
||||||
|
password=get_password_hash(SYSTEM_ADMIN_CONFIG['password']),
|
||||||
|
role=UserRole.SYSTEM_ADMIN,
|
||||||
|
description=SYSTEM_ADMIN_CONFIG['description']
|
||||||
|
)
|
||||||
|
session.add(admin)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def close_db_connection():
|
||||||
|
"""关闭数据库连接"""
|
||||||
|
global _engine
|
||||||
|
if _engine is not None:
|
||||||
|
await _engine.dispose()
|
||||||
|
_engine = None
|
74
backend/services/user.py
Normal file
74
backend/services/user.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from sqlalchemy import select, update, delete
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
from models.user import User
|
||||||
|
from schemas.user import UserCreate, UserUpdate, UserResponse
|
||||||
|
|
||||||
|
# 创建一个密码上下文对象,指定使用 bcrypt 加密算法
|
||||||
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse:
|
||||||
|
"""创建用户"""
|
||||||
|
hashed_password = pwd_context.hash(user_data.password)
|
||||||
|
user = User(
|
||||||
|
username=user_data.username,
|
||||||
|
password=hashed_password,
|
||||||
|
role=user_data.role,
|
||||||
|
description=user_data.description
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
return UserResponse.from_orm(user)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]:
|
||||||
|
"""根据ID获取用户"""
|
||||||
|
result = await session.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalars().first()
|
||||||
|
return UserResponse.from_orm(user) if user else None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_users_list(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]:
|
||||||
|
"""获取用户列表"""
|
||||||
|
result = await session.execute(select(User).offset(skip).limit(limit))
|
||||||
|
users = result.scalars().all()
|
||||||
|
return [UserResponse.from_orm(user) for user in users]
|
||||||
|
|
||||||
|
|
||||||
|
async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]:
|
||||||
|
"""更新用户信息"""
|
||||||
|
await session.execute(
|
||||||
|
update(User)
|
||||||
|
.where(User.id == user_id)
|
||||||
|
.values(**user_data.dict(exclude_unset=True))
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
return await get_user(session, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_user(session: AsyncSession, user_id: int) -> bool:
|
||||||
|
"""删除用户"""
|
||||||
|
result = await session.execute(delete(User).where(User.id == user_id))
|
||||||
|
await session.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
"""验证输入的明文密码是否与存储的哈希密码匹配"""
|
||||||
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
|
def get_password_hash(password: str) -> str:
|
||||||
|
"""生成使用 bcrypt 的密码哈希"""
|
||||||
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]:
|
||||||
|
"""验证用户登录"""
|
||||||
|
result = await session.execute(select(User).where(User.username == username))
|
||||||
|
user = result.scalars().first()
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
if not verify_password(password, user.password):
|
||||||
|
return None
|
||||||
|
return UserResponse.from_orm(user)
|
Loading…
x
Reference in New Issue
Block a user