Add 'backend/' from commit '48a644fb354d6c6efcbd12bc1b4a2cb83137b68e'

git-subtree-dir: backend
git-subtree-mainline: 545699d16f
git-subtree-split: 48a644fb35
This commit is contained in:
carry
2025-02-17 17:44:42 +08:00
15 changed files with 699 additions and 0 deletions

90
backend/services/auth.py Normal file
View File

@@ -0,0 +1,90 @@
from typing import Optional
import jwt
import time
from config import JWT_CONFIG
from schemas.auth import TokenResponse, TokenPayload
SECRET_KEY = JWT_CONFIG['secret_key']
ALGORITHM = JWT_CONFIG['algorithm']
ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire']
REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire']
def get_current_time() -> int:
"""获取当前UTC时间戳"""
return int(time.time())
def create_token(user_id: int, username: str, role: str, token_type: str = "access") -> str:
"""创建JWT token"""
expire_delta = ACCESS_TOKEN_EXPIRE if token_type == "access" else REFRESH_TOKEN_EXPIRE
expire = get_current_time() + int(expire_delta.total_seconds())
to_encode = {
"id": user_id,
"username": username,
"role": role,
"exp": expire,
"token_type": token_type
}
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_tokens_response(user_id: int, username: str, role: str) -> TokenResponse:
"""创建access token和refresh token"""
access_token = create_token(user_id, username, role, "access")
refresh_token = create_token(user_id, username, role, "refresh")
# 获取token的过期时间
access_token_exp = get_current_time() + int(ACCESS_TOKEN_EXPIRE.total_seconds())
refresh_token_exp = get_current_time() + int(REFRESH_TOKEN_EXPIRE.total_seconds())
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
access_token_exp=access_token_exp,
refresh_token_exp=refresh_token_exp
)
def verify_access_token(token: str) -> Optional[TokenPayload]:
"""验证access token有效性并返回payload如果token无效或类型不匹配则返回None"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
if payload.get("token_type") != "access":
return None
return TokenPayload(
id=payload.get("id"),
username=payload.get("username"),
role=payload.get("role"),
exp=payload.get("exp"),
token_type=payload.get("token_type")
)
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
return None
def verify_refresh_token(token: str) -> Optional[TokenPayload]:
"""验证refresh token有效性并返回payload如果token无效或类型不匹配则返回None"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
if payload.get("token_type") != "refresh":
return None
return TokenPayload(
id=payload.get("id"),
username=payload.get("username"),
role=payload.get("role"),
exp=payload.get("exp"),
token_type=payload.get("token_type")
)
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
return None
def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]:
"""使用refresh token刷新access token如果refresh token无效则返回None"""
token_data = verify_refresh_token(refresh_token)
if token_data is None:
return None
else:
return create_tokens_response(
user_id=token_data.id,
username=token_data.username,
role=token_data.role
)

77
backend/services/db.py Normal file
View File

@@ -0,0 +1,77 @@
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from models.user import Base, User, UserRole
from sqlalchemy import select
from contextlib import asynccontextmanager
from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG
from services.user import get_password_hash
from typing import AsyncGenerator
# 全局数据库引擎实例
_engine: AsyncEngine | None = None
def create_db_engine() -> AsyncEngine:
"""创建数据库引擎"""
return create_async_engine(
f"mysql+asyncmy://{DATABASE_CONFIG['user']}:{DATABASE_CONFIG['password']}@"
f"{DATABASE_CONFIG['host']}:{DATABASE_CONFIG['port']}/{DATABASE_CONFIG['database']}",
echo=True
)
def get_db_engine() -> AsyncEngine:
"""获取全局数据库引擎实例"""
if _engine is None:
raise RuntimeError("Database engine not initialized")
return _engine
async def get_db_session() -> AsyncSession:
"""获取数据库会话"""
async with AsyncSession(get_db_engine()) as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
async def get_db_session_dep() -> AsyncSession:
"""FastAPI依赖注入使用的数据库会话获取函数"""
async with AsyncSession(get_db_engine()) as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
async def init_db(engine: AsyncEngine):
"""初始化数据库"""
global _engine
_engine = engine
async with engine.begin() as conn:
# 创建所有表
await conn.run_sync(Base.metadata.create_all)
async with AsyncSession(engine) as session:
# 检查系统管理员是否存在
result = await session.execute(
select(User).where(User.role == UserRole.SYSTEM_ADMIN)
)
if not result.scalars().first():
# 创建默认系统管理员
admin = User(
username=SYSTEM_ADMIN_CONFIG['username'],
password=get_password_hash(SYSTEM_ADMIN_CONFIG['password']),
role=UserRole.SYSTEM_ADMIN,
description=SYSTEM_ADMIN_CONFIG['description']
)
session.add(admin)
await session.commit()
async def close_db_connection():
"""关闭数据库连接"""
global _engine
if _engine is not None:
await _engine.dispose()
_engine = None

74
backend/services/user.py Normal file
View File

@@ -0,0 +1,74 @@
from typing import List, Optional
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from passlib.context import CryptContext
from models.user import User
from schemas.user import UserCreate, UserUpdate, UserResponse
# 创建一个密码上下文对象,指定使用 bcrypt 加密算法
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse:
"""创建用户"""
hashed_password = pwd_context.hash(user_data.password)
user = User(
username=user_data.username,
password=hashed_password,
role=user_data.role,
description=user_data.description
)
session.add(user)
await session.commit()
await session.refresh(user)
return UserResponse.from_orm(user)
async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]:
"""根据ID获取用户"""
result = await session.execute(select(User).where(User.id == user_id))
user = result.scalars().first()
return UserResponse.from_orm(user) if user else None
async def get_users_list(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]:
"""获取用户列表"""
result = await session.execute(select(User).offset(skip).limit(limit))
users = result.scalars().all()
return [UserResponse.from_orm(user) for user in users]
async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]:
"""更新用户信息"""
await session.execute(
update(User)
.where(User.id == user_id)
.values(**user_data.dict(exclude_unset=True))
)
await session.commit()
return await get_user(session, user_id)
async def delete_user(session: AsyncSession, user_id: int) -> bool:
"""删除用户"""
result = await session.execute(delete(User).where(User.id == user_id))
await session.commit()
return result.rowcount > 0
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证输入的明文密码是否与存储的哈希密码匹配"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""生成使用 bcrypt 的密码哈希"""
return pwd_context.hash(password)
async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]:
"""验证用户登录"""
result = await session.execute(select(User).where(User.username == username))
user = result.scalars().first()
if not user:
return None
if not verify_password(password, user.password):
return None
return UserResponse.from_orm(user)