From bf856af9f9b685792273502b68999aaf1769a093 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Wed, 22 Jan 2025 13:58:39 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E4=BA=86=E7=A8=8B=E5=BA=8F?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E4=BA=86=E7=9B=AE=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 7 +++++-- routes/auth.py | 8 ++++---- routes/depends.py | 2 +- routes/users.py | 14 +++++++------- services/{auth_service.py => auth.py} | 0 services/db.py | 11 ++++++----- services/{user_services.py => user.py} | 4 +++- 7 files changed, 26 insertions(+), 20 deletions(-) rename services/{auth_service.py => auth.py} (100%) rename services/{user_services.py => user.py} (95%) diff --git a/main.py b/main.py index 0346ec8..6a9e6f6 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,13 @@ +from dotenv import load_dotenv +load_dotenv() + import logging from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from config import JWT_CONFIG, LOGGING_CONFIG, SYSTEM_ADMIN_CONFIG from services.db import create_db_engine, init_db, close_db_connection -from routes.auth import auth_router -from routes.users import users_router +from routes.auth import router as auth_router +from routes.users import router as users_router # 配置日志 logging.basicConfig( diff --git a/routes/auth.py b/routes/auth.py index 52bd654..c4bb51a 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -2,9 +2,9 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession from schemas.auth import TokenResponse, LoginRequest, RefreshTokenRequest -from services.auth_service import create_tokens_response, verify_token, refresh_tokens -from services.user_services import authenticate_user -from services.db import get_db +from services.auth import create_tokens_response, verify_token, refresh_tokens +from services.user import authenticate_user +from services.db import get_db_session router = APIRouter(tags=["auth"]) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -12,7 +12,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @router.post("/login", response_model=TokenResponse) async def login( login_data: LoginRequest, - session: AsyncSession = Depends(get_db) + session: AsyncSession = Depends(get_db_session) ): user = await authenticate_user(session, login_data.username, login_data.password) if not user: diff --git a/routes/depends.py b/routes/depends.py index bdcb70e..34e0e18 100644 --- a/routes/depends.py +++ b/routes/depends.py @@ -3,7 +3,7 @@ from fastapi.security import OAuth2PasswordBearer from typing import Optional from schemas.auth import TokenPayload from schemas.user import UserRole -from services.auth_service import verify_token +from services.auth import verify_token oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") diff --git a/routes/users.py b/routes/users.py index 2e863a2..cbed7c9 100644 --- a/routes/users.py +++ b/routes/users.py @@ -3,7 +3,7 @@ from typing import List, Optional from schemas.auth import TokenPayload from schemas.user import UserCreate, UserUpdate, UserResponse, UserRole from routes.depends import get_current_user,get_current_admin -from services.user_services import get_user_by_id,get_users,create_user,update_user,delete_user +from services.user import get_users,create_user,update_user,delete_user#,get_user_by_id from services.db import get_db_session router = APIRouter(tags=["users"]) @@ -15,12 +15,12 @@ async def get_users( role: Optional[str] = None, current_user_token: TokenPayload = Depends(get_current_user) ): - current_user = await get_user_by_id(current_user_token.id) - if current_user is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) + # current_user = await get_user_by_id(current_user_token.id) + # if current_user is None: + # raise HTTPException( + # status_code=status.HTTP_404_NOT_FOUND, + # detail="User not found" + # ) async with get_db_session() as session: skip = (page - 1) * limit users = await get_users(session, skip=skip, limit=limit) diff --git a/services/auth_service.py b/services/auth.py similarity index 100% rename from services/auth_service.py rename to services/auth.py diff --git a/services/db.py b/services/db.py index 74777bf..6e3f9f6 100644 --- a/services/db.py +++ b/services/db.py @@ -1,10 +1,10 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker -from ..models.user import Base, User, UserRole +from models.user import Base, User, UserRole from sqlalchemy import select from contextlib import asynccontextmanager from config import SYSTEM_ADMIN_CONFIG, DATABASE_CONFIG -from user_services import get_password_hash +from services.user import get_password_hash from typing import AsyncGenerator # 全局数据库引擎实例 @@ -44,8 +44,9 @@ async def init_db(engine: AsyncEngine): # 创建所有表 await conn.run_sync(Base.metadata.create_all) + async with AsyncSession(engine) as session: # 检查系统管理员是否存在 - result = await conn.execute( + result = await session.execute( select(User).where(User.role == UserRole.SYSTEM_ADMIN) ) if not result.scalars().first(): @@ -56,8 +57,8 @@ async def init_db(engine: AsyncEngine): role=UserRole.SYSTEM_ADMIN, description=SYSTEM_ADMIN_CONFIG['description'] ) - conn.add(admin) - await conn.commit() + session.add(admin) + await session.commit() async def close_db_connection(): """关闭数据库连接""" diff --git a/services/user_services.py b/services/user.py similarity index 95% rename from services/user_services.py rename to services/user.py index 79e1ce5..90c6828 100644 --- a/services/user_services.py +++ b/services/user.py @@ -7,7 +7,6 @@ from schemas.user import UserCreate, UserUpdate, UserResponse # 创建一个密码上下文对象,指定使用 bcrypt 加密算法 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - async def create_user(session: AsyncSession, user_data: UserCreate) -> UserResponse: """创建用户""" @@ -60,6 +59,9 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: """验证输入的明文密码是否与存储的哈希密码匹配""" return pwd_context.verify(plain_password, hashed_password) +def get_password_hash(password: str) -> str: + """生成使用 bcrypt 的密码哈希""" + return pwd_context.hash(password) async def authenticate_user(session: AsyncSession, username: str, password: str) -> Optional[UserResponse]: """验证用户登录"""