Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix result source was never awaited if clients disconnect fast #3687

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

This release fixes the issue that some coroutines in the WebSocket protocol handlers were never awaited if clients disconnected shortly after starting an operation.
DoctorJohn marked this conversation as resolved.
Show resolved Hide resolved
62 changes: 39 additions & 23 deletions strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,41 +245,42 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
elif hasattr(self.context, "connection_params"):
self.context.connection_params = self.connection_params

operation = Operation(
self,
message.id,
operation_type,
message.payload.query,
message.payload.variables,
message.payload.operationName,
)

operation.task = asyncio.create_task(self.run_operation(operation))
self.operations[message.id] = operation

async def run_operation(self, operation: Operation) -> None:
"""The operation task's top level method. Cleans-up and de-registers the operation once it is done."""
# TODO: Handle errors in this method using self.handle_task_exception()

result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult]

# Get an AsyncGenerator yielding the results
if operation_type == OperationType.SUBSCRIPTION:
if operation.operation_type == OperationType.SUBSCRIPTION:
result_source = self.schema.subscribe(
query=message.payload.query,
variable_values=message.payload.variables,
operation_name=message.payload.operationName,
query=operation.query,
variable_values=operation.variables,
operation_name=operation.operation_name,
context_value=self.context,
root_value=self.root_value,
)
else:
result_source = self.schema.execute(
query=message.payload.query,
variable_values=message.payload.variables,
query=operation.query,
variable_values=operation.variables,
context_value=self.context,
root_value=self.root_value,
operation_name=message.payload.operationName,
operation_name=operation.operation_name,
)

operation = Operation(self, message.id, operation_type)

# Create task to handle this subscription, reserve the operation ID
operation.task = asyncio.create_task(
self.operation_task(result_source, operation)
)
self.operations[message.id] = operation

async def operation_task(
self,
result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult],
operation: Operation,
) -> None:
"""The operation task's top level method. Cleans-up and de-registers the operation once it is done."""
# TODO: Handle errors in this method using self.handle_task_exception()
try:
first_res_or_agen = await result_source
# that's an immediate error we should end the operation
Expand Down Expand Up @@ -340,17 +341,32 @@ async def reap_completed_tasks(self) -> None:
class Operation:
"""A class encapsulating a single operation with its id. Helps enforce protocol state transition."""

__slots__ = ["handler", "id", "operation_type", "completed", "task"]
__slots__ = [
"handler",
"id",
"operation_type",
"query",
"variables",
"operation_name",
"completed",
"task",
]

def __init__(
self,
handler: BaseGraphQLTransportWSHandler,
id: str,
operation_type: OperationType,
query: str,
variables: Optional[Dict[str, Any]],
operation_name: Optional[str],
) -> None:
self.handler = handler
self.id = id
self.operation_type = operation_type
self.query = query
self.variables = variables
self.operation_name = operation_name
self.completed = False
self.task: Optional[asyncio.Task] = None

Expand Down
24 changes: 12 additions & 12 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import (
TYPE_CHECKING,
AsyncGenerator,
Awaitable,
Dict,
Optional,
cast,
Expand Down Expand Up @@ -37,7 +36,6 @@
if TYPE_CHECKING:
from strawberry.http.async_base_view import AsyncWebSocketAdapter
from strawberry.schema import BaseSchema
from strawberry.schema.subscribe import SubscriptionResult


class BaseGraphQLWSHandler:
Expand Down Expand Up @@ -136,15 +134,9 @@ async def handle_start(self, message: OperationMessage) -> None:
if self.debug:
pretty_print_graphql_operation(operation_name, query, variables)

result_source = self.schema.subscribe(
query=query,
variable_values=variables,
operation_name=operation_name,
context_value=self.context,
root_value=self.root_value,
result_handler = self.handle_async_results(
operation_id, query, operation_name, variables
)

result_handler = self.handle_async_results(result_source, operation_id)
self.tasks[operation_id] = asyncio.create_task(result_handler)

async def handle_stop(self, message: OperationMessage) -> None:
Expand All @@ -160,11 +152,19 @@ async def handle_keep_alive(self) -> None:

async def handle_async_results(
self,
result_source: Awaitable[SubscriptionResult],
operation_id: str,
query: str,
operation_name: Optional[str],
variables: Optional[Dict[str, object]],
) -> None:
try:
agen_or_err = await result_source
agen_or_err = await self.schema.subscribe(
query=query,
variable_values=variables,
operation_name=operation_name,
context_value=self.context,
root_value=self.root_value,
)
if isinstance(agen_or_err, PreExecutionError):
assert agen_or_err.errors
error_payload = agen_or_err.errors[0].formatted
Expand Down
Loading