修正若干bug,后端基本完成

This commit is contained in:
carry 2025-01-22 15:02:07 +08:00
parent bf856af9f9
commit 1fd8af3be9
5 changed files with 43 additions and 38 deletions

View File

@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest
from services.auth import create_tokens_response, verify_token, refresh_tokens
from services.user import authenticate_user
from services.db import get_db_session
from services.db import get_db_session_dep
router = APIRouter(tags=["auth"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@ -12,7 +12,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@router.post("/login", response_model=TokenResponse)
async def login(
login_data: LoginRequest,
session: AsyncSession = Depends(get_db_session)
session: AsyncSession = Depends(get_db_session_dep)
):
user = await authenticate_user(session, login_data.username, login_data.password)
if not user:

View File

@ -1,59 +1,55 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from schemas.auth import TokenPayload
from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole
from routes.depends import get_current_user,get_current_admin
from services.user import get_users,create_user,update_user,delete_user#,get_user_by_id
from services.db import get_db_session
import services.user as user_service
from services.db import get_db_session_dep
router = APIRouter(tags=["users"])
@router.get("/", response_model=List[UserResponse])
async def get_users(
async def get_users_list(
page: int = 1,
limit: int = 10,
limit: int = 100,
role: Optional[str] = None,
current_user_token: TokenPayload = Depends(get_current_user)
current_user_token: TokenPayload = Depends(get_current_user),
session: AsyncSession = Depends(get_db_session_dep)
):
# current_user = await get_user_by_id(current_user_token.id)
# if current_user is None:
# raise HTTPException(
# status_code=status.HTTP_404_NOT_FOUND,
# detail="User not found"
# )
async with get_db_session() as session:
skip = (page - 1) * limit
users = await get_users(session, skip=skip, limit=limit)
if role:
users = [user for user in users if user.role == role]
return users
skip = (page - 1) * limit
users = await user_service.get_users_list(session, skip=skip, limit=limit)
if role:
users = [user for user in users if user.role == role]
return users
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_user(
user_data: UserCreate,
current_user_token: TokenPayload = Depends(get_current_admin)
current_user_token: TokenPayload = Depends(get_current_admin),
session: AsyncSession = Depends(get_db_session_dep)
):
async with get_db_session() as session:
return await create_user(session, user_data)
return await user_service.create_user(session, user_data)
@router.put("/{user_id}", response_model=UserResponse)
async def update_user(
user_id: int,
user_data: UserUpdate,
current_user_token: TokenPayload = Depends(get_current_admin)
current_user_token: TokenPayload = Depends(get_current_admin),
session: AsyncSession = Depends(get_db_session_dep)
):
async with get_db_session() as session:
return await update_user(session, user_id, user_data)
return await user_service.update_user(session, user_id, user_data)
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(
user_id: int,
current_user_token: TokenPayload = Depends(get_current_admin)
current_user_token: TokenPayload = Depends(get_current_admin),
session: AsyncSession = Depends(get_db_session_dep)
):
async with get_db_session() as session:
success = await delete_user(session, user_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
success = await user_service.delete_user(session, user_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)

View File

@ -14,9 +14,9 @@ def get_current_time() -> int:
"""获取当前UTC时间戳"""
return int(time.time())
def create_token(user_id: int, username: str, role: str, expire_delta: int) -> str:
def create_token(user_id: int, username: str, role: str, expire_delta) -> str:
"""创建JWT token"""
expire = get_current_time() + expire_delta
expire = get_current_time() + int(expire_delta.total_seconds())
to_encode = {
"id": user_id,
"username": username,

View File

@ -24,8 +24,7 @@ def get_db_engine() -> AsyncEngine:
raise RuntimeError("Database engine not initialized")
return _engine
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
async def get_db_session() -> AsyncSession:
"""获取数据库会话"""
async with AsyncSession(get_db_engine()) as session:
try:
@ -35,6 +34,16 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
await session.rollback()
raise
async def get_db_session_dep() -> AsyncSession:
"""FastAPI依赖注入使用的数据库会话获取函数"""
async with AsyncSession(get_db_engine()) as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
async def init_db(engine: AsyncEngine):
"""初始化数据库"""
global _engine

View File

@ -30,7 +30,7 @@ async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse
return UserResponse.from_orm(user) if user else None
async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]:
async def get_users_list(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]:
"""获取用户列表"""
result = await session.execute(select(User).offset(skip).limit(limit))
users = result.scalars().all()