From b4d799c7904d0eb8b036ff4675bb3c78d0d073ce Mon Sep 17 00:00:00 2001 From: Kefu Chai Date: Fri, 3 May 2024 22:33:39 +0800 Subject: [PATCH] coroutine/async_generator: reimplement async generator this generator implementation is inspired by https://wg21.link/P2502R2. Refs #2190 Refs #1913 Refs #1677 Signed-off-by: Kefu Chai --- include/seastar/coroutine/async_generator.hh | 363 +++++++++++++++++++ tests/unit/CMakeLists.txt | 3 + tests/unit/generator_test.cc | 152 ++++++++ 3 files changed, 518 insertions(+) create mode 100644 include/seastar/coroutine/async_generator.hh create mode 100644 tests/unit/generator_test.cc diff --git a/include/seastar/coroutine/async_generator.hh b/include/seastar/coroutine/async_generator.hh new file mode 100644 index 00000000000..f1dc79afa94 --- /dev/null +++ b/include/seastar/coroutine/async_generator.hh @@ -0,0 +1,363 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +// async_generator concept is heavily inspired by P2502R2 +// (https://wg21.link/P2502R2), a proposal accepted into C++23. P2502R2 +// introduced std::generator, which provides a synchronous coroutine +// mechanism for generating ranges. in contrast, async_generator offers +// asynchronous generation of element sequences. +namespace seastar::coroutine::experimental { + +template +class async_generator; + +namespace internal { + +template class next_awaiter; + +template +class async_generator_promise_base : public seastar::task { +protected: + std::add_pointer_t _value = nullptr; + +protected: + std::exception_ptr _exception; + std::coroutine_handle<> _consumer; + task* _waiting_task = nullptr; + + class yield_awaiter final { + async_generator_promise_base* _promise; + std::coroutine_handle<> _consumer; + public: + yield_awaiter(async_generator_promise_base* promise, + std::coroutine_handle<> consumer) noexcept + : _promise{promise} + , _consumer{consumer} + {} + bool await_ready() const noexcept { + return false; + } + template + std::coroutine_handle<> await_suspend(std::coroutine_handle producer) noexcept { + _promise->_waiting_task = &producer.promise(); + return _consumer; + } + void await_resume() noexcept {} + }; + + yield_awaiter do_yield() noexcept { + return yield_awaiter{this, _consumer}; + } + +public: + async_generator_promise_base() noexcept = default; + async_generator_promise_base(const async_generator_promise_base &) = delete; + async_generator_promise_base& operator=(const async_generator_promise_base &) = delete; + async_generator_promise_base(async_generator_promise_base &&) noexcept = default; + async_generator_promise_base& operator=(async_generator_promise_base &&) noexcept = default; + + // lazily-started coroutine, do not execute the coroutine until + // the coroutine is awaited. + std::suspend_always initial_suspend() const noexcept { + return {}; + } + + yield_awaiter final_suspend() noexcept { + _value = nullptr; + return do_yield(); + } + + void unhandled_exception() noexcept { + _exception = std::current_exception(); + } + + void return_void() noexcept {} + + // @return if the generator has reached the end of the sequence + bool finished() const noexcept { + return _value == nullptr; + } + + void rethrow_if_unhandled_exception() { + if (_exception) { + std::rethrow_exception(std::move(_exception)); + } + } + + void run_and_dispose() noexcept final { + using handle_type = std::coroutine_handle; + handle_type::from_promise(*this).resume(); + } + + seastar::task* waiting_task() noexcept final { + return _waiting_task; + } + + class element_awaiter { + std::remove_cvref_t _value; + constexpr bool await_ready() const noexcept { + return false; + } + template + constexpr void await_suspend(std::coroutine_handle producer) noexcept { + auto& current = producer.promise(); + producer._value = std::addressof(_value); + } + constexpr void await_resume() const noexcept {} + }; + +private: + friend class next_awaiter; +}; + +template +class next_awaiter { +protected: + async_generator_promise_base* _promise = nullptr; + std::coroutine_handle<> _producer = nullptr; + + explicit next_awaiter(std::nullptr_t) noexcept {} + next_awaiter(async_generator_promise_base& promise, + std::coroutine_handle<> producer) noexcept + : _promise{std::addressof(promise)} + , _producer{producer} {} + +public: + bool await_ready() const noexcept { + return false; + } + + template + std::coroutine_handle<> await_suspend(std::coroutine_handle consumer) noexcept { + _promise->_consumer = consumer; + return _producer; + } +}; + +} // namespace internal + +template +class [[nodiscard]] async_generator { + using value_type = std::conditional_t, + std::remove_cvref_t, + Value>; + using reference_type = std::conditional_t, + Ref&&, + Ref>; + using yielded_type = std::conditional_t, + reference_type, + const reference_type&>; + +public: + class promise_type; + +private: + using handle_type = std::coroutine_handle; + handle_type _coro = {}; + +public: + class iterator; + + async_generator() noexcept = default; + explicit async_generator(promise_type& promise) noexcept + : _coro(std::coroutine_handle::from_promise(promise)) + {} + async_generator(async_generator&& other) noexcept + : _coro{std::exchange(other._coro, {})} + {} + async_generator(const async_generator&) = delete; + async_generator& operator=(const async_generator &) = delete; + + ~async_generator() { + if (_coro) { + _coro.destroy(); + } + } + + friend void swap(async_generator& lhs, async_generator& rhs) noexcept { + std::swap(lhs._coro, rhs._coro); + } + + async_generator& operator=(async_generator &&other) noexcept { + if (_coro) { + _coro.destroy(); + } + _coro = std::exchange(other._coro, nullptr); + return *this; + } + + [[nodiscard]] auto begin() noexcept { + using base_awaiter = internal::next_awaiter; + class begin_awaiter final : public base_awaiter { + using base_awaiter::_promise; + + public: + explicit begin_awaiter(std::nullptr_t) noexcept + : base_awaiter{nullptr} + {} + explicit begin_awaiter(handle_type producer_coro) noexcept + : base_awaiter{producer_coro.promise(), producer_coro} + {} + bool await_ready() const noexcept { + return _promise == nullptr || base_awaiter::await_ready(); + } + + iterator await_resume() { + if (_promise == nullptr) { + return iterator{nullptr}; + } + if (_promise->finished()) { + _promise->rethrow_if_unhandled_exception(); + return iterator{nullptr}; + } + return iterator{ + handle_type::from_promise(*static_cast(_promise)) + }; + } + }; + + if (_coro) { + return begin_awaiter{_coro}; + } else { + return begin_awaiter{nullptr}; + } + } + + [[nodiscard]] std::default_sentinel_t end() const noexcept { + return {}; + } +}; + +template +class async_generator::promise_type final : public internal::async_generator_promise_base { + using yield_awaiter = internal::async_generator_promise_base::yield_awaiter; + using element_awaiter = internal::async_generator_promise_base::element_awaiter; + using internal::async_generator_promise_base::_value; + using internal::async_generator_promise_base::_exception; + +public: + async_generator get_return_object() noexcept { + return async_generator{*this}; + } + + // lazily-started coroutine, do not execute the coroutine until + // the coroutine is awaited. + std::suspend_always initial_suspend() const noexcept { + return {}; + } + yield_awaiter final_suspend() noexcept { + _value = nullptr; + return this->do_yield(); + } + + yield_awaiter yield_value(yielded_type value) noexcept { + _value = std::addressof(value); + return this->do_yield(); + } + + element_awaiter yield_value(const std::remove_reference_t& value) + requires (std::is_rvalue_reference_v && + std::constructible_from< + std::remove_cvref_t, + const std::remove_reference_t&>) { + return element_awaiter{value}; + } + + yielded_type value() const noexcept { + return static_cast(*_value); + } + + void unhandled_exception() noexcept { + _exception = std::current_exception(); + } + + void return_void() noexcept {} + + // @return if the generator has reached the end of the sequence + bool finished() const noexcept { + return _value == nullptr; + } +}; + +template +class async_generator::iterator final { +private: + using handle_type = async_generator::handle_type; + handle_type _coro = nullptr; + +public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = async_generator::value_type; + using reference = async_generator::reference_type; + using pointer = std::add_pointer_t; + + explicit iterator(handle_type coroutine) noexcept + : _coro{coroutine} + {} + + explicit operator bool() const noexcept { + return _coro && !_coro.done(); + } + + [[nodiscard]] auto operator++() noexcept { + using base_awaiter = internal::next_awaiter; + class increment_awaiter final : public base_awaiter { + iterator& _iterator; + using base_awaiter::_promise; + + public: + explicit increment_awaiter(iterator& iterator) noexcept + : base_awaiter{iterator._coro.promise(), iterator._coro} + , _iterator{iterator} + {} + iterator& await_resume() { + if (_promise->finished()) { + // update iterator to end() + _iterator = iterator{nullptr}; + _promise->rethrow_if_unhandled_exception(); + } + return _iterator; + } + }; + + assert(bool(*this) && "cannot increment end iterator"); + return increment_awaiter{*this}; + } + + reference operator*() const noexcept { + return _coro.promise().value(); + } + + bool operator==(std::default_sentinel_t) const noexcept { + return !bool(*this); + } +}; + +} // namespace seastar::coroutine::experimental diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 181524901eb..cc0363518d6 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -282,6 +282,9 @@ seastar_add_test (content_source seastar_add_test (coroutines SOURCES coroutines_test.cc) +seastar_add_test (generator + SOURCES generator_test.cc) + seastar_add_test (defer KIND BOOST SOURCES defer_test.cc) diff --git a/tests/unit/generator_test.cc b/tests/unit/generator_test.cc new file mode 100644 index 00000000000..35ac00082f9 --- /dev/null +++ b/tests/unit/generator_test.cc @@ -0,0 +1,152 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright (C) 2024 ScyllaDB Ltd. + */ + +#include +#include +#include +#include +#include +#include + +using namespace seastar; + +using do_suspend = bool_class; + +coroutine::experimental::async_generator +fibonacci_sequence(unsigned count, + do_suspend suspend) { + auto a = 0, b = 1; + for (unsigned i = 0; i < count; ++i) { + if (std::numeric_limits::max() - a < b) { + throw std::out_of_range( + fmt::format("fibonacci[{}] is greater than the largest value of int", i)); + } + if (suspend) { + co_await coroutine::maybe_yield(); + } + co_yield std::exchange(a, std::exchange(b, a + b)); + } +} + +seastar::future<> test_async_generator_drained(do_suspend suspend) { + auto expected_fibs = {0, 1, 1, 2}; + auto expected_fib = std::begin(expected_fibs); + + auto actual_fibs = fibonacci_sequence(std::size(expected_fibs), suspend); + auto actual_fib = co_await actual_fibs.begin(); + + for (; actual_fib != actual_fibs.end(); co_await ++actual_fib) { + BOOST_REQUIRE(expected_fib != std::end(expected_fibs)); + BOOST_REQUIRE_EQUAL(*actual_fib, *expected_fib); + ++expected_fib; + } + BOOST_REQUIRE(actual_fib == actual_fibs.end()); +} + +SEASTAR_TEST_CASE(test_async_generator_drained_with_suspend) { + return test_async_generator_drained(do_suspend::yes); +} + +SEASTAR_TEST_CASE(test_async_generator_drained_without_suspend) { + return test_async_generator_drained(do_suspend::no); +} + +seastar::future<> test_async_generator_not_drained(do_suspend suspend) { + auto fib = fibonacci_sequence(42, suspend); + auto actual_fib = co_await fib.begin(); + BOOST_REQUIRE_EQUAL(*actual_fib, 0); +} + +SEASTAR_TEST_CASE(test_async_generator_not_drained_with_suspend) { + return test_async_generator_not_drained(do_suspend::yes); +} + +SEASTAR_TEST_CASE(test_async_generator_not_drained_without_suspend) { + return test_async_generator_not_drained(do_suspend::no); +} + +struct counter_t { + int n; + int* count; + counter_t(counter_t&& other) noexcept + : n{std::exchange(other.n, -1)}, + count{std::exchange(other.count, nullptr)} + {} + counter_t(int n, int* count) noexcept + : n{n}, count{count} { + ++(*count); + } + ~counter_t() noexcept { + if (count) { + --(*count); + } + } +}; + +std::ostream& operator<<(std::ostream& os, const counter_t& c) { + return os << c.n; +} + +coroutine::experimental::async_generator +fiddle(int n, int* total) { + int i = 0; + while (true) { + if (i++ == n) { + throw std::invalid_argument("Eureka from generator!"); + } + co_yield counter_t{i, total}; + } +} + +SEASTAR_TEST_CASE(test_async_generator_throws_from_generator) { + int total = 0; + auto count_to = [total=&total](unsigned n) -> seastar::future<> { + auto count = fiddle(n, total); + auto it = co_await count.begin(); + for (unsigned i = 0; i < 2 * n; i++) { + co_await ++it; + } + }; + co_await count_to(42).then_wrapped([&total] (auto f) { + BOOST_REQUIRE(f.failed()); + BOOST_REQUIRE_THROW(std::rethrow_exception(f.get_exception()), std::invalid_argument); + BOOST_REQUIRE_EQUAL(total, 0); + }); +} + +SEASTAR_TEST_CASE(test_async_generator_throws_from_consumer) { + int total = 0; + auto count_to = [total=&total](unsigned n) -> seastar::future<> { + auto count = fiddle(n, total); + auto it = co_await count.begin(); + for (unsigned i = 0; i < n; i++) { + if (i == n / 2) { + throw std::invalid_argument("Eureka from consumer!"); + } + co_await ++it; + } + }; + co_await count_to(42).then_wrapped([&total] (auto f) { + BOOST_REQUIRE(f.failed()); + BOOST_REQUIRE_THROW(std::rethrow_exception(f.get_exception()), std::invalid_argument); + BOOST_REQUIRE_EQUAL(total, 0); + }); +}