diff --git a/backend/controller/common.py b/backend/controller/common.py index 0c837db..f58f775 100644 --- a/backend/controller/common.py +++ b/backend/controller/common.py @@ -1,14 +1,19 @@ import asyncio +from fastapi import APIRouter from starlette.websockets import WebSocket +from websockets.exceptions import WebSocketException from core.security import generate_token, verify_password from core.utils import get_system_info from dbhelper.user import get_user from schemas import LoginForm, LoginResult, Response +router = APIRouter(tags=["公共"]) -async def login(auth_data: LoginForm) -> Response[LoginResult]: + +@router.post("/login", summary="登录", response_model=Response[LoginResult]) +async def login(auth_data: LoginForm): user_obj = await get_user({"username": auth_data.username, "status__not": 9}) if user_obj: if verify_password(auth_data.password, user_obj.password): @@ -20,12 +25,12 @@ async def login(auth_data: LoginForm) -> Response[LoginResult]: return Response(code=400, msg="账号或密码错误") +@router.websocket("/ws", name="系统信息") async def websocket(ws: WebSocket): await ws.accept() try: while True: await asyncio.sleep(1) await ws.send_json(get_system_info()) - except Exception as e: - print("断开了链接", e) + except WebSocketException: await ws.close() diff --git a/backend/controller/menu.py b/backend/controller/menu.py index bb9fc4a..714cbe9 100644 --- a/backend/controller/menu.py +++ b/backend/controller/menu.py @@ -1,13 +1,20 @@ +# router service db router+service db +from fastapi import APIRouter + from core.utils import list_to_tree from dbhelper.menu import del_menu, get_menu, get_tree_menu, insert_menu, put_menu from schemas import MenuIn, MenuRead, Response +router = APIRouter(prefix="/menu", tags=["菜单管理"]) -async def menu_add(data: MenuIn) -> Response[MenuRead]: + +@router.post("", summary="菜单新增", response_model=Response[MenuRead]) +async def menu_add(data: MenuIn): return Response(data=await insert_menu(data)) -async def menu_arr() -> Response: +@router.get("", summary="菜单列表", response_model=Response) +async def menu_arr(): menus = await get_tree_menu() try: data = list_to_tree(menus) @@ -16,7 +23,8 @@ async def menu_arr() -> Response: return Response(data=data) -async def menu_del(pk: int) -> Response: +@router.delete("/{pk}", summary="菜单删除", response_model=Response) +async def menu_del(pk: int): if await get_menu({"pid": pk}) is not None: return Response(code=400, msg="请先删除子节点") if await del_menu(pk) == 0: @@ -24,7 +32,8 @@ async def menu_del(pk: int) -> Response: return Response() -async def menu_put(pk: int, data: MenuIn) -> Response: +@router.put("/{pk}", summary="菜单更新", response_model=Response) +async def menu_put(pk: int, data: MenuIn): """更新菜单""" if await put_menu(pk, data) == 0: return Response(code=400, msg="菜单不存在") diff --git a/backend/controller/role.py b/backend/controller/role.py index dc68144..ec25f84 100644 --- a/backend/controller/role.py +++ b/backend/controller/role.py @@ -1,6 +1,4 @@ -import json - -from fastapi import Query +from fastapi import APIRouter, Query from core.utils import list_to_tree from dbhelper.menu import get_menu @@ -14,13 +12,17 @@ from dbhelper.role import ( ) from schemas import ListAll, Response, RoleIn, RoleInfo, RoleQuery, RoleRead +router = APIRouter(prefix="/role", tags=["角色管理"]) -async def role_add(data: RoleIn) -> Response[RoleInfo]: + +@router.post("", summary="角色新增", response_model=Response[RoleInfo]) +async def role_add(data: RoleIn): if result := await new_role(data): return Response(data=result) return Response(code=400, msg="菜单不存在") +@router.get("/{rid}/menu", summary="查询角色拥有权限", response_model=Response) async def role_has_menu(rid: int): """ rid: 角色ID @@ -34,24 +36,26 @@ async def role_has_menu(rid: int): return Response(data=result) +@router.get("", summary="角色列表", response_model=Response[ListAll[list[RoleRead]]]) async def role_arr( offset: int = Query(default=1, description="偏移量-页码"), limit: int = Query(default=10, description="数据量"), -) -> Response[ListAll[list[RoleRead]]]: +): skip = (offset - 1) * limit roles, count = await get_roles(skip, limit) return Response(data=ListAll(total=count, items=roles)) -async def role_del(pk: int) -> Response: +@router.delete("/{pk}", summary="角色删除", response_model=Response) +async def role_del(pk: int): if await del_role(pk) == 0: return Response(code=400, msg="角色不存在") return Response() -async def role_put(pk: int, data: RoleIn) -> Response: +@router.put("/{pk}", summary="角色更新", response_model=Response) +async def role_put(pk: int, data: RoleIn): """更新角色""" - print(await get_role({"id": pk})) if await get_role({"id": pk}) is None: return Response(code=400, msg="角色不存在") @@ -64,7 +68,8 @@ async def role_put(pk: int, data: RoleIn) -> Response: return Response() -async def role_query(query: RoleQuery) -> Response[ListAll[list[RoleRead]]]: +@router.post("/query", summary="角色查询", response_model=Response[ListAll[list[RoleRead]]]) +async def role_query(query: RoleQuery): """post条件查询角色表""" size = query.limit skip = (query.offset - 1) * size diff --git a/backend/controller/user.py b/backend/controller/user.py index 5f0ca22..f425f4e 100644 --- a/backend/controller/user.py +++ b/backend/controller/user.py @@ -1,7 +1,6 @@ -from fastapi import Depends, Query -from starlette.requests import Request +from fastapi import APIRouter, Depends, Query -from core.security import check_token, get_password_hash +from core.security import check_permissions, get_password_hash from dbhelper.user import ( del_user, get_user, @@ -14,8 +13,11 @@ from dbhelper.user import ( from schemas import Response, UserAdd, UserInfo, UserPut, UserQuery, UserRead from schemas.common import ListAll +router = APIRouter(prefix="/user", tags=["用户管理"]) -async def user_add(data: UserAdd) -> Response[UserRead]: + +@router.post("", summary="用户新增", response_model=Response[UserRead]) +async def user_add(data: UserAdd): """新增用户并分配角色 一步到位""" if await get_user({"username": data.username}) is not None: return Response(code=400, msg="用户名已存在") @@ -28,7 +30,8 @@ async def user_add(data: UserAdd) -> Response[UserRead]: return Response(data=result) -async def user_info(pk: int) -> Response[UserInfo]: +@router.get("/{pk}", summary="用户信息", response_model=Response[UserInfo]) +async def user_info(pk: int): """获取用户信息""" obj = await get_user({"id": pk}) if obj is None: @@ -36,17 +39,19 @@ async def user_info(pk: int) -> Response[UserInfo]: return Response(data=await get_user_info(obj)) +@router.get("", summary="用户列表", response_model=Response[ListAll[list[UserRead]]]) async def user_arr( offset: int = Query(default=1, description="偏移量-页码"), limit: int = Query(default=10, description="数据量"), -) -> Response[ListAll[list[UserRead]]]: +): """分页列表数据""" skip = (offset - 1) * limit users, count = await get_users(skip, limit) return Response(data=ListAll(total=count, items=users)) -async def user_list(query: UserQuery) -> Response[ListAll[list[UserRead]]]: +@router.post("/query", summary="用户查询", response_model=Response[ListAll[list[UserRead]]]) +async def user_list(query: UserQuery): """post查询用户列表""" size = query.limit skip = (query.offset - 1) * size @@ -55,14 +60,16 @@ async def user_list(query: UserQuery) -> Response[ListAll[list[UserRead]]]: return Response(data=ListAll(total=count, items=users)) -async def user_del(pk: int) -> Response: +@router.delete("/{pk}", summary="用户删除", response_model=Response) +async def user_del(pk: int): """删除用户""" if await del_user(pk) == 0: return Response(code=400, msg="用户不存在") return Response() -async def user_put(pk: int, data: UserPut) -> Response: +@router.put("/{pk}", summary="用户更新", response_model=Response) +async def user_put(pk: int, data: UserPut): """更新用户""" if await get_user({"id": pk}) is None: return Response(code=400, msg="用户不存在") @@ -73,7 +80,8 @@ async def user_put(pk: int, data: UserPut) -> Response: return Response() -async def user_select_role(rid: int, user=Depends(check_token)): +@router.put("/role/{rid}", summary="用户切换角色", response_model=Response) +async def user_select_role(rid: int, user=Depends(check_permissions)): """用户切换角色""" res = await select_role(user.id, rid) if res == 0: diff --git a/backend/core/dbhelper.py b/backend/core/dbhelper.py new file mode 100644 index 0000000..cd768f1 --- /dev/null +++ b/backend/core/dbhelper.py @@ -0,0 +1,103 @@ +"""数据库通用查询方法""" +from typing import Optional + +from tortoise import connections, models + + +class DbHelper: + def __init__(self, model: models.Model): + """ + 初始化 + :param model: 模型类 + """ + self.model = model + + def __filter(self, kwargs: dict): + """ + 过滤数据,默认过滤数据 + :param kwargs: + :return: + """ + return self.model.filter(**kwargs) + + async def select(self, kwargs: dict = None) -> Optional[models.Model]: + """ + 查询符合条件的第一个对象, 查无结果时返回None + :param kwargs: kwargs: {"name:"7y", "id": 1} + :return: select * from model where name = "7y" and id = 1 limit 1 + """ + if kwargs is None: + kwargs = {} + return await self.__filter(kwargs).first() + + async def update(self, filters: dict = None, updates: dict = None): + """ + 更新单条数据 + :param filters: 条件字典 {"id":1,"status__not": 9} + :param updates: 待更新数据 {"status": 5} + :return: 0 失败, 1 成功 + """ + return await self.__filter(filters).update(**updates) + + async def delete(self, pk: int) -> int: + """ + 逻辑删除单条数据, status -> 9 + :param pk: 数据id + :return: 0 是删除 失败, 1是删除成功 + """ + filters = {"id": pk} + updates = dict(status=9) + return await self.update(filters=filters, updates=updates) + + async def insert(self, data: dict): + """ + 新增一条数据 + :param data: 模型字典 + :return: 新增之后的对象 + """ + return await self.model.create(**data) + + async def selects( + self, offset: int, limit: int, kwargs: dict = None, order_by: str = None + ) -> dict: + """ + 条件分页查询数据列表, 支持排序 + Args: + offset: 偏移量 + limit: 数量 + kwargs: 条件 {} + order_by: 排序,默认为None, 传入 -字段名 降序 字段名升序 + SQL => select * from model where xx=xx ... order by xx limit offset, limit + Returns: + {"items": Model列表, "total": "数量"} + """ + if kwargs is None: + kwargs = {} + objs = self.__filter(kwargs).all() + if order_by is not None: + objs = objs.order_by(order_by) + + return dict( + items=await objs.offset(offset).limit(limit), total=await objs.count() + ) + + async def inserts(self, objs: list[models.Model]): + """ + 批量新增数据 + :param objs: 模型列表 + :return: + """ + await self.model.bulk_create(objs) + + @classmethod + async def raw_sql(cls, sql: str, args: list = None): + """ + 手动执行SQL + :param sql: + :param args: sql参数 + :return: + """ + db = connections.get("default") + if args is None: + args = [] + return await db.execute_query_dict(sql, args) diff --git a/backend/core/security.py b/backend/core/security.py index 0237c06..f6f0d1b 100644 --- a/backend/core/security.py +++ b/backend/core/security.py @@ -71,9 +71,11 @@ async def check_permissions(request: Request, user: UserModel = Depends(check_to active_rid = result["roles"][0]["id"] # 白名单 登录用户信息, 登录用户菜单信息 - whitelist = [f"/user/{user.id}", f"/role/{active_rid}/menu"] - flag = request.url.path in whitelist and request.method == "GET" - if flag: + whitelist = [(f"/user/{user.id}", "GET"), (f"/role/{active_rid}/menu", "GET")] + [ + (f"/user/role/{rid['id']}", "PUT") for rid in result["roles"] + ] + + if (request.url.path, request.method) in whitelist: return user api = request.url.path diff --git a/backend/core/utils.py b/backend/core/utils.py index 912b016..4e40b9e 100644 --- a/backend/core/utils.py +++ b/backend/core/utils.py @@ -1,5 +1,10 @@ +import importlib +import inspect +import os import random +from core.log import logger + def list_to_tree( menus, parent_flag: str = "pid", children_key: str = "children" @@ -43,3 +48,47 @@ def get_system_info(): "user": f"{random.randint(1, 50)}", }, } + + +def load_routers( + app, package_path: str = "router", router_name: str = "router", is_init=False +): + """ + 自动注册路由 + :param app: FastAPI 实例对象 或者 APIRouter对象 + :param package_path: 路由包所在路径,默认相对路径router包 + :param router_name: APIRouter实例名称,需所有实例统一,默认router + :param is_init: 是否在包中的__init__.py中导入了所有APIRouter实例,默认否 + :return: 默认None + """ + + def __register(module_obj): + """注册路由,module_obj: 模块对象""" + if hasattr(module_obj, router_name): + router_obj = getattr(module_obj, router_name) + app.include_router(router_obj) + + logger.info("开始扫描路由。") + if is_init: + # 1. init 导入了其他自文件包时 + for _, module in inspect.getmembers( + importlib.import_module(package_path), inspect.ismodule + ): + __register(module) + + else: + # 2. 排除init文件时 的情况 + for _, _, files in os.walk(package_path): + for file in files: + if file.endswith(".py") and file != "__init__.py": + module = importlib.import_module(f"{package_path}.{file[:-3]}") + __register(module) + + for route in app.routes: + try: + logger.debug( + f"{route.path}, {route.methods}, {route.__dict__.get('summary')}" + ) + except AttributeError as e: + logger.error(e) + logger.info("👌路由注册完成✅。") diff --git a/backend/main.py b/backend/main.py index 6137403..c6a6d0a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -4,25 +4,18 @@ from core.events import close_orm, init_orm from core.exceptions import exception_handlers from core.log import logger from core.middleware import middlewares -from router.url import routes +from core.utils import load_routers app = FastAPI( on_startup=[init_orm], on_shutdown=[close_orm], - routes=routes, middleware=middlewares, exception_handlers=exception_handlers, ) +load_routers(app, "controller") if __name__ == "__main__": import uvicorn - from fastapi.routing import APIWebSocketRoute - - for i in app.routes: - if not isinstance(i, APIWebSocketRoute): - logger.info( - f"{i.path}, {i.methods}, {i.__dict__.get('summary')}, {i.endpoint}" - ) uvicorn.run("main:app", reload=True) diff --git a/backend/requirements.txt b/backend/requirements.txt index bf8b905..8e3034a 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,13 +1,9 @@ bcrypt==4.0.0 -fastapi==0.82.0 +fastapi==0.85.0 passlib==1.7.4 pytest==7.1.3 python-jose==3.3.0 requests==2.28.1 -uvicorn==0.18.3 tortoise-orm==0.19.2 +uvicorn==0.18.3 websockets==10.3 - - - - diff --git a/backend/router/__init__.py b/backend/router/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/router/url.py b/backend/router/url.py deleted file mode 100644 index 77d92a1..0000000 --- a/backend/router/url.py +++ /dev/null @@ -1,176 +0,0 @@ -from typing import Any, Callable, get_type_hints - -from fastapi import Depends, routing - -from controller.common import login, websocket -from controller.menu import menu_add, menu_arr, menu_del, menu_put -from controller.role import ( - role_add, - role_arr, - role_del, - role_has_menu, - role_put, - role_query, -) -from controller.user import ( - user_add, - user_arr, - user_del, - user_info, - user_list, - user_put, - user_select_role, -) -from core.security import check_permissions - - -class Route(routing.APIRoute): - """ - https://github.com/tiangolo/fastapi/issues/620 - Django挂载视图方法 - def index() -> User: - pass - Route("/", endpoint=index) - """ - - def __init__( - self, - path: str, - endpoint: Callable[..., Any], - tags: list[str], - summary: str, - **kwargs: Any - ): - if kwargs.get("response_model") is None: - kwargs["response_model"] = get_type_hints(endpoint).get("return") - super(Route, self).__init__( - path=path, endpoint=endpoint, tags=tags, summary=summary, **kwargs - ) - - @classmethod - def post( - cls, - path: str, - endpoint: Callable[..., Any], - tags: list[str], - summary: str, - **kwargs: Any - ): - return Route( - path=path, - endpoint=endpoint, - methods=["POST"], - tags=tags, - summary=summary, - **kwargs - ) - - @classmethod - def get( - cls, - path: str, - endpoint: Callable[..., Any], - tags: list[str], - summary: str, - **kwargs: Any - ): - return Route( - path=path, - endpoint=endpoint, - methods=["GET"], - tags=tags, - summary=summary, - **kwargs - ) - - @classmethod - def delete( - cls, - path: str, - endpoint: Callable[..., Any], - tags: list[str], - summary: str, - **kwargs: Any - ): - return Route( - path=path, - endpoint=endpoint, - methods=["DELETE"], - tags=tags, - summary=summary, - **kwargs - ) - - @classmethod - def put( - cls, - path: str, - endpoint: Callable[..., Any], - tags: list[str], - summary: str, - **kwargs: Any - ): - return Route( - path=path, - endpoint=endpoint, - methods=["PUT"], - tags=tags, - summary=summary, - **kwargs - ) - - -has_perm = {"dependencies": [Depends(check_permissions)]} - -routes = [ - Route.post("/login", endpoint=login, tags=["公共"], summary="登录"), - # 用户管理 - Route.get("/user", endpoint=user_arr, tags=["用户管理"], summary="用户列表", **has_perm), - Route.post("/user", endpoint=user_add, tags=["用户管理"], summary="用户新增", **has_perm), - Route.delete( - "/user/{pk}", endpoint=user_del, tags=["用户管理"], summary="用户删除", **has_perm - ), - Route.put( - "/user/{pk}", endpoint=user_put, tags=["用户管理"], summary="用户更新", **has_perm - ), - Route.get( - "/user/{pk}", endpoint=user_info, tags=["用户管理"], summary="用户信息", **has_perm - ), - Route.post( - "/user/query", endpoint=user_list, tags=["用户管理"], summary="用户列表查询", **has_perm - ), - Route.put( - "/user/role/{rid}", endpoint=user_select_role, tags=["用户管理"], summary="用户切换角色" - ), - # 角色管理, - Route.get("/role", endpoint=role_arr, tags=["角色管理"], summary="角色列表", **has_perm), - Route.post("/role", endpoint=role_add, tags=["角色管理"], summary="角色新增", **has_perm), - Route.delete( - "/role/{pk}", endpoint=role_del, tags=["角色管理"], summary="角色删除", **has_perm - ), - Route.get( - "/role/{rid}/menu", - endpoint=role_has_menu, - tags=["角色管理"], - summary="查询角色拥有权限", - **has_perm - ), - Route.put( - "/role/{pk}", endpoint=role_put, tags=["角色管理"], summary="角色更新", **has_perm - ), - Route.post( - "/role/query", endpoint=role_query, tags=["角色管理"], summary="角色条件查询", **has_perm - ), - # 菜单新增 - Route.get("/menu", endpoint=menu_arr, tags=["菜单管理"], summary="菜单列表", **has_perm), - Route.post("/menu", endpoint=menu_add, tags=["菜单管理"], summary="菜单新增", **has_perm), - Route.delete( - "/menu/{pk}", endpoint=menu_del, tags=["菜单管理"], summary="菜单删除", **has_perm - ), - Route.put( - "/menu/{pk}", endpoint=menu_put, tags=["菜单管理"], summary="菜单更新", **has_perm - ), - routing.APIWebSocketRoute("/ws", endpoint=websocket), -] - -__all__ = [routes]