修正若干bug,后端基本完成
This commit is contained in:
parent
bf856af9f9
commit
1fd8af3be9
@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest
|
from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest
|
||||||
from services.auth import create_tokens_response, verify_token, refresh_tokens
|
from services.auth import create_tokens_response, verify_token, refresh_tokens
|
||||||
from services.user import authenticate_user
|
from services.user import authenticate_user
|
||||||
from services.db import get_db_session
|
from services.db import get_db_session_dep
|
||||||
|
|
||||||
router = APIRouter(tags=["auth"])
|
router = APIRouter(tags=["auth"])
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
@ -12,7 +12,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|||||||
@router.post("/login", response_model=TokenResponse)
|
@router.post("/login", response_model=TokenResponse)
|
||||||
async def login(
|
async def login(
|
||||||
login_data: LoginRequest,
|
login_data: LoginRequest,
|
||||||
session: AsyncSession = Depends(get_db_session)
|
session: AsyncSession = Depends(get_db_session_dep)
|
||||||
):
|
):
|
||||||
user = await authenticate_user(session, login_data.username, login_data.password)
|
user = await authenticate_user(session, login_data.username, login_data.password)
|
||||||
if not user:
|
if not user:
|
||||||
|
@ -1,59 +1,55 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from schemas.auth import TokenPayload
|
from schemas.auth import TokenPayload
|
||||||
from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole
|
from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole
|
||||||
from routes.depends import get_current_user,get_current_admin
|
from routes.depends import get_current_user,get_current_admin
|
||||||
from services.user import get_users,create_user,update_user,delete_user#,get_user_by_id
|
import services.user as user_service
|
||||||
from services.db import get_db_session
|
|
||||||
|
from services.db import get_db_session_dep
|
||||||
|
|
||||||
router = APIRouter(tags=["users"])
|
router = APIRouter(tags=["users"])
|
||||||
|
|
||||||
@router.get("/", response_model=List[UserResponse])
|
@router.get("/", response_model=List[UserResponse])
|
||||||
async def get_users(
|
async def get_users_list(
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
limit: int = 10,
|
limit: int = 100,
|
||||||
role: Optional[str] = None,
|
role: Optional[str] = None,
|
||||||
current_user_token: TokenPayload = Depends(get_current_user)
|
current_user_token: TokenPayload = Depends(get_current_user),
|
||||||
|
session: AsyncSession = Depends(get_db_session_dep)
|
||||||
):
|
):
|
||||||
# current_user = await get_user_by_id(current_user_token.id)
|
skip = (page - 1) * limit
|
||||||
# if current_user is None:
|
users = await user_service.get_users_list(session, skip=skip, limit=limit)
|
||||||
# raise HTTPException(
|
if role:
|
||||||
# status_code=status.HTTP_404_NOT_FOUND,
|
users = [user for user in users if user.role == role]
|
||||||
# detail="User not found"
|
return users
|
||||||
# )
|
|
||||||
async with get_db_session() as session:
|
|
||||||
skip = (page - 1) * limit
|
|
||||||
users = await get_users(session, skip=skip, limit=limit)
|
|
||||||
if role:
|
|
||||||
users = [user for user in users if user.role == role]
|
|
||||||
return users
|
|
||||||
|
|
||||||
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def create_user(
|
async def create_user(
|
||||||
user_data: UserCreate,
|
user_data: UserCreate,
|
||||||
current_user_token: TokenPayload = Depends(get_current_admin)
|
current_user_token: TokenPayload = Depends(get_current_admin),
|
||||||
|
session: AsyncSession = Depends(get_db_session_dep)
|
||||||
):
|
):
|
||||||
async with get_db_session() as session:
|
return await user_service.create_user(session, user_data)
|
||||||
return await create_user(session, user_data)
|
|
||||||
|
|
||||||
@router.put("/{user_id}", response_model=UserResponse)
|
@router.put("/{user_id}", response_model=UserResponse)
|
||||||
async def update_user(
|
async def update_user(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
user_data: UserUpdate,
|
user_data: UserUpdate,
|
||||||
current_user_token: TokenPayload = Depends(get_current_admin)
|
current_user_token: TokenPayload = Depends(get_current_admin),
|
||||||
|
session: AsyncSession = Depends(get_db_session_dep)
|
||||||
):
|
):
|
||||||
async with get_db_session() as session:
|
return await user_service.update_user(session, user_id, user_data)
|
||||||
return await update_user(session, user_id, user_data)
|
|
||||||
|
|
||||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
async def delete_user(
|
async def delete_user(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
current_user_token: TokenPayload = Depends(get_current_admin)
|
current_user_token: TokenPayload = Depends(get_current_admin),
|
||||||
|
session: AsyncSession = Depends(get_db_session_dep)
|
||||||
):
|
):
|
||||||
async with get_db_session() as session:
|
success = await user_service.delete_user(session, user_id)
|
||||||
success = await delete_user(session, user_id)
|
if not success:
|
||||||
if not success:
|
raise HTTPException(
|
||||||
raise HTTPException(
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
detail="User not found"
|
||||||
detail="User not found"
|
)
|
||||||
)
|
|
||||||
|
@ -14,9 +14,9 @@ def get_current_time() -> int:
|
|||||||
"""获取当前UTC时间戳"""
|
"""获取当前UTC时间戳"""
|
||||||
return int(time.time())
|
return int(time.time())
|
||||||
|
|
||||||
def create_token(user_id: int, username: str, role: str, expire_delta: int) -> str:
|
def create_token(user_id: int, username: str, role: str, expire_delta) -> str:
|
||||||
"""创建JWT token"""
|
"""创建JWT token"""
|
||||||
expire = get_current_time() + expire_delta
|
expire = get_current_time() + int(expire_delta.total_seconds())
|
||||||
to_encode = {
|
to_encode = {
|
||||||
"id": user_id,
|
"id": user_id,
|
||||||
"username": username,
|
"username": username,
|
||||||
|
@ -24,8 +24,7 @@ def get_db_engine() -> AsyncEngine:
|
|||||||
raise RuntimeError("Database engine not initialized")
|
raise RuntimeError("Database engine not initialized")
|
||||||
return _engine
|
return _engine
|
||||||
|
|
||||||
@asynccontextmanager
|
async def get_db_session() -> AsyncSession:
|
||||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
|
||||||
"""获取数据库会话"""
|
"""获取数据库会话"""
|
||||||
async with AsyncSession(get_db_engine()) as session:
|
async with AsyncSession(get_db_engine()) as session:
|
||||||
try:
|
try:
|
||||||
@ -35,6 +34,16 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def get_db_session_dep() -> AsyncSession:
|
||||||
|
"""FastAPI依赖注入使用的数据库会话获取函数"""
|
||||||
|
async with AsyncSession(get_db_engine()) as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
async def init_db(engine: AsyncEngine):
|
async def init_db(engine: AsyncEngine):
|
||||||
"""初始化数据库"""
|
"""初始化数据库"""
|
||||||
global _engine
|
global _engine
|
||||||
|
@ -30,7 +30,7 @@ async def get_user(session: AsyncSession, user_id: int) -> Optional[UserResponse
|
|||||||
return UserResponse.from_orm(user) if user else None
|
return UserResponse.from_orm(user) if user else None
|
||||||
|
|
||||||
|
|
||||||
async def get_users(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]:
|
async def get_users_list(session: AsyncSession, skip: int = 0, limit: int = 100) -> List[UserResponse]:
|
||||||
"""获取用户列表"""
|
"""获取用户列表"""
|
||||||
result = await session.execute(select(User).offset(skip).limit(limit))
|
result = await session.execute(select(User).offset(skip).limit(limit))
|
||||||
users = result.scalars().all()
|
users = result.scalars().all()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user