From a7aa9e056213f85ffcb2c6ffb4c2ea3d02147f36 Mon Sep 17 00:00:00 2001 From: pyoor Date: Fri, 21 Apr 2023 09:15:04 -0400 Subject: [PATCH] wip: initial attempt at adding webtransport to grizzly --- .codecov.yml | 2 + grizzly/conftest.py | 1 + grizzly/main.py | 8 + grizzly/reduce/core.py | 14 + grizzly/replay/replay.py | 13 + grizzly/replay/test_main.py | 7 + grizzly/services/__init__.py | 11 + grizzly/services/base.py | 30 ++ grizzly/services/core.py | 96 ++++ grizzly/services/test_services.py | 39 ++ grizzly/services/webtransport/__init__.py | 0 grizzly/services/webtransport/core.py | 117 +++++ .../webtransport/test_webtransport.py | 37 ++ .../webtransport/wpt_h3_server/LICENSE.md | 11 + .../webtransport/wpt_h3_server/__init__.py | 0 .../webtransport/wpt_h3_server/capsule.py | 114 ++++ .../wpt_h3_server/handlers/__init__.py | 0 .../handlers/abort_stream_from_server.py | 12 + .../wpt_h3_server/handlers/custom_response.py | 14 + .../wpt_h3_server/handlers/echo.py | 32 ++ .../handlers/echo_request_headers.py | 12 + .../wpt_h3_server/handlers/server_close.py | 17 + .../handlers/server_connection_close.py | 9 + .../wpt_h3_server/webtransport_h3_server.py | 493 ++++++++++++++++++ grizzly/session.py | 4 +- grizzly/test_session.py | 3 + pyproject.toml | 1 + sapphire/__init__.py | 3 +- setup.cfg | 2 + tox.ini | 1 + 30 files changed, 1101 insertions(+), 2 deletions(-) create mode 100644 grizzly/services/__init__.py create mode 100644 grizzly/services/base.py create mode 100644 grizzly/services/core.py create mode 100644 grizzly/services/test_services.py create mode 100644 grizzly/services/webtransport/__init__.py create mode 100644 grizzly/services/webtransport/core.py create mode 100644 grizzly/services/webtransport/test_webtransport.py create mode 100644 grizzly/services/webtransport/wpt_h3_server/LICENSE.md create mode 100644 grizzly/services/webtransport/wpt_h3_server/__init__.py create mode 100644 grizzly/services/webtransport/wpt_h3_server/capsule.py create mode 100644 grizzly/services/webtransport/wpt_h3_server/handlers/__init__.py create mode 100644 grizzly/services/webtransport/wpt_h3_server/handlers/abort_stream_from_server.py create mode 100644 grizzly/services/webtransport/wpt_h3_server/handlers/custom_response.py create mode 100644 grizzly/services/webtransport/wpt_h3_server/handlers/echo.py create mode 100644 grizzly/services/webtransport/wpt_h3_server/handlers/echo_request_headers.py create mode 100644 grizzly/services/webtransport/wpt_h3_server/handlers/server_close.py create mode 100644 grizzly/services/webtransport/wpt_h3_server/handlers/server_connection_close.py create mode 100644 grizzly/services/webtransport/wpt_h3_server/webtransport_h3_server.py diff --git a/.codecov.yml b/.codecov.yml index 9e3d0461..f2ab8d37 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,3 +1,5 @@ +ignore: + - "grizzly/services/webtransport/wpt_h3_server" codecov: ci: - community-tc.services.mozilla.com diff --git a/grizzly/conftest.py b/grizzly/conftest.py index b699d5fd..30642f13 100644 --- a/grizzly/conftest.py +++ b/grizzly/conftest.py @@ -14,6 +14,7 @@ def session_setup(mocker): mocker.patch("grizzly.main.FuzzManagerReporter", autospec=True) mocker.patch("grizzly.main.Sapphire", autospec_set=True) + mocker.patch("grizzly.main.WebServices", autospec_set=True) adapter_cls = mocker.Mock(spec_set=Adapter) adapter_cls.return_value.RELAUNCH = Adapter.RELAUNCH adapter_cls.return_value.TIME_LIMIT = Adapter.TIME_LIMIT diff --git a/grizzly/main.py b/grizzly/main.py index dc88ccd7..7e547e1c 100644 --- a/grizzly/main.py +++ b/grizzly/main.py @@ -23,6 +23,7 @@ package_version, time_limits, ) +from .services import WebServices from .session import LogRate, Session from .target import Target, TargetLaunchError, TargetLaunchTimeout @@ -52,6 +53,7 @@ def main(args: Namespace) -> int: adapter: Optional[Adapter] = None certs: Optional[CertificateBundle] = None complete_with_results = False + ext_services = None target: Optional[Target] = None try: LOG.debug("initializing Adapter %r", args.adapter) @@ -121,6 +123,9 @@ def main(args: Namespace) -> int: # launch http server used to serve test cases LOG.debug("starting Sapphire server") with Sapphire(auto_close=1, timeout=timeout, certs=certs) as server: + if certs is not None: + ext_services = WebServices.start_services(certs.host, certs.key) + target.reverse(server.port, server.port) LOG.debug("initializing the Session") with Session( @@ -149,6 +154,7 @@ def main(args: Namespace) -> int: log_rate=log_rate, launch_attempts=args.launch_attempts, post_launch_delay=args.post_launch_delay, + services=ext_services, ) complete_with_results = session.status.results.total > 0 @@ -169,6 +175,8 @@ def main(args: Namespace) -> int: if adapter is not None: LOG.debug("calling adapter.cleanup()") adapter.cleanup() + if ext_services is not None: + ext_services.cleanup() if certs is not None: certs.cleanup() LOG.info("Done.") diff --git a/grizzly/reduce/core.py b/grizzly/reduce/core.py index 50b74235..320ba0a0 100644 --- a/grizzly/reduce/core.py +++ b/grizzly/reduce/core.py @@ -37,6 +37,7 @@ time_limits, ) from ..replay import ReplayManager, ReplayResult +from ..services import WebServices from ..target import AssetManager, Target, TargetLaunchError, TargetLaunchTimeout from .exceptions import GrizzlyReduceBaseException, NotReproducible from .strategies import STRATEGIES @@ -90,6 +91,7 @@ def __init__( relaunch: int = 1, report_period: Optional[int] = None, report_to_fuzzmanager: bool = False, + services=None, signature: Optional[CrashSignature] = None, signature_desc: Optional[str] = None, static_timeout: bool = False, @@ -115,6 +117,7 @@ def __init__( Target should be relaunched. report_period: Periodically report best results for long-running strategies. report_to_fuzzmanager: Report to FuzzManager rather than filesystem. + services (WebServices): WebServices instance. signature: Signature for accepting crashes. signature_desc: Short description of the given signature. static_timeout: Use only specified timeouts (`--timeout` and @@ -153,6 +156,7 @@ def __init__( ) self._use_analysis = use_analysis self._use_harness = use_harness + self._services = services def __enter__(self) -> "ReduceManager": return self @@ -322,6 +326,7 @@ def run_reliability_analysis(self) -> Tuple[int, int]: idle_delay=self._idle_delay, idle_threshold=self._idle_threshold, on_iteration_cb=self._on_replay_iteration, + services=self._services, ) try: crashes = sum(x.count for x in results if x.expected) @@ -525,6 +530,7 @@ def run( repeat=repeat, on_iteration_cb=self._on_replay_iteration, post_launch_delay=post_launch_delay, + services=self._services, ) self._status.attempts += 1 self.update_timeout(results) @@ -777,6 +783,7 @@ def main(cls, args: Namespace) -> int: asset_mgr: Optional[AssetManager] = None certs = None + ext_services = None signature = None signature_desc = None target: Optional[Target] = None @@ -846,6 +853,10 @@ def main(cls, args: Namespace) -> int: LOG.debug("starting sapphire server") # launch HTTP server used to serve test cases with Sapphire(auto_close=1, timeout=timeout, certs=certs) as server: + if certs is not None: + LOG.debug("starting additional web services") + ext_services = WebServices.start_services(certs.host, certs.key) + target.reverse(server.port, server.port) with ReduceManager( set(args.ignore), @@ -868,6 +879,7 @@ def main(cls, args: Namespace) -> int: tool=args.tool, use_analysis=not args.no_analysis, use_harness=not args.no_harness, + services=ext_services, ) as mgr: return_code = mgr.run( repeat=args.repeat, @@ -910,4 +922,6 @@ def main(cls, args: Namespace) -> int: asset_mgr.cleanup() if certs is not None: certs.cleanup() + if ext_services is not None: + ext_services.cleanup() LOG.info("Done.") diff --git a/grizzly/replay/replay.py b/grizzly/replay/replay.py index 6cbb9d51..7bfca5f1 100644 --- a/grizzly/replay/replay.py +++ b/grizzly/replay/replay.py @@ -33,6 +33,7 @@ package_version, time_limits, ) +from ..services import WebServices from ..target import ( AssetManager, Result, @@ -294,6 +295,7 @@ def run( launch_attempts: int = 3, on_iteration_cb: Optional[Callable[[], None]] = None, post_launch_delay: int = -1, + services=None, ) -> List[ReplayResult]: """Run testcase replay. @@ -349,6 +351,9 @@ def harness_fn(_: str) -> bytes: # pragma: no cover ) server_map.set_redirect("grz_start", "grz_harness", required=False) + if services: + services.map_locations(server_map) + # track unprocessed results reports: Dict[str, ReplayResult] = {} try: @@ -629,6 +634,7 @@ def main(cls, args: Namespace) -> int: certs = None results: Optional[List[ReplayResult]] = None target: Optional[Target] = None + ext_services = None try: # check if hangs are expected expect_hang = cls.expect_hang(args.ignore, signature, testcases) @@ -682,6 +688,10 @@ def main(cls, args: Namespace) -> int: LOG.debug("starting sapphire server") # launch HTTP server used to serve test cases with Sapphire(auto_close=1, timeout=timeout, certs=certs) as server: + if certs is not None: + LOG.debug("starting additional web services") + ext_services = WebServices.start_services(certs.host, certs.key) + target.reverse(server.port, server.port) with cls( set(args.ignore), @@ -702,6 +712,7 @@ def main(cls, args: Namespace) -> int: min_results=args.min_crashes, post_launch_delay=args.post_launch_delay, repeat=repeat, + services=ext_services, ) # handle results success = any(x.expected for x in results) @@ -754,4 +765,6 @@ def main(cls, args: Namespace) -> int: asset_mgr.cleanup() if certs is not None: certs.cleanup() + if ext_services is not None: + ext_services.cleanup() LOG.info("Done.") diff --git a/grizzly/replay/test_main.py b/grizzly/replay/test_main.py index 1fe894b2..dc2c2d9a 100644 --- a/grizzly/replay/test_main.py +++ b/grizzly/replay/test_main.py @@ -36,6 +36,7 @@ def test_main_01(mocker, server, tmp_path): # Of the four attempts only the first and third will 'reproduce' the result # and the forth attempt should be skipped. mocker.patch("grizzly.common.runner.sleep", autospec=True) + mocker.patch("grizzly.replay.replay.WebServices", autospec=True) server.serve_path.return_value = (Served.ALL, {"test.html": "/fake/path"}) # setup Target load_target = mocker.patch("grizzly.replay.replay.load_plugin", autospec=True) @@ -151,6 +152,7 @@ def test_main_02(mocker, server, tmp_path, repro_results): test_index=[], time_limit=10, timeout=None, + use_https=False, valgrind=False, ) assert ReplayManager.main(args) == Exit.FAILURE @@ -215,6 +217,7 @@ def test_main_03(mocker, load_plugin, load_testcases, signature, result): test_index=[], time_limit=10, timeout=None, + use_https=False, valgrind=False, ) asset_mgr = load_testcases[1] if isinstance(load_testcases, tuple) else None @@ -256,6 +259,7 @@ def test_main_04(mocker, tmp_path): test_index=[], time_limit=10, timeout=None, + use_https=False, valgrind=False, ) # target launch error @@ -317,6 +321,7 @@ def test_main_05(mocker, server, tmp_path): test_index=[], time_limit=1, timeout=None, + use_https=False, valgrind=False, ) # build a test case @@ -385,6 +390,7 @@ def test_main_06( test_index=[], time_limit=10, timeout=None, + use_https=False, valgrind=valgrind, ) # maximum one debugger allowed at a time @@ -438,6 +444,7 @@ def test_main_07(mocker, server, tmp_path): time_limit=10, timeout=None, tool=None, + use_https=False, valgrind=False, ) assert ReplayManager.main(args) == Exit.SUCCESS diff --git a/grizzly/services/__init__.py b/grizzly/services/__init__.py new file mode 100644 index 00000000..94c7b1e5 --- /dev/null +++ b/grizzly/services/__init__.py @@ -0,0 +1,11 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +__all__ = ( + "ServiceName", + "WebServices", + "WebTransportServer", +) + +from .core import ServiceName, WebServices +from .webtransport.core import WebTransportServer diff --git a/grizzly/services/base.py b/grizzly/services/base.py new file mode 100644 index 00000000..dbfe4428 --- /dev/null +++ b/grizzly/services/base.py @@ -0,0 +1,30 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +from abc import ABC, abstractmethod + + +class BaseService(ABC): + """Base service class""" + + @property + @abstractmethod + def location(self): + """Location to use with Sapphire.set_dynamic_response""" + + @property + @abstractmethod + def port(self): + """The port on which the server is listening""" + + @abstractmethod + def url(self, _query): + """Returns the URL of the server.""" + + @abstractmethod + async def is_ready(self): + """Wait until the service is ready""" + + @abstractmethod + def cleanup(self): + """Stop the server.""" diff --git a/grizzly/services/core.py b/grizzly/services/core.py new file mode 100644 index 00000000..0f778a3e --- /dev/null +++ b/grizzly/services/core.py @@ -0,0 +1,96 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +import asyncio +from enum import Enum +from logging import getLogger +from typing import Dict + +from sapphire import create_listening_socket + +from .base import BaseService +from .webtransport.core import WebTransportServer + +LOG = getLogger(__name__) + + +class ServiceName(Enum): + """Enum for listing available services""" + + WEB_TRANSPORT = 1 + + +class WebServices: + """Class for running additional web services""" + + def __init__(self, services: Dict[ServiceName, BaseService]): + """Initialize new WebServices instance + + Args: + services (dict of ServiceName: BaseService): Collection of services. + """ + self.services = services + + @staticmethod + def get_free_port(): + """Returns an open port""" + sock = create_listening_socket() + port = sock.getsockname()[1] + sock.close() + + return port + + async def is_running(self, timeout=20): + """Polls all available services to ensure they are running and accessible. + + Args: + timeout (int): Total time to wait. + + Returns: + bool: Indicates if all services started successfully. + """ + tasks = {} + for name, service in self.services.items(): + task = asyncio.create_task(service.is_ready()) + tasks[name] = task + + try: + await asyncio.wait_for(asyncio.gather(*tasks.values()), timeout) + except asyncio.TimeoutError: + for name, task in tasks.items(): + if not task.done(): + LOG.warning("Failed to start service (%s)", ServiceName(name).name) + return False + + return True + + def cleanup(self): + """Stops all running services and join's the service thread""" + for service in self.services.values(): + service.cleanup() + + def map_locations(self, server_map): + """Configure server map""" + for service in self.services.values(): + server_map.set_dynamic_response( + service.location, service.url, mime_type="text/plain", required=False + ) + + @classmethod + def start_services(cls, cert, key): + """Start all available services + + Args: + cert (Path): Path to the certificate file + key (Path): Path to the certificate's private key + """ + services = {} + # Start WebTransport service + wt_port = cls.get_free_port() + services[ServiceName.WEB_TRANSPORT] = WebTransportServer(wt_port, cert, key) + services[ServiceName.WEB_TRANSPORT].start() + + ext_services = cls(services) + assert asyncio.run(ext_services.is_running()) + + return ext_services diff --git a/grizzly/services/test_services.py b/grizzly/services/test_services.py new file mode 100644 index 00000000..754d885f --- /dev/null +++ b/grizzly/services/test_services.py @@ -0,0 +1,39 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# pylint: disable=protected-access +from asyncio import get_event_loop +from concurrent.futures import ThreadPoolExecutor + +from pytest import mark + +from sapphire import ServerMap + +from ..common.utils import CertificateBundle +from .core import WebServices + + +@mark.asyncio +async def test_service_01(): + """Verify that services are started and shutdown gracefully""" + cert = CertificateBundle.create() + try: + loop = get_event_loop() + with ThreadPoolExecutor() as executor: + ext_services = await loop.run_in_executor( + executor, WebServices.start_services, cert.host, cert.key + ) + + # Check that all services are running + assert len(ext_services.services) == 1 + assert await ext_services.is_running() + + server_map = ServerMap() + ext_services.map_locations(server_map) + assert len(server_map.dynamic) == 1 + + # Check that all services have stopped + ext_services.cleanup() + assert not await ext_services.is_running(timeout=0.1) + finally: + cert.cleanup() diff --git a/grizzly/services/webtransport/__init__.py b/grizzly/services/webtransport/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/grizzly/services/webtransport/core.py b/grizzly/services/webtransport/core.py new file mode 100644 index 00000000..d63ffaa0 --- /dev/null +++ b/grizzly/services/webtransport/core.py @@ -0,0 +1,117 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +import asyncio +from logging import CRITICAL, getLogger +from os import environ +from pathlib import Path +from platform import system +from threading import Thread + +from aioquic.asyncio import serve +from aioquic.h3.connection import H3_ALPN +from aioquic.quic.configuration import QuicConfiguration + +from ..base import BaseService +from .wpt_h3_server.webtransport_h3_server import ( + SessionTicketStore, + WebTransportH3Protocol, + _connect_to_server, +) + +if "GRZ_QUIC_LOGGING" not in environ: + getLogger("quic").setLevel(CRITICAL) + +LOG = getLogger(__name__) + + +class WebTransportServer(BaseService): + def __init__(self, port: int, cert: Path, key: Path) -> None: + """A WebTransport service. + + Args: + port: The port on which to listen on. + cert: The path to the certificate file. + key: The path to the certificate's private key. + """ + self._port = port + self._cert = cert + self._key = key + + self._loop = None + self._server_thread = None + self._started = False + + @property + def location(self): + return "grz_webtransport_server" + + @property + def port(self): + """The port on which the service is listening""" + return self._port + + def url(self, _query): + """URL for Sapphire.set_dynamic_response + + Args: + _query (str): Unused query string. + + Returns: + bytes: Server URL. + """ + return b"https://127.0.0.1:%d" % (self._port,) + + async def is_ready(self): + """Wait until the service is ready""" + await _connect_to_server("127.0.0.1", self.port) + + def start(self) -> None: + """Start the server.""" + + def _start_service() -> None: + configuration = QuicConfiguration( + alpn_protocols=H3_ALPN, + is_client=False, + max_datagram_frame_size=65536, + ) + + LOG.info("Starting WebTransport service on port %s", self.port) + configuration.load_cert_chain(self._cert, self._key) + ticket_store = SessionTicketStore() + + # On Windows, the default event loop is ProactorEventLoop, but it + # doesn't seem to work when aioquic detects a connection loss. + # Use SelectorEventLoop to work around the problem. + if system() == "Windows": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + self._loop.run_until_complete( + serve( + "127.0.0.1", + self.port, + configuration=configuration, + create_protocol=WebTransportH3Protocol, + session_ticket_fetcher=ticket_store.pop, + session_ticket_handler=ticket_store.add, + ) + ) + self._loop.run_forever() + + self._server_thread = Thread(target=_start_service, daemon=True) + self._server_thread.start() + self._started = True + + def cleanup(self) -> None: + """Stop the server.""" + + async def _stop_loop() -> None: + self._loop.stop() + + if self._started: + asyncio.run_coroutine_threadsafe(_stop_loop(), self._loop) + self._server_thread.join() + LOG.info("Stopped WebTransport service on port %s", self._port) + self._started = False diff --git a/grizzly/services/webtransport/test_webtransport.py b/grizzly/services/webtransport/test_webtransport.py new file mode 100644 index 00000000..a3d3f411 --- /dev/null +++ b/grizzly/services/webtransport/test_webtransport.py @@ -0,0 +1,37 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# pylint: disable=protected-access +import asyncio + +import pytest + +from ...common.utils import CertificateBundle +from ..core import WebServices +from .core import WebTransportServer + + +def test_webtransport_01(): + """Verify that the WebTransport service started and shutdown gracefully""" + cert = CertificateBundle.create() + try: + port = WebServices.get_free_port() + web_transport = WebTransportServer(port, cert.host, cert.key) + assert not web_transport._started + + web_transport.start() + + # Check that all services are running + assert web_transport._started + asyncio.run(asyncio.wait_for(web_transport.is_ready(), timeout=3.0)) + + assert web_transport.location == "grz_webtransport_server" + assert isinstance(web_transport.url(None), bytes) + + web_transport.cleanup() + + assert not web_transport._started + with pytest.raises(asyncio.TimeoutError): + asyncio.run(asyncio.wait_for(web_transport.is_ready(), timeout=1.0)) + finally: + cert.cleanup() diff --git a/grizzly/services/webtransport/wpt_h3_server/LICENSE.md b/grizzly/services/webtransport/wpt_h3_server/LICENSE.md new file mode 100644 index 00000000..39c46d03 --- /dev/null +++ b/grizzly/services/webtransport/wpt_h3_server/LICENSE.md @@ -0,0 +1,11 @@ +# The 3-Clause BSD License + +Copyright © web-platform-tests contributors + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/grizzly/services/webtransport/wpt_h3_server/__init__.py b/grizzly/services/webtransport/wpt_h3_server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/grizzly/services/webtransport/wpt_h3_server/capsule.py b/grizzly/services/webtransport/wpt_h3_server/capsule.py new file mode 100644 index 00000000..2f142b94 --- /dev/null +++ b/grizzly/services/webtransport/wpt_h3_server/capsule.py @@ -0,0 +1,114 @@ +# pylint: skip-file +# mypy: no-warn-return-any + +from enum import IntEnum +from typing import Iterator, Optional + +# TODO(bashi): Remove import check suppressions once aioquic dependency is +# resolved. +from aioquic.buffer import UINT_VAR_MAX_SIZE, Buffer, BufferReadError # type: ignore + + +class CapsuleType(IntEnum): + # Defined in + # https://www.ietf.org/archive/id/draft-ietf-masque-h3-datagram-03.html. + DATAGRAM = 0xFF37A0 + REGISTER_DATAGRAM_CONTEXT = 0xFF37A1 + REGISTER_DATAGRAM_NO_CONTEXT = 0xFF37A2 + CLOSE_DATAGRAM_CONTEXT = 0xFF37A3 + # Defined in + # https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-01.html. + CLOSE_WEBTRANSPORT_SESSION = 0x2843 + + +class H3Capsule: + """ + Represents the Capsule concept defined in + https://ietf-wg-masque.github.io/draft-ietf-masque-h3-datagram/draft-ietf-masque-h3-datagram.html#name-capsules. + """ + + def __init__(self, type: int, data: bytes) -> None: + """ + :param type the type of this Capsule. We don't use CapsuleType here + because this may be a capsule of an unknown type. + :param data the payload + """ + self.type = type + self.data = data + + def encode(self) -> bytes: + """ + Encodes this H3Capsule and return the bytes. + """ + buffer = Buffer(capacity=len(self.data) + 2 * UINT_VAR_MAX_SIZE) + buffer.push_uint_var(self.type) + buffer.push_uint_var(len(self.data)) + buffer.push_bytes(self.data) + return buffer.data + + +class H3CapsuleDecoder: + """ + A decoder of H3Capsule. This is a streaming decoder and can handle multiple + decoders. + """ + + def __init__(self) -> None: + self._buffer: Optional[Buffer] = None + self._type: Optional[int] = None + self._length: Optional[int] = None + self._final: bool = False + + def append(self, data: bytes) -> None: + """ + Appends the given bytes to this decoder. + """ + assert not self._final + + if len(data) == 0: + return + if self._buffer: + remaining = self._buffer.pull_bytes( + self._buffer.capacity - self._buffer.tell() + ) + self._buffer = Buffer(data=(remaining + data)) + else: + self._buffer = Buffer(data=data) + + def final(self) -> None: + """ + Pushes the end-of-stream mark to this decoder. After calling this, + calling append() will be invalid. + """ + self._final = True + + def __iter__(self) -> Iterator[H3Capsule]: + """ + Yields decoded capsules. + """ + try: + while self._buffer is not None: + if self._type is None: + self._type = self._buffer.pull_uint_var() + if self._length is None: + self._length = self._buffer.pull_uint_var() + if self._buffer.capacity - self._buffer.tell() < self._length: + if self._final: + raise ValueError("insufficient buffer") + return + capsule = H3Capsule(self._type, self._buffer.pull_bytes(self._length)) + self._type = None + self._length = None + if self._buffer.tell() == self._buffer.capacity: + self._buffer = None + yield capsule + except BufferReadError as e: + if self._final: + raise e + if not self._buffer: + return 0 + size = self._buffer.capacity - self._buffer.tell() + if size >= UINT_VAR_MAX_SIZE: + raise e + # Ignore the error because there may not be sufficient input. + return diff --git a/grizzly/services/webtransport/wpt_h3_server/handlers/__init__.py b/grizzly/services/webtransport/wpt_h3_server/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/grizzly/services/webtransport/wpt_h3_server/handlers/abort_stream_from_server.py b/grizzly/services/webtransport/wpt_h3_server/handlers/abort_stream_from_server.py new file mode 100644 index 00000000..264c47e6 --- /dev/null +++ b/grizzly/services/webtransport/wpt_h3_server/handlers/abort_stream_from_server.py @@ -0,0 +1,12 @@ +# pylint: skip-file +def session_established(session): + session.dict_for_handlers["code"] = 400 + + +def stream_data_received(session, stream_id: int, data: bytes, stream_ended: bool): + code: int = session.dict_for_handlers["code"] + if session.stream_is_unidirectional(stream_id): + session.stop_stream(stream_id, code) + else: + session.stop_stream(stream_id, code) + session.reset_stream(stream_id, code) diff --git a/grizzly/services/webtransport/wpt_h3_server/handlers/custom_response.py b/grizzly/services/webtransport/wpt_h3_server/handlers/custom_response.py new file mode 100644 index 00000000..d5d980e4 --- /dev/null +++ b/grizzly/services/webtransport/wpt_h3_server/handlers/custom_response.py @@ -0,0 +1,14 @@ +# pylint: skip-file +from urllib.parse import parse_qsl, urlsplit + + +def connect_received(request_headers, response_headers): + for data in request_headers: + if data[0] == b":path": + path = data[1].decode("utf-8") + + qs = dict(parse_qsl(urlsplit(path).query)) + for key, value in qs.items(): + response_headers.append((key.encode("utf-8"), value.encode("utf-8"))) + + break diff --git a/grizzly/services/webtransport/wpt_h3_server/handlers/echo.py b/grizzly/services/webtransport/wpt_h3_server/handlers/echo.py new file mode 100644 index 00000000..29636f1a --- /dev/null +++ b/grizzly/services/webtransport/wpt_h3_server/handlers/echo.py @@ -0,0 +1,32 @@ +# pylint: skip-file +streams_dict = {} + + +def session_established(session): + # When a WebTransport session is established, a bidirectional stream is + # created by the server, which is used to echo back stream data from the + # client. + session.create_bidirectional_stream() + + +def stream_data_received(session, stream_id: int, data: bytes, stream_ended: bool): + # If a stream is unidirectional, create a new unidirectional stream and echo + # back the data on that stream. + if session.stream_is_unidirectional(stream_id): + # pylint: disable=consider-iterating-dictionary + if (session.session_id, stream_id) not in streams_dict.keys(): + new_stream_id = session.create_unidirectional_stream() + streams_dict[(session.session_id, stream_id)] = new_stream_id + session.send_stream_data( + streams_dict[(session.session_id, stream_id)], data, end_stream=stream_ended + ) + if stream_ended: + del streams_dict[(session.session_id, stream_id)] + return + # Otherwise (e.g. if the stream is bidirectional), echo back the data on the + # same stream. + session.send_stream_data(stream_id, data, end_stream=stream_ended) + + +def datagram_received(session, data: bytes): + session.send_datagram(data) diff --git a/grizzly/services/webtransport/wpt_h3_server/handlers/echo_request_headers.py b/grizzly/services/webtransport/wpt_h3_server/handlers/echo_request_headers.py new file mode 100644 index 00000000..d9429600 --- /dev/null +++ b/grizzly/services/webtransport/wpt_h3_server/handlers/echo_request_headers.py @@ -0,0 +1,12 @@ +# pylint: skip-file +import json + + +def session_established(session): + headers = {} + for name, value in session.request_headers: + headers[name.decode("utf-8")] = value.decode("utf-8") + + stream_id = session.create_unidirectional_stream() + data = json.dumps(headers).encode("utf-8") + session.send_stream_data(stream_id, data, end_stream=True) diff --git a/grizzly/services/webtransport/wpt_h3_server/handlers/server_close.py b/grizzly/services/webtransport/wpt_h3_server/handlers/server_close.py new file mode 100644 index 00000000..b0dbafc9 --- /dev/null +++ b/grizzly/services/webtransport/wpt_h3_server/handlers/server_close.py @@ -0,0 +1,17 @@ +# pylint: skip-file +from typing import Optional +from urllib.parse import parse_qsl, urlsplit + + +def session_established(session): + path: Optional[bytes] = None + for key, value in session.request_headers: + if key == b":path": + path = value + assert path is not None + qs = dict(parse_qsl(urlsplit(path).query)) + code = qs[b"code"] if b"code" in qs else None + reason = qs[b"reason"] if b"reason" in qs else b"" + close_info = None if code is None else (int(code), reason) + + session.close(close_info) diff --git a/grizzly/services/webtransport/wpt_h3_server/handlers/server_connection_close.py b/grizzly/services/webtransport/wpt_h3_server/handlers/server_connection_close.py new file mode 100644 index 00000000..b41ef67a --- /dev/null +++ b/grizzly/services/webtransport/wpt_h3_server/handlers/server_connection_close.py @@ -0,0 +1,9 @@ +# pylint: skip-file + + +def session_established(session): + session.create_bidirectional_stream() + + +def stream_data_received(session, stream_id: int, data: bytes, stream_ended: bool): + session._http._quic.close() diff --git a/grizzly/services/webtransport/wpt_h3_server/webtransport_h3_server.py b/grizzly/services/webtransport/wpt_h3_server/webtransport_h3_server.py new file mode 100644 index 00000000..6f455d74 --- /dev/null +++ b/grizzly/services/webtransport/wpt_h3_server/webtransport_h3_server.py @@ -0,0 +1,493 @@ +""" +A WebTransport over HTTP/3 server for testing. + +The server interprets the underlying protocols (WebTransport, HTTP/3 and QUIC) +and passes events to a particular webtransport handler. From the standpoint of +test authors, a webtransport handler is a Python script which contains some +callback functions. See handler.py for available callbacks. +""" +# pylint: skip-file +import asyncio +import logging +import ssl +import traceback +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from urllib.parse import urlparse + +from aioquic.asyncio import QuicConnectionProtocol +from aioquic.asyncio.client import connect +from aioquic.buffer import Buffer +from aioquic.h3.connection import ( + H3_ALPN, + FrameType, + H3Connection, + ProtocolError, + Setting, +) +from aioquic.h3.events import ( + DatagramReceived, + DataReceived, + H3Event, + HeadersReceived, + WebTransportStreamDataReceived, +) +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.connection import stream_is_unidirectional +from aioquic.quic.events import ( + ConnectionTerminated, + ProtocolNegotiated, + QuicEvent, + StreamReset, +) +from aioquic.tls import SessionTicket + +from .capsule import CapsuleType, H3Capsule, H3CapsuleDecoder + +SERVER_NAME = "webtransport-h3-server" + +LOG: logging.Logger = logging.getLogger(__name__) +DOC_ROOT: Path = Path(__file__).resolve().parent / "handlers" + + +class H3ConnectionWithDatagram04(H3Connection): + """ + A H3Connection subclass, to make it work with the latest + HTTP Datagram protocol. + """ + + H3_DATAGRAM_04 = 0xFFD277 + # https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-h3-websockets-00#section-5 + ENABLE_CONNECT_PROTOCOL = 0x08 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._supports_h3_datagram_04 = False + + def _validate_settings(self, settings: Dict[int, int]) -> None: + H3_DATAGRAM_04 = H3ConnectionWithDatagram04.H3_DATAGRAM_04 + if H3_DATAGRAM_04 in settings and settings[H3_DATAGRAM_04] == 1: + settings[Setting.H3_DATAGRAM] = 1 + self._supports_h3_datagram_04 = True + return super()._validate_settings(settings) + + def _get_local_settings(self) -> Dict[int, int]: + H3_DATAGRAM_04 = H3ConnectionWithDatagram04.H3_DATAGRAM_04 + settings = super()._get_local_settings() + settings[H3_DATAGRAM_04] = 1 + settings[H3ConnectionWithDatagram04.ENABLE_CONNECT_PROTOCOL] = 1 + return settings + + @property + def supports_h3_datagram_04(self) -> bool: + """ + True if the client supports the latest HTTP Datagram protocol. + """ + return self._supports_h3_datagram_04 + + +class WebTransportH3Protocol(QuicConnectionProtocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._handler: Optional[Any] = None + self._http: Optional[H3ConnectionWithDatagram04] = None + self._session_stream_id: Optional[int] = None + self._close_info: Optional[Tuple[int, bytes]] = None + self._capsule_decoder_for_session_stream: H3CapsuleDecoder = H3CapsuleDecoder() + self._allow_calling_session_closed = True + self._allow_datagrams = False + + def quic_event_received(self, event: QuicEvent) -> None: + if isinstance(event, ProtocolNegotiated): + self._http = H3ConnectionWithDatagram04( + self._quic, enable_webtransport=True + ) + if not self._http.supports_h3_datagram_04: + self._allow_datagrams = True + + if self._http is not None: + for http_event in self._http.handle_event(event): + self._h3_event_received(http_event) + + if isinstance(event, ConnectionTerminated): + self._call_session_closed(close_info=None, abruptly=True) + if isinstance(event, StreamReset): + if self._handler: + self._handler.stream_reset(event.stream_id, event.error_code) + + def _h3_event_received(self, event: H3Event) -> None: + if isinstance(event, HeadersReceived): + # Convert from List[Tuple[bytes, bytes]] to Dict[bytes, bytes]. + # Only the last header will be kept when there are duplicate + # headers. + headers = {} + for header, value in event.headers: + headers[header] = value + + method = headers.get(b":method") + protocol = headers.get(b":protocol") + origin = headers.get(b"origin") + # Accept any Origin but the client must send it. + if method == b"CONNECT" and protocol == b"webtransport" and origin: + self._session_stream_id = event.stream_id + self._handshake_webtransport(event, headers) + else: + status_code = 404 if origin else 403 + self._send_error_response(event.stream_id, status_code) + + if ( + isinstance(event, DataReceived) + and self._session_stream_id == event.stream_id + ): + if ( + self._http + and not self._http.supports_h3_datagram_04 + and len(event.data) > 0 + ): + raise ProtocolError("Unexpected data on the session stream") + self._receive_data_on_session_stream(event.data, event.stream_ended) + elif self._handler is not None: + if isinstance(event, WebTransportStreamDataReceived): + self._handler.stream_data_received( + stream_id=event.stream_id, + data=event.data, + stream_ended=event.stream_ended, + ) + elif isinstance(event, DatagramReceived): + if self._allow_datagrams: + self._handler.datagram_received(data=event.data) + + def _receive_data_on_session_stream(self, data: bytes, fin: bool) -> None: + self._capsule_decoder_for_session_stream.append(data) + if fin: + self._capsule_decoder_for_session_stream.final() + for capsule in self._capsule_decoder_for_session_stream: + if capsule.type in { + CapsuleType.DATAGRAM, + CapsuleType.REGISTER_DATAGRAM_CONTEXT, + CapsuleType.CLOSE_DATAGRAM_CONTEXT, + }: + raise ProtocolError(f"Unimplemented capsule type: {capsule.type}") + if capsule.type in { + CapsuleType.REGISTER_DATAGRAM_NO_CONTEXT, + CapsuleType.CLOSE_WEBTRANSPORT_SESSION, + }: + # We'll handle this case below. + pass + else: + # We should ignore unknown capsules. + continue + + if self._close_info is not None: + raise ProtocolError( + ( + "Receiving a capsule with type = {} after receiving " + + "CLOSE_WEBTRANSPORT_SESSION" + ).format(capsule.type) + ) + + if capsule.type == CapsuleType.REGISTER_DATAGRAM_NO_CONTEXT: + buffer = Buffer(data=capsule.data) + format_type = buffer.pull_uint_var() + # https://ietf-wg-webtrans.github.io/draft-ietf-webtrans-http3/draft-ietf-webtrans-http3.html#name-datagram-format-type + WEBTRANPORT_FORMAT_TYPE = 0xFF7C00 + if format_type != WEBTRANPORT_FORMAT_TYPE: + raise ProtocolError( + f"Unexpected datagram format type: {format_type}" + ) + self._allow_datagrams = True + elif capsule.type == CapsuleType.CLOSE_WEBTRANSPORT_SESSION: + buffer = Buffer(data=capsule.data) + code = buffer.pull_uint32() + # 4 bytes for the uint32. + reason = buffer.pull_bytes(len(capsule.data) - 4) + # TODO(yutakahirano): Make sure `reason` is a UTF-8 text. + self._close_info = (code, reason) + if fin: + self._call_session_closed(self._close_info, abruptly=False) + + def _send_error_response(self, stream_id: int, status_code: int) -> None: + assert self._http is not None + headers = [ + (b":status", str(status_code).encode()), + (b"server", SERVER_NAME.encode()), + ] + self._http.send_headers(stream_id=stream_id, headers=headers, end_stream=True) + + def _handshake_webtransport( + self, event: HeadersReceived, request_headers: Dict[bytes, bytes] + ) -> None: + assert self._http is not None + path = request_headers.get(b":path") + if path is None: + # `:path` must be provided. + self._send_error_response(event.stream_id, 400) + return + + # Create a handler using `:path`. + try: + self._handler = self._create_event_handler( + session_id=event.stream_id, path=path, request_headers=event.headers + ) + except OSError: + self._send_error_response(event.stream_id, 404) + return + + response_headers = [ + (b"server", SERVER_NAME.encode()), + (b"sec-webtransport-http3-draft", b"draft02"), + ] + self._handler.connect_received(response_headers=response_headers) + + status_code = None + for name, value in response_headers: + if name == b":status": + status_code = value + response_headers.remove((b":status", status_code)) + response_headers.insert(0, (b":status", status_code)) + break + if not status_code: + response_headers.insert(0, (b":status", b"200")) + self._http.send_headers(stream_id=event.stream_id, headers=response_headers) + + if status_code is None or status_code == b"200": + self._handler.session_established() + + def _create_event_handler( + self, + session_id: int, + path: bytes, + request_headers: List[Tuple[bytes, bytes]], + ) -> Any: + parsed = urlparse(path.decode()) + handler = (DOC_ROOT / parsed.path.lstrip("/")).with_suffix(".py") + callbacks = {"__file__": handler} + exec(compile(handler.read_text(), path, "exec"), callbacks) + session = WebTransportSession(self, session_id, request_headers) + return WebTransportEventHandler(session, callbacks) + + def _call_session_closed( + self, close_info: Optional[Tuple[int, bytes]], abruptly: bool + ) -> None: + allow_calling_session_closed = self._allow_calling_session_closed + self._allow_calling_session_closed = False + if self._handler and allow_calling_session_closed: + self._handler.session_closed(close_info, abruptly) + + +class WebTransportSession: + """ + A WebTransport session. + """ + + def __init__( + self, + protocol: WebTransportH3Protocol, + session_id: int, + request_headers: List[Tuple[bytes, bytes]], + ) -> None: + self.session_id = session_id + self.request_headers = request_headers + + self._protocol: WebTransportH3Protocol = protocol + self._http: H3Connection = protocol._http + + self._dict_for_handlers: Dict[str, Any] = {} + + @property + def dict_for_handlers(self) -> Dict[str, Any]: + """A dictionary that handlers can attach arbitrary data.""" + return self._dict_for_handlers + + def stream_is_unidirectional(self, stream_id: int) -> bool: + """Return True if the stream is unidirectional.""" + return stream_is_unidirectional(stream_id) + + def close(self, close_info: Optional[Tuple[int, bytes]]) -> None: + """ + Close the session. + + :param close_info The close information to send. + """ + self._protocol._allow_calling_session_closed = False + assert self._protocol._session_stream_id is not None + session_stream_id = self._protocol._session_stream_id + if close_info is not None: + code = close_info[0] + reason = close_info[1] + buffer = Buffer(capacity=len(reason) + 4) + buffer.push_uint32(code) + buffer.push_bytes(reason) + capsule = H3Capsule(CapsuleType.CLOSE_WEBTRANSPORT_SESSION, buffer.data) + self._http.send_data(session_stream_id, capsule.encode(), end_stream=False) + + self._http.send_data(session_stream_id, b"", end_stream=True) + # TODO(yutakahirano): Reset all other streams. + # TODO(yutakahirano): Reject future stream open requests + # We need to wait for the stream data to arrive at the client, and then + # we need to close the connection. At this moment we're relying on the + # client's behavior. + # TODO(yutakahirano): Implement the above. + + def create_unidirectional_stream(self) -> int: + """ + Create a unidirectional WebTransport stream and return the stream ID. + """ + return self._http.create_webtransport_stream( + session_id=self.session_id, is_unidirectional=True + ) + + def create_bidirectional_stream(self) -> int: + """ + Create a bidirectional WebTransport stream and return the stream ID. + """ + stream_id = self._http.create_webtransport_stream( + session_id=self.session_id, is_unidirectional=False + ) + # TODO(bashi): Remove this workaround when aioquic supports receiving + # data on server-initiated bidirectional streams. + stream = self._http._get_or_create_stream(stream_id) + assert stream.frame_type is None + assert stream.session_id is None + stream.frame_type = FrameType.WEBTRANSPORT_STREAM + stream.session_id = self.session_id + return stream_id + + def send_stream_data( + self, stream_id: int, data: bytes, end_stream: bool = False + ) -> None: + """ + Send data on the specific stream. + + :param stream_id: The stream ID on which to send the data. + :param data: The data to send. + :param end_stream: If set to True, the stream will be closed. + """ + self._http._quic.send_stream_data( + stream_id=stream_id, data=data, end_stream=end_stream + ) + + def send_datagram(self, data: bytes) -> None: + """ + Send data using a datagram frame. + + :param data: The data to send. + """ + if not self._protocol._allow_datagrams: + LOG.warning("Sending a datagram while that's now allowed - discarding it") + return + flow_id = self.session_id + if self._http.supports_h3_datagram_04: + # The REGISTER_DATAGRAM_NO_CONTEXT capsule was on the session + # stream, so we must have the ID of the stream. + assert self._protocol._session_stream_id is not None + # TODO(yutakahirano): Make sure if this is the correct logic. + # Chrome always use 0 for the initial stream and the initial flow + # ID, we cannot check the correctness with it. + flow_id = self._protocol._session_stream_id // 4 + self._http.send_datagram(flow_id=flow_id, data=data) + + def stop_stream(self, stream_id: int, code: int) -> None: + """ + Send a STOP_SENDING frame to the given stream. + :param code: the reason of the error. + """ + self._http._quic.stop_stream(stream_id, code) + + def reset_stream(self, stream_id: int, code: int) -> None: + """ + Send a RESET_STREAM frame to the given stream. + :param code: the reason of the error. + """ + self._http._quic.reset_stream(stream_id, code) + + +class WebTransportEventHandler: + def __init__(self, session: WebTransportSession, callbacks: Dict[str, Any]) -> None: + self._session = session + self._callbacks = callbacks + + def _run_callback(self, callback_name: str, *args: Any, **kwargs: Any) -> None: + if callback_name not in self._callbacks: + return + try: + self._callbacks[callback_name](*args, **kwargs) + except Exception as e: + LOG.warning(str(e)) + traceback.print_exc() + + def connect_received(self, response_headers: List[Tuple[bytes, bytes]]) -> None: + self._run_callback( + "connect_received", + self._session.request_headers, + response_headers, + ) + + def session_established(self) -> None: + self._run_callback("session_established", self._session) + + def stream_data_received( + self, stream_id: int, data: bytes, stream_ended: bool + ) -> None: + self._run_callback( + "stream_data_received", self._session, stream_id, data, stream_ended + ) + + def datagram_received(self, data: bytes) -> None: + self._run_callback("datagram_received", self._session, data) + + def session_closed( + self, + close_info: Optional[Tuple[int, bytes]], + abruptly: bool, + ) -> None: + self._run_callback( + "session_closed", self._session, close_info, abruptly=abruptly + ) + + def stream_reset(self, stream_id: int, error_code: int) -> None: + self._run_callback("stream_reset", self._session, stream_id, error_code) + + +class SessionTicketStore: + """ + Simple in-memory store for session tickets. + """ + + def __init__(self) -> None: + self.tickets: Dict[bytes, SessionTicket] = {} + + def add(self, ticket: SessionTicket) -> None: + self.tickets[ticket.ticket] = ticket + + def pop(self, label: bytes) -> Optional[SessionTicket]: + return self.tickets.pop(label, None) + + +def server_is_running(host: str, port: int, timeout: float) -> bool: + """ + Check the WebTransport over HTTP/3 server is running at the given `host` and + `port`. + """ + loop = asyncio.get_event_loop() + return loop.run_until_complete(_connect_server_with_timeout(host, port, timeout)) + + +async def _connect_server_with_timeout(host: str, port: int, timeout: float) -> bool: + try: + await asyncio.wait_for(_connect_to_server(host, port), timeout=timeout) + except asyncio.TimeoutError: + LOG.warning("Failed to connect WebTransport over HTTP/3 server") + return False + return True + + +async def _connect_to_server(host: str, port: int) -> None: + configuration = QuicConfiguration( + alpn_protocols=H3_ALPN, + is_client=True, + verify_mode=ssl.CERT_NONE, + ) + + async with connect(host, port, configuration=configuration) as protocol: + await protocol.ping() diff --git a/grizzly/session.py b/grizzly/session.py index 30709441..98b8a8d1 100644 --- a/grizzly/session.py +++ b/grizzly/session.py @@ -20,7 +20,6 @@ __author__ = "Tyson Smith" __credits__ = ["Tyson Smith", "Jesse Schwartzentruber"] - LOG = getLogger(__name__) @@ -165,6 +164,7 @@ def run( log_rate: LogRate = LogRate.NORMAL, launch_attempts: int = 3, post_launch_delay: int = 0, + services=None, ) -> None: assert iteration_limit >= 0 assert launch_attempts > 0 @@ -188,6 +188,8 @@ def run( self.iomanager.server_map.set_redirect( "grz_start", "grz_harness", required=False ) + if services: + services.map_locations(self.iomanager.server_map) log_limiter = LogOutputLimiter(rate=log_rate) # limit relaunch to max iterations if needed diff --git a/grizzly/test_session.py b/grizzly/test_session.py index c7f4ab5b..240ebeeb 100644 --- a/grizzly/test_session.py +++ b/grizzly/test_session.py @@ -15,6 +15,7 @@ from .adapter import Adapter from .common.reporter import Report, Reporter from .common.runner import RunResult +from .services.core import WebServices from .session import LogOutputLimiter, LogRate, Session, SessionError from .target import Result, Target @@ -93,12 +94,14 @@ def test_session_01(mocker, harness, profiling, coverage, relaunch, iters, runti Served.ALL, {session.iomanager.page_name(offset=-1): "/fake/path"}, ) + services = mocker.Mock(spec_set=WebServices) session.run( [], 10, input_path="file.bin", iteration_limit=iters, runtime_limit=runtime, + services=services, ) assert session.status.iteration == max_iters assert session.status.test_name == "file.bin" diff --git a/pyproject.toml b/pyproject.toml index 4f5a48f0..143e070a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ omit = [ "*/resources/*", "*/.tox/*", "*/.egg/*", + "grizzly/services/webtransport/wpt_h3_server/*", ] [tool.coverage.report] diff --git a/sapphire/__init__.py b/sapphire/__init__.py index de8b6469..387ca1da 100644 --- a/sapphire/__init__.py +++ b/sapphire/__init__.py @@ -7,7 +7,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from .certificate_bundle import CertificateBundle -from .core import Sapphire +from .core import Sapphire, create_listening_socket from .job import Served from .server_map import ServerMap @@ -16,6 +16,7 @@ "Sapphire", "Served", "ServerMap", + "create_listening_socket", ) __author__ = "Tyson Smith" __credits__ = ["Tyson Smith"] diff --git a/setup.cfg b/setup.cfg index e947fc1e..89cab00f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,7 @@ url = https://github.com/MozillaSecurity/grizzly [options] include_package_data = True install_requires = + aioquic == 0.9.20 bugsy cryptography cssbeautifier @@ -38,6 +39,7 @@ packages = grizzly.reduce.strategies grizzly.replay grizzly.target + grizzly.services loki sapphire python_requires = >=3.8 diff --git a/tox.ini b/tox.ini index 999208e6..f05aabf4 100644 --- a/tox.ini +++ b/tox.ini @@ -7,6 +7,7 @@ tox_pip_extensions_ext_venv_update = true commands = pytest -v --cache-clear --cov={toxinidir} --cov-config={toxinidir}/pyproject.toml --cov-report=term-missing {posargs} deps = pytest + pytest-asyncio pytest-cov pytest-mock passenv =