Add 'backend/' from commit '48a644fb354d6c6efcbd12bc1b4a2cb83137b68e'
git-subtree-dir: backend git-subtree-mainline: 545699d16fda1029201c9bfadbfb8d5c7ffe2464 git-subtree-split: 48a644fb354d6c6efcbd12bc1b4a2cb83137b68e
This commit is contained in:
commit
3f114b2cc3
30
backend/.gitignore
vendored
Normal file
30
backend/.gitignore
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
env/
|
||||
.venv/
|
||||
|
||||
# IDE specific files
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Logs and databases
|
||||
*.log
|
||||
*.sqlite3
|
||||
|
||||
# Testing
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# Python packaging
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
35
backend/config.py
Normal file
35
backend/config.py
Normal file
@ -0,0 +1,35 @@
|
||||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_CONFIG = {
|
||||
'host': os.getenv('DB_HOST', 'localhost'),
|
||||
'port': int(os.getenv('DB_PORT', '3306')),
|
||||
'user': os.getenv('DB_USER', 'root'),
|
||||
'password': os.getenv('DB_PASSWORD', 'password'),
|
||||
'database': os.getenv('DB_NAME', 'user_manage'),
|
||||
'charset': 'utf8mb4'
|
||||
}
|
||||
|
||||
# JWT配置
|
||||
JWT_CONFIG = {
|
||||
'secret_key': os.getenv('JWT_SECRET_KEY', 'your-secret-key'),
|
||||
'algorithm': 'HS256',
|
||||
'access_token_expire': timedelta(minutes=int(os.getenv('JWT_ACCESS_EXPIRE_MINUTES', '10'))),
|
||||
'refresh_token_expire': timedelta(days=int(os.getenv('JWT_REFRESH_EXPIRE_DAYS', '3')))
|
||||
}
|
||||
|
||||
# 日志配置
|
||||
LOGGING_CONFIG = {
|
||||
'level': os.getenv('LOG_LEVEL', 'INFO'),
|
||||
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
'filename': os.getenv('LOG_FILE', 'app.log')
|
||||
}
|
||||
|
||||
# 系统管理员初始配置
|
||||
SYSTEM_ADMIN_CONFIG = {
|
||||
'username': os.getenv('ADMIN_USERNAME', 'admin'),
|
||||
'password': os.getenv('ADMIN_PASSWORD', 'password'),
|
||||
'role': os.getenv('ADMIN_ROLE', 'system_admin'),
|
||||
'description': os.getenv('ADMIN_DESCRIPTION', 'default system admin')
|
||||
}
|
109
backend/devdoc.md
Normal file
109
backend/devdoc.md
Normal file
@ -0,0 +1,109 @@
|
||||
## 项目概述
|
||||
|
||||
本项目是一个简单的用户管理系统,支持三种用户角色:系统管理员、管理员和普通用户。系统管理员在项目启动时自动生成,且唯一。所有账户必须由管理员手动创建。用户表包含密码、ID、角色、描述等信息,管理员可以对用户表进行增删改查,普通用户只能读取用户表。
|
||||
|
||||
## 技术栈
|
||||
|
||||
* **后端技术栈**:FastAPI、SQLAlchemy、MySQL、bcrypt、PyJWT、logging
|
||||
* **前端技术栈**:Vue、Vue Router、Pinia、Vite、Axios、Element Plus
|
||||
|
||||
## 数据库设计
|
||||
|
||||
### 用户表结构
|
||||
|
||||
| 字段名 | 数据类型 | 约束条件 | 默认值 | 说明 |
|
||||
| -------- | ---------- | ---------- | -------- | ---------------------------------------------------- |
|
||||
| `id` | `INT` | `PRIMARY KEY`, `AUTO_INCREMENT` | - | 用户ID,主键,自增 |
|
||||
| `username` | `VARCHAR(50)` | `NOT NULL`, `UNIQUE` | - | 用户名,唯一 |
|
||||
| `password` | `VARCHAR(255)` | `NOT NULL` | - | 用户密码,使用哈希加密存储 |
|
||||
| `role` | `ENUM` | `NOT NULL` | `UserRole.USER` | 用户角色,枚举类型,可选值为 `UserRole.SYSTEM_ADMIN`, `UserRole.ADMIN`, `UserRole.USER` |
|
||||
| `description` | `TEXT` | - | - | 用户描述,可选 |
|
||||
| `created_at` | `DATETIME` | - | `当前UTC时间` | 用户创建时间,默认值为当前UTC时间 |
|
||||
| `updated_at` | `DATETIME` | - | `当前UTC时间,更新时自动更新` | 用户信息更新时间,默认值为当前UTC时间,更新时自动更新 |
|
||||
|
||||
### 数据库初始化
|
||||
|
||||
* 系统初始化时,后端控制层会检查表是否存在,若不存在则自动创建表。
|
||||
* 检查表中是否存在系统管理员,若不存在则创建一个名为`admin`、密码为`password`、描述为`default system admin`的系统管理员。
|
||||
|
||||
## 鉴权设计
|
||||
|
||||
### 角色权限
|
||||
|
||||
* **系统管理员、管理员**:可以访问所有 API,包括用户增删改查。
|
||||
* **普通用户**:只能访问用户列表(只读)。
|
||||
|
||||
### JWT 鉴权
|
||||
|
||||
JWT(JSON Web Token)用于用户身份验证和权限控制。JWT 包含以下信息:
|
||||
|
||||
* **Payload**:用户 ID、用户名、角色、Token 过期时间、token_type等。
|
||||
* **签名**:使用后端密钥对 Payload 进行签名,确保 Token 的完整性和安全性。
|
||||
|
||||
### JWT Token 数据结构
|
||||
|
||||
```python
|
||||
class TokenPayload:
|
||||
id: int # 用户ID
|
||||
username: str # 用户名
|
||||
role: str # 用户角色
|
||||
exp: int # Token过期时间
|
||||
token_type: str # Token类型(access或refresh)
|
||||
```
|
||||
|
||||
### JWT 自动过期机制
|
||||
|
||||
* **Access Token**:用于常规 API 请求,有效期较短(如 30 分钟)。
|
||||
* **Refresh Token**:用于刷新 Access Token,有效期较长(如 7 天)。
|
||||
|
||||
#### Token 刷新流程
|
||||
|
||||
1. 客户端使用 Refresh Token 请求 `/api/auth/refresh` 接口。
|
||||
2. 服务端验证 Refresh Token 的有效性。
|
||||
3. 服务端生成新的 Access Token 和 Refresh Token,并返回给客户端。
|
||||
4. 客户端更新本地存储的 Token。
|
||||
|
||||
### Token 验证
|
||||
|
||||
* **verify_access_token**:验证access token有效性并返回payload,如果token无效或类型不匹配则返回None
|
||||
* **verify_refresh_token**:验证refresh token有效性并返回payload,如果token无效或类型不匹配则返回None
|
||||
|
||||
## 单元测试
|
||||
|
||||
### 测试框架
|
||||
|
||||
* 使用 `pytest` 进行单元测试。
|
||||
* 使用 `requests` 库模拟 API 请求。
|
||||
|
||||
### 测试用例
|
||||
|
||||
1. **用户登录**:
|
||||
* 测试正确的用户名和密码。
|
||||
* 测试错误的用户名和密码。
|
||||
2. **用户登出**:
|
||||
* 测试已登录用户登出。
|
||||
* 测试未登录用户登出。
|
||||
3. **刷新 JWT Token**:
|
||||
* 测试有效的 Refresh Token。
|
||||
* 测试无效的 Refresh Token。
|
||||
4. **获取用户列表**:
|
||||
* 测试不同角色的用户访问权限。
|
||||
* 测试分页和过滤功能。
|
||||
5. **创建用户**:
|
||||
* 测试管理员创建用户。
|
||||
* 测试普通用户尝试创建用户。
|
||||
6. **更新用户信息**:
|
||||
* 测试管理员更新用户信息。
|
||||
* 测试普通用户尝试更新用户信息。
|
||||
7. **删除用户**:
|
||||
* 测试管理员删除用户。
|
||||
* 测试普通用户尝试删除用户。
|
||||
|
||||
## 日志记录
|
||||
|
||||
* 使用 `logging` 模块记录系统日志。
|
||||
* 日志级别包括 `INFO`、`WARNING`、`ERROR`。
|
||||
* 日志内容包括用户操作、API 请求、错误信息等。
|
||||
|
||||
## 后续开发
|
||||
* **部署文档**:包括 Docker 容器化、CI/CD 流程、环境变量配置等。
|
47
backend/main.py
Normal file
47
backend/main.py
Normal file
@ -0,0 +1,47 @@
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from config import JWT_CONFIG, LOGGING_CONFIG, SYSTEM_ADMIN_CONFIG
|
||||
from services.db import create_db_engine, init_db, close_db_connection
|
||||
from routes.auth import router as auth_router
|
||||
from routes.users import router as users_router
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=LOGGING_CONFIG['level'],
|
||||
format=LOGGING_CONFIG['format'],
|
||||
filename=LOGGING_CONFIG['filename']
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化FastAPI应用
|
||||
app = FastAPI(
|
||||
title="User Management System",
|
||||
description="API for managing users with role-based access control",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# 数据库初始化
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("Initializing database...")
|
||||
engine = create_db_engine()
|
||||
await init_db(engine)
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("Closing database connections...")
|
||||
await close_db_connection()
|
||||
logger.info("Database connections closed")
|
||||
|
||||
# 注册路由
|
||||
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
|
||||
app.include_router(users_router, prefix="/api/users", tags=["users"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
3
backend/models/__init__.py
Normal file
3
backend/models/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .user import User
|
||||
|
||||
__all__ = ["User"]
|
25
backend/models/user.py
Normal file
25
backend/models/user.py
Normal file
@ -0,0 +1,25 @@
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Enum as SQLEnum
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class UserRole(str, Enum):
|
||||
SYSTEM_ADMIN = "system_admin"
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
username = Column(String(50), unique=True, nullable=False)
|
||||
password = Column(String(255), nullable=False)
|
||||
role = Column(SQLEnum(UserRole), nullable=False, default=UserRole.USER)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, username={self.username}, role={self.role})>"
|
9
backend/requirements.txt
Normal file
9
backend/requirements.txt
Normal file
@ -0,0 +1,9 @@
|
||||
fastapi>=0.95.2
|
||||
python-dotenv>=1.0.0
|
||||
sqlalchemy>=2.0.15
|
||||
passlib>=1.7.4
|
||||
bcrypt>=4.0.1
|
||||
pyjwt>=2.6.0
|
||||
pytest>=7.3.1
|
||||
requests>=2.28.2
|
||||
uvicorn[standard]>=0.21.1
|
32
backend/routes/auth.py
Normal file
32
backend/routes/auth.py
Normal file
@ -0,0 +1,32 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest
|
||||
from services.auth import create_tokens_response, refresh_tokens
|
||||
from services.user import authenticate_user
|
||||
from services.db import get_db_session_dep
|
||||
|
||||
router = APIRouter(tags=["auth"])
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(
|
||||
login_data: LoginRequest,
|
||||
session: AsyncSession = Depends(get_db_session_dep)
|
||||
):
|
||||
user = await authenticate_user(session, login_data.username, login_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return create_tokens_response(user.id, user.username, user.role)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_token(refresh_data: RefreshTokenRequest):
|
||||
tokens = refresh_tokens(refresh_data.refresh_token)
|
||||
if not tokens:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
return tokens
|
34
backend/routes/depends.py
Normal file
34
backend/routes/depends.py
Normal file
@ -0,0 +1,34 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from typing import Optional
|
||||
from schemas.auth import TokenPayload
|
||||
from schemas.user import UserRole
|
||||
from services.auth import verify_access_token
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
||||
|
||||
async def _get_token_data(token: str) -> TokenPayload:
|
||||
"""验证并返回TokenData"""
|
||||
token_data = verify_access_token(token)
|
||||
if token_data is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return token_data
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)) -> TokenPayload:
|
||||
"""获取当前用户"""
|
||||
return await _get_token_data(token)
|
||||
|
||||
async def get_current_admin(token: str = Depends(oauth2_scheme)) -> TokenPayload:
|
||||
"""获取当前管理员用户"""
|
||||
token_data = await _get_token_data(token)
|
||||
if token_data.role not in [UserRole.SYSTEM_ADMIN, UserRole.ADMIN]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: Insufficient privileges for this operation",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return token_data
|
69
backend/routes/users.py
Normal file
69
backend/routes/users.py
Normal file
@ -0,0 +1,69 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional
|
||||
from schemas.auth import TokenPayload
|
||||
from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole
|
||||
from routes.depends import get_current_user,get_current_admin
|
||||
import services.user as user_service
|
||||
|
||||
from services.db import get_db_session_dep
|
||||
|
||||
router = APIRouter(tags=["users"])
|
||||
|
||||
@router.get("/", response_model=List[UserResponse])
|
||||
async def get_users_list(
|
||||
page: int = 1,
|
||||
limit: int = 100,
|
||||
role: Optional[str] = None,
|
||||
current_user_token: TokenPayload = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db_session_dep)
|
||||
):
|
||||
skip = (page - 1) * limit
|
||||
users = await user_service.get_users_list(session, skip=skip, limit=limit)
|
||||
if role:
|
||||
users = [user for user in users if user.role == role]
|
||||
return users
|
||||
|
||||
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
user_data: UserCreate,
|
||||
current_user_token: TokenPayload = Depends(get_current_admin),
|
||||
session: AsyncSession = Depends(get_db_session_dep)
|
||||
):
|
||||
return await user_service.create_user(session, user_data)
|
||||
|
||||
@router.put("/{user_id}", response_model=UserResponse)
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
user_data: UserUpdate,
|
||||
current_user_token: TokenPayload = Depends(get_current_admin),
|
||||
session: AsyncSession = Depends(get_db_session_dep)
|
||||
):
|
||||
return await user_service.update_user(session, user_id, user_data)
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
current_user_token: TokenPayload = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db_session_dep)
|
||||
):
|
||||
user = await user_service.get_user(session, user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
return user
|
||||
|
||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
current_user_token: TokenPayload = Depends(get_current_admin),
|
||||
session: AsyncSession = Depends(get_db_session_dep)
|
||||
):
|
||||
success = await user_service.delete_user(session, user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
22
backend/schemas/auth.py
Normal file
22
backend/schemas/auth.py
Normal file
@ -0,0 +1,22 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
access_token_exp: int
|
||||
refresh_token_exp: int
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
role: str
|
||||
exp: int
|
||||
token_type: str
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
43
backend/schemas/user.py
Normal file
43
backend/schemas/user.py
Normal file
@ -0,0 +1,43 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from typing import Optional
|
||||
|
||||
# 用户角色枚举
|
||||
class UserRole(str, Enum):
|
||||
SYSTEM_ADMIN = "system_admin"
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
|
||||
# 基础用户模型
|
||||
class UserBase(BaseModel):
|
||||
username: str = Field(..., max_length=50, description="用户名")
|
||||
role: UserRole = Field(default=UserRole.USER, description="用户角色")
|
||||
description: Optional[str] = Field(None, max_length=255, description="用户描述")
|
||||
|
||||
# 用户创建模型
|
||||
class UserCreate(UserBase):
|
||||
password: str = Field(..., min_length=6, max_length=255, description="用户密码")
|
||||
|
||||
|
||||
# 用户更新模型
|
||||
class UserUpdate(BaseModel):
|
||||
username: Optional[str] = Field(None, max_length=50, description="用户名")
|
||||
role: Optional[UserRole] = Field(None, description="用户角色")
|
||||
description: Optional[str] = Field(None, max_length=255, description="用户描述")
|
||||
|
||||
# 可选:确保至少更新一个字段
|
||||
@root_validator
|
||||
def validate_at_least_one_field(cls, values):
|
||||
if not any(values.values()):
|
||||
raise ValueError("至少需要更新一个字段")
|
||||
return values
|
||||
|
||||
# 用户响应模型
|
||||
class UserResponse(UserBase):
|
||||
id: int = Field(..., description="用户ID")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
updated_at: datetime = Field(..., description="更新时间")
|
||||
|
||||
class Config:
|
||||
orm_mode = True # 允许从 ORM 对象加载数据
|
90
backend/services/auth.py
Normal file
90
backend/services/auth.py
Normal file
@ -0,0 +1,90 @@
|
||||
from typing import Optional
|
||||
import jwt
|
||||
import time
|
||||
from config import JWT_CONFIG
|
||||
from schemas.auth import TokenResponse, TokenPayload
|
||||
|
||||
SECRET_KEY = JWT_CONFIG['secret_key']
|
||||
ALGORITHM = JWT_CONFIG['algorithm']
|
||||
ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire']
|
||||
REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire']
|
||||
|
||||
|
||||
def get_current_time() -> int:
|
||||
"""获取当前UTC时间戳"""
|
||||
return int(time.time())
|
||||
|
||||
def create_token(user_id: int, username: str, role: str, token_type: str = "access") -> str:
|
||||
"""创建JWT token"""
|
||||
expire_delta = ACCESS_TOKEN_EXPIRE if token_type == "access" else REFRESH_TOKEN_EXPIRE
|
||||
expire = get_current_time() + int(expire_delta.total_seconds())
|
||||
|
||||
to_encode = {
|
||||
"id": user_id,
|
||||
"username": username,
|
||||
"role": role,
|
||||
"exp": expire,
|
||||
"token_type": token_type
|
||||
}
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
def create_tokens_response(user_id: int, username: str, role: str) -> TokenResponse:
|
||||
"""创建access token和refresh token"""
|
||||
access_token = create_token(user_id, username, role, "access")
|
||||
refresh_token = create_token(user_id, username, role, "refresh")
|
||||
|
||||
# 获取token的过期时间
|
||||
access_token_exp = get_current_time() + int(ACCESS_TOKEN_EXPIRE.total_seconds())
|
||||
refresh_token_exp = get_current_time() + int(REFRESH_TOKEN_EXPIRE.total_seconds())
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
access_token_exp=access_token_exp,
|
||||
refresh_token_exp=refresh_token_exp
|
||||
)
|
||||
|
||||
def verify_access_token(token: str) -> Optional[TokenPayload]:
|
||||
"""验证access token有效性并返回payload,如果token无效或类型不匹配则返回None"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
if payload.get("token_type") != "access":
|
||||
return None
|
||||
return TokenPayload(
|
||||
id=payload.get("id"),
|
||||
username=payload.get("username"),
|
||||
role=payload.get("role"),
|
||||
exp=payload.get("exp"),
|
||||
token_type=payload.get("token_type")
|
||||
)
|
||||
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
||||
return None
|
||||
|
||||
def verify_refresh_token(token: str) -> Optional[TokenPayload]:
|
||||
"""验证refresh token有效性并返回payload,如果token无效或类型不匹配则返回None"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
if payload.get("token_type") != "refresh":
|
||||
return None
|
||||
return TokenPayload(
|
||||
id=payload.get("id"),
|
||||
username=payload.get("username"),
|
||||
role=payload.get("role"),
|
||||
exp=payload.get("exp"),
|
||||
token_type=payload.get("token_type")
|
||||
)
|
||||
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
||||
return None
|
||||
|
||||
def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]:
|
||||
"""使用refresh token刷新access token,如果refresh token无效则返回None"""
|
||||
token_data = verify_refresh_token(refresh_token)
|
||||
if token_data is None:
|
||||
return None
|
||||
else:
|
||||
return create_tokens_response(
|
||||
user_id=token_data.id,
|
||||
username=token_data.username,
|
||||
role=token_data.role
|
||||
)
|
77
backend/services/db.py
Normal file
77
backend/services/db.py
Normal file
@ -0,0 +1,77 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from models.user import Base, User, UserRole
|
||||
from sqlalchemy import select
|
||||
from contextlib import asynccontextmanager
|
||||
from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG
|
||||
from services.user import get_password_hash
|
||||
from typing import AsyncGenerator
|
||||
|
||||
# 全局数据库引擎实例
|
||||
_engine: AsyncEngine | None = None
|
||||
|
||||
def create_db_engine() -> AsyncEngine:
|
||||
"""创建数据库引擎"""
|
||||
return create_async_engine(
|
||||
f"mysql+asyncmy://{DATABASE_CONFIG['user']}:{DATABASE_CONFIG['password']}@"
|
||||
f"{DATABASE_CONFIG['host']}:{DATABASE_CONFIG['port']}/{DATABASE_CONFIG['database']}",
|
||||
echo=True
|
||||
)
|
||||
|
||||
def get_db_engine() -> AsyncEngine:
|
||||
"""获取全局数据库引擎实例"""
|
||||
if _engine is None:
|
||||
raise RuntimeError("Database engine not initialized")
|
||||
return _engine
|
||||
|
||||
async def get_db_session() -> AsyncSession:
|
||||
"""获取数据库会话"""
|
||||
async with AsyncSession(get_db_engine()) as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
async def get_db_session_dep() -> AsyncSession:
|
||||
"""FastAPI依赖注入使用的数据库会话获取函数"""
|
||||
async with AsyncSession(get_db_engine()) as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
async def init_db(engine: AsyncEngine):
|
||||
"""初始化数据库"""
|
||||
global _engine
|
||||
_engine = engine
|
||||
|
||||
async with engine.begin() as conn:
|
||||
# 创建所有表
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
# 检查系统管理员是否存在
|
||||
result = await session.execute(
|
||||
select(User).where(User.role == UserRole.SYSTEM_ADMIN)
|
||||
)
|
||||
if not result.scalars().first():
|
||||
# 创建默认系统管理员
|
||||
admin = User(
|
||||
username=SYSTEM_ADMIN_CONFIG['username'],
|
||||
password=get_password_hash(SYSTEM_ADMIN_CONFIG['password']),
|
||||
role=UserRole.SYSTEM_ADMIN,
|
||||
description=SYSTEM_ADMIN_CONFIG['description']
|
||||
)
|
||||
session.add(admin)
|
||||
await session.commit()
|
||||
|
||||
async def close_db_connection():
|
||||
"""关闭数据库连接"""
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
await _engine.dispose()
|
||||
_engine = None
|
74
backend/services/user.py
Normal file
74
backend/services/user.py
Normal file
@ -0,0 +1,74 @@
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from passlib.context import CryptContext
|
||||
from models.user import User
|
||||
from schemas.user import UserCreate, UserUpdate, UserResponse
|
||||
|
||||
# 创建一个密码上下文对象,指定使用 bcrypt 加密算法
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse:
|
||||
"""创建用户"""
|
||||
hashed_password = pwd_context.hash(user_data.password)
|
||||
user = User(
|
||||
username=user_data.username,
|
||||
password=hashed_password,
|
||||
role=user_data.role,
|
||||
description=user_data.description
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserResponse.from_orm(user)
|
||||
|
||||
|
||||
async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]:
|
||||
"""根据ID获取用户"""
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
return UserResponse.from_orm(user) if user else None
|
||||
|
||||
|
||||
async def get_users_list(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]:
|
||||
"""获取用户列表"""
|
||||
result = await session.execute(select(User).offset(skip).limit(limit))
|
||||
users = result.scalars().all()
|
||||
return [UserResponse.from_orm(user) for user in users]
|
||||
|
||||
|
||||
async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]:
|
||||
"""更新用户信息"""
|
||||
await session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id)
|
||||
.values(**user_data.dict(exclude_unset=True))
|
||||
)
|
||||
await session.commit()
|
||||
return await get_user(session, user_id)
|
||||
|
||||
|
||||
async def delete_user(session: AsyncSession, user_id: int) -> bool:
|
||||
"""删除用户"""
|
||||
result = await session.execute(delete(User).where(User.id == user_id))
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证输入的明文密码是否与存储的哈希密码匹配"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""生成使用 bcrypt 的密码哈希"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]:
|
||||
"""验证用户登录"""
|
||||
result = await session.execute(select(User).where(User.username == username))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, user.password):
|
||||
return None
|
||||
return UserResponse.from_orm(user)
|
Loading…
x
Reference in New Issue
Block a user