Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: change rewrite path middleware for queue #188

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions faster_sam/middlewares/queue_path_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from http import HTTPStatus
import logging

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp
from starlette.types import ASGIApp, Receive, Scope, Send

logger = logging.getLogger(__name__)


class QueuePathRewriterMiddleware(BaseHTTPMiddleware):
class QueuePathRewriterMiddleware:
"""
Rewrites a specified part of the request path.

Expand All @@ -24,9 +23,9 @@ def __init__(self, app: ASGIApp) -> None:
"""
Initializes the QueuePathRewriterMiddleware.
"""
super().__init__(app, self.dispatch)
self.app = app

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Response:
"""
Rewrites a specified part of the request path.

Expand All @@ -42,8 +41,10 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
Response
The response generated by the middleware.
"""
request = Request(scope, receive=receive)

if request.method != "POST":
return await call_next(request)
return await self.app(scope, receive, send)

body = await request.body()

Expand All @@ -64,4 +65,4 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -

request.scope["path"] = "/" + queue

return await call_next(request)
return await self.app(scope, receive, send)
54 changes: 14 additions & 40 deletions tests/test_middleware_queue_path_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,35 @@
import unittest

from fastapi import FastAPI, Request, Response
from fastapi.testclient import TestClient

from faster_sam.middlewares import queue_path_rewriter


class TestQueuePathRewriterMiddleware(unittest.IsolatedAsyncioTestCase):
async def test_middleware_rewrite_path(self):
async def receive():
return {
"type": "http.request",
"body": b'{"message":{"attributes":{"endpoint":"aws/bar"}}}',
}
class TestQueuePathRewriterMiddleware(unittest.TestCase):
def setUp(self) -> None:
async def queue(request: Request) -> Response:
return Response(content=json.dumps({"path": request.scope["path"]}))

app = FastAPI()
app.add_middleware(queue_path_rewriter.QueuePathRewriterMiddleware)
app.add_route("/queue", queue)

middleware = queue_path_rewriter.QueuePathRewriterMiddleware(app)

async def call_next(request: Request) -> Response:
return Response(content=json.dumps({"path": request.scope["path"]}))

request = Request(scope={"type": "http", "method": "POST", "path": "/foo"}, receive=receive)
self.client = TestClient(app)

response = await middleware.dispatch(request, call_next)
async def test_middleware_rewrite_path(self):
response = self.client.post(
"/queue", json={"message": {"attributes": {"endpoint": "/foo/bar"}}}
)

self.assertEqual(json.loads(response.body), {"path": "/bar"})

async def test_middleware_rewrite_path_with_empty_body(self):
async def receive():
return {
"type": "http.request",
"body": b"",
}

app = FastAPI()

middleware = queue_path_rewriter.QueuePathRewriterMiddleware(app)

async def call_next(request: Request) -> Response:
return Response(content=json.dumps({"path": request.scope["path"]}))

request = Request(scope={"type": "http", "method": "POST", "path": "/foo"}, receive=receive)

response = await middleware.dispatch(request, call_next)
response = self.client.post("/queue", json={})

self.assertEqual(json.loads(response.body), {"message": "Invalid Request"})

async def test_middleware_rewrite_path_without_post(self):
app = FastAPI()

middleware = queue_path_rewriter.QueuePathRewriterMiddleware(app)

async def call_next(request: Request) -> Response:
return Response(content=json.dumps({"path": request.scope["path"]}))

request = Request(scope={"type": "http", "method": "GET", "path": "/foo"})

response = await middleware.dispatch(request, call_next)
response = self.client.get("/queue")

self.assertEqual(json.loads(response.body), {"path": "/foo"})