Skip to content

Commit

Permalink
fix: cors middleware mirrors origin in case no initial cookie is present
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianBr committed Sep 2, 2024
1 parent 8e1fc9b commit 4a49a51
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
4 changes: 2 additions & 2 deletions starlette/middleware/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ async def send(self, message: Message, send: Send, request_headers: Headers) ->
headers = MutableHeaders(scope=message)
headers.update(self.simple_headers)
origin = request_headers["Origin"]
has_cookie = "cookie" in request_headers
has_cookie = "cookie" in request_headers or "set-cookie" in headers

# If request includes any cookie headers, then we must respond
# If request or response includes any cookie headers, then we must respond
# with the specific origin instead of '*'.
if self.allow_all_origins and has_cookie:
self.allow_explicit_origin(headers, origin)
Expand Down
23 changes: 23 additions & 0 deletions tests/middleware/test_cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,29 @@ def homepage(request: Request) -> PlainTextResponse:
assert "access-control-allow-credentials" not in response.headers


def test_cors_credentialed_requests_return_specific_origin_without_initial_cookie(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
response = PlainTextResponse("Homepage", status_code=200)
response.set_cookie("mycookie", "myvalue", path=None)
return response

app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=["*"])],
)
client = test_client_factory(app)

# Test credentialed request
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert "access-control-allow-credentials" in response.headers


def test_cors_vary_header_defaults_to_origin(
test_client_factory: TestClientFactory,
) -> None:
Expand Down

0 comments on commit 4a49a51

Please sign in to comment.