Add 'backend/' from commit '48a644fb354d6c6efcbd12bc1b4a2cb83137b68e'
git-subtree-dir: backend git-subtree-mainline:545699d16f
git-subtree-split:48a644fb35
This commit is contained in:
90
backend/services/auth.py
Normal file
90
backend/services/auth.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Optional
|
||||
import jwt
|
||||
import time
|
||||
from config import JWT_CONFIG
|
||||
from schemas.auth import TokenResponse, TokenPayload
|
||||
|
||||
SECRET_KEY = JWT_CONFIG['secret_key']
|
||||
ALGORITHM = JWT_CONFIG['algorithm']
|
||||
ACCESS_TOKEN_EXPIRE = JWT_CONFIG['access_token_expire']
|
||||
REFRESH_TOKEN_EXPIRE = JWT_CONFIG['refresh_token_expire']
|
||||
|
||||
|
||||
def get_current_time() -> int:
|
||||
"""获取当前UTC时间戳"""
|
||||
return int(time.time())
|
||||
|
||||
def create_token(user_id: int, username: str, role: str, token_type: str = "access") -> str:
|
||||
"""创建JWT token"""
|
||||
expire_delta = ACCESS_TOKEN_EXPIRE if token_type == "access" else REFRESH_TOKEN_EXPIRE
|
||||
expire = get_current_time() + int(expire_delta.total_seconds())
|
||||
|
||||
to_encode = {
|
||||
"id": user_id,
|
||||
"username": username,
|
||||
"role": role,
|
||||
"exp": expire,
|
||||
"token_type": token_type
|
||||
}
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
def create_tokens_response(user_id: int, username: str, role: str) -> TokenResponse:
|
||||
"""创建access token和refresh token"""
|
||||
access_token = create_token(user_id, username, role, "access")
|
||||
refresh_token = create_token(user_id, username, role, "refresh")
|
||||
|
||||
# 获取token的过期时间
|
||||
access_token_exp = get_current_time() + int(ACCESS_TOKEN_EXPIRE.total_seconds())
|
||||
refresh_token_exp = get_current_time() + int(REFRESH_TOKEN_EXPIRE.total_seconds())
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
access_token_exp=access_token_exp,
|
||||
refresh_token_exp=refresh_token_exp
|
||||
)
|
||||
|
||||
def verify_access_token(token: str) -> Optional[TokenPayload]:
|
||||
"""验证access token有效性并返回payload,如果token无效或类型不匹配则返回None"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
if payload.get("token_type") != "access":
|
||||
return None
|
||||
return TokenPayload(
|
||||
id=payload.get("id"),
|
||||
username=payload.get("username"),
|
||||
role=payload.get("role"),
|
||||
exp=payload.get("exp"),
|
||||
token_type=payload.get("token_type")
|
||||
)
|
||||
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
||||
return None
|
||||
|
||||
def verify_refresh_token(token: str) -> Optional[TokenPayload]:
|
||||
"""验证refresh token有效性并返回payload,如果token无效或类型不匹配则返回None"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
if payload.get("token_type") != "refresh":
|
||||
return None
|
||||
return TokenPayload(
|
||||
id=payload.get("id"),
|
||||
username=payload.get("username"),
|
||||
role=payload.get("role"),
|
||||
exp=payload.get("exp"),
|
||||
token_type=payload.get("token_type")
|
||||
)
|
||||
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
||||
return None
|
||||
|
||||
def refresh_tokens(refresh_token: str) -> Optional[TokenResponse]:
|
||||
"""使用refresh token刷新access token,如果refresh token无效则返回None"""
|
||||
token_data = verify_refresh_token(refresh_token)
|
||||
if token_data is None:
|
||||
return None
|
||||
else:
|
||||
return create_tokens_response(
|
||||
user_id=token_data.id,
|
||||
username=token_data.username,
|
||||
role=token_data.role
|
||||
)
|
77
backend/services/db.py
Normal file
77
backend/services/db.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from models.user import Base, User, UserRole
|
||||
from sqlalchemy import select
|
||||
from contextlib import asynccontextmanager
|
||||
from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG
|
||||
from services.user import get_password_hash
|
||||
from typing import AsyncGenerator
|
||||
|
||||
# 全局数据库引擎实例
|
||||
_engine: AsyncEngine | None = None
|
||||
|
||||
def create_db_engine() -> AsyncEngine:
|
||||
"""创建数据库引擎"""
|
||||
return create_async_engine(
|
||||
f"mysql+asyncmy://{DATABASE_CONFIG['user']}:{DATABASE_CONFIG['password']}@"
|
||||
f"{DATABASE_CONFIG['host']}:{DATABASE_CONFIG['port']}/{DATABASE_CONFIG['database']}",
|
||||
echo=True
|
||||
)
|
||||
|
||||
def get_db_engine() -> AsyncEngine:
|
||||
"""获取全局数据库引擎实例"""
|
||||
if _engine is None:
|
||||
raise RuntimeError("Database engine not initialized")
|
||||
return _engine
|
||||
|
||||
async def get_db_session() -> AsyncSession:
|
||||
"""获取数据库会话"""
|
||||
async with AsyncSession(get_db_engine()) as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
async def get_db_session_dep() -> AsyncSession:
|
||||
"""FastAPI依赖注入使用的数据库会话获取函数"""
|
||||
async with AsyncSession(get_db_engine()) as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
async def init_db(engine: AsyncEngine):
|
||||
"""初始化数据库"""
|
||||
global _engine
|
||||
_engine = engine
|
||||
|
||||
async with engine.begin() as conn:
|
||||
# 创建所有表
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
# 检查系统管理员是否存在
|
||||
result = await session.execute(
|
||||
select(User).where(User.role == UserRole.SYSTEM_ADMIN)
|
||||
)
|
||||
if not result.scalars().first():
|
||||
# 创建默认系统管理员
|
||||
admin = User(
|
||||
username=SYSTEM_ADMIN_CONFIG['username'],
|
||||
password=get_password_hash(SYSTEM_ADMIN_CONFIG['password']),
|
||||
role=UserRole.SYSTEM_ADMIN,
|
||||
description=SYSTEM_ADMIN_CONFIG['description']
|
||||
)
|
||||
session.add(admin)
|
||||
await session.commit()
|
||||
|
||||
async def close_db_connection():
|
||||
"""关闭数据库连接"""
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
await _engine.dispose()
|
||||
_engine = None
|
74
backend/services/user.py
Normal file
74
backend/services/user.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from passlib.context import CryptContext
|
||||
from models.user import User
|
||||
from schemas.user import UserCreate, UserUpdate, UserResponse
|
||||
|
||||
# 创建一个密码上下文对象,指定使用 bcrypt 加密算法
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse:
|
||||
"""创建用户"""
|
||||
hashed_password = pwd_context.hash(user_data.password)
|
||||
user = User(
|
||||
username=user_data.username,
|
||||
password=hashed_password,
|
||||
role=user_data.role,
|
||||
description=user_data.description
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserResponse.from_orm(user)
|
||||
|
||||
|
||||
async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse]:
|
||||
"""根据ID获取用户"""
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
return UserResponse.from_orm(user) if user else None
|
||||
|
||||
|
||||
async def get_users_list(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]:
|
||||
"""获取用户列表"""
|
||||
result = await session.execute(select(User).offset(skip).limit(limit))
|
||||
users = result.scalars().all()
|
||||
return [UserResponse.from_orm(user) for user in users]
|
||||
|
||||
|
||||
async def update_user(session: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserResponse]:
|
||||
"""更新用户信息"""
|
||||
await session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id)
|
||||
.values(**user_data.dict(exclude_unset=True))
|
||||
)
|
||||
await session.commit()
|
||||
return await get_user(session, user_id)
|
||||
|
||||
|
||||
async def delete_user(session: AsyncSession, user_id: int) -> bool:
|
||||
"""删除用户"""
|
||||
result = await session.execute(delete(User).where(User.id == user_id))
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证输入的明文密码是否与存储的哈希密码匹配"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""生成使用 bcrypt 的密码哈希"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]:
|
||||
"""验证用户登录"""
|
||||
result = await session.execute(select(User).where(User.username == username))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, user.password):
|
||||
return None
|
||||
return UserResponse.from_orm(user)
|
Reference in New Issue
Block a user