Line data Source code
1 : //
2 : // Copyright (c) 2026 Steve Gerbino
3 : //
4 : // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 : //
7 : // Official repository: https://github.com/cppalliance/capy
8 : //
9 :
10 : #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 : #define BOOST_CAPY_WHEN_ALL_HPP
12 :
13 : #include <boost/capy/detail/config.hpp>
14 : #include <boost/capy/concept/executor.hpp>
15 : #include <boost/capy/io_awaitable.hpp>
16 : #include <boost/capy/coro.hpp>
17 : #include <boost/capy/ex/executor_ref.hpp>
18 : #include <boost/capy/ex/frame_allocator.hpp>
19 : #include <boost/capy/task.hpp>
20 :
21 : #include <array>
22 : #include <atomic>
23 : #include <exception>
24 : #include <optional>
25 : #include <stop_token>
26 : #include <tuple>
27 : #include <type_traits>
28 : #include <utility>
29 :
30 : namespace boost {
31 : namespace capy {
32 :
33 : namespace detail {
34 :
35 : /** Type trait to filter void types from a tuple.
36 :
37 : Void-returning tasks do not contribute a value to the result tuple.
38 : This trait computes the filtered result type.
39 :
40 : Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 : */
42 : template<typename T>
43 : using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44 :
45 : template<typename... Ts>
46 : using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47 :
48 : /** Holds the result of a single task within when_all.
49 : */
50 : template<typename T>
51 : struct result_holder
52 : {
53 : std::optional<T> value_;
54 :
55 9 : void set(T v)
56 : {
57 9 : value_ = std::move(v);
58 9 : }
59 :
60 9 : T get() &&
61 : {
62 9 : return std::move(*value_);
63 : }
64 : };
65 :
66 : /** Specialization for void tasks - no value storage needed.
67 : */
68 : template<>
69 : struct result_holder<void>
70 : {
71 : };
72 :
73 : /** Shared state for when_all operation.
74 :
75 : @tparam Ts The result types of the tasks.
76 : */
77 : template<typename... Ts>
78 : struct when_all_state
79 : {
80 : static constexpr std::size_t task_count = sizeof...(Ts);
81 :
82 : // Completion tracking - when_all waits for all children
83 : std::atomic<std::size_t> remaining_count_;
84 :
85 : // Result storage in input order
86 : std::tuple<result_holder<Ts>...> results_;
87 :
88 : // Runner handles - destroyed in await_resume while allocator is valid
89 : std::array<coro, task_count> runner_handles_{};
90 :
91 : // Exception storage - first error wins, others discarded
92 : std::atomic<bool> has_exception_{false};
93 : std::exception_ptr first_exception_;
94 :
95 : // Stop propagation - on error, request stop for siblings
96 : std::stop_source stop_source_;
97 :
98 : // Connects parent's stop_token to our stop_source
99 : struct stop_callback_fn
100 : {
101 : std::stop_source* source_;
102 1 : void operator()() const { source_->request_stop(); }
103 : };
104 : using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 : std::optional<stop_callback_t> parent_stop_callback_;
106 :
107 : // Parent resumption
108 : coro continuation_;
109 : executor_ref caller_ex_;
110 :
111 5 : when_all_state()
112 5 : : remaining_count_(task_count)
113 : {
114 5 : }
115 :
116 5 : ~when_all_state()
117 : {
118 14 : for(auto h : runner_handles_)
119 9 : if(h)
120 9 : h.destroy();
121 5 : }
122 :
123 : /** Capture an exception (first one wins).
124 : */
125 0 : void capture_exception(std::exception_ptr ep)
126 : {
127 0 : bool expected = false;
128 0 : if(has_exception_.compare_exchange_strong(
129 : expected, true, std::memory_order_relaxed))
130 0 : first_exception_ = ep;
131 0 : }
132 :
133 : /** Signal that a task has completed.
134 :
135 : The last child to complete triggers resumption of the parent.
136 : */
137 9 : coro signal_completion()
138 : {
139 9 : auto remaining = remaining_count_.fetch_sub(1, std::memory_order_acq_rel);
140 9 : if(remaining == 1)
141 5 : return caller_ex_.dispatch(continuation_);
142 4 : return std::noop_coroutine();
143 : }
144 :
145 : };
146 :
147 : /** Wrapper coroutine that intercepts task completion.
148 :
149 : This runner awaits its assigned task and stores the result in
150 : the shared state, or captures the exception and requests stop.
151 : */
152 : template<typename T, typename... Ts>
153 : struct when_all_runner
154 : {
155 : struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
156 : {
157 : when_all_state<Ts...>* state_ = nullptr;
158 : executor_ref ex_;
159 : std::stop_token stop_token_;
160 :
161 9 : when_all_runner get_return_object()
162 : {
163 9 : return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
164 : }
165 :
166 9 : std::suspend_always initial_suspend() noexcept
167 : {
168 9 : return {};
169 : }
170 :
171 9 : auto final_suspend() noexcept
172 : {
173 : struct awaiter
174 : {
175 : promise_type* p_;
176 :
177 9 : bool await_ready() const noexcept
178 : {
179 9 : return false;
180 : }
181 :
182 9 : coro await_suspend(coro) noexcept
183 : {
184 : // Signal completion; last task resumes parent
185 9 : return p_->state_->signal_completion();
186 : }
187 :
188 0 : void await_resume() const noexcept
189 : {
190 0 : }
191 : };
192 9 : return awaiter{this};
193 : }
194 :
195 9 : void return_void()
196 : {
197 9 : }
198 :
199 0 : void unhandled_exception()
200 : {
201 0 : state_->capture_exception(std::current_exception());
202 : // Request stop for sibling tasks
203 0 : state_->stop_source_.request_stop();
204 0 : }
205 :
206 : template<class Awaitable>
207 : struct transform_awaiter
208 : {
209 : std::decay_t<Awaitable> a_;
210 : promise_type* p_;
211 :
212 9 : bool await_ready()
213 : {
214 9 : return a_.await_ready();
215 : }
216 :
217 9 : auto await_resume()
218 : {
219 9 : return a_.await_resume();
220 : }
221 :
222 : template<class Promise>
223 9 : auto await_suspend(std::coroutine_handle<Promise> h)
224 : {
225 9 : return a_.await_suspend(h, p_->ex_, p_->stop_token_);
226 : }
227 : };
228 :
229 : template<class Awaitable>
230 9 : auto await_transform(Awaitable&& a)
231 : {
232 : using A = std::decay_t<Awaitable>;
233 : if constexpr (IoAwaitable<A>)
234 : {
235 : return transform_awaiter<Awaitable>{
236 18 : std::forward<Awaitable>(a), this};
237 : }
238 : else
239 : {
240 : static_assert(sizeof(A) == 0, "requires IoAwaitable");
241 : }
242 9 : }
243 : };
244 :
245 : std::coroutine_handle<promise_type> h_;
246 :
247 9 : explicit when_all_runner(std::coroutine_handle<promise_type> h)
248 9 : : h_(h)
249 : {
250 9 : }
251 :
252 : // Enable move for all clang versions - some versions need it
253 : when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
254 :
255 : // Non-copyable
256 : when_all_runner(when_all_runner const&) = delete;
257 : when_all_runner& operator=(when_all_runner const&) = delete;
258 : when_all_runner& operator=(when_all_runner&&) = delete;
259 :
260 9 : auto release() noexcept
261 : {
262 9 : return std::exchange(h_, nullptr);
263 : }
264 : };
265 :
266 : /** Create a runner coroutine for a single task.
267 :
268 : Task is passed directly to ensure proper coroutine frame storage.
269 : */
270 : template<std::size_t Index, typename T, typename... Ts>
271 : when_all_runner<T, Ts...>
272 9 : make_when_all_runner(task<T> inner, when_all_state<Ts...>* state)
273 : {
274 : if constexpr (std::is_void_v<T>)
275 : {
276 : co_await std::move(inner);
277 : }
278 : else
279 : {
280 : std::get<Index>(state->results_).set(co_await std::move(inner));
281 : }
282 18 : }
283 :
284 : /** Internal awaitable that launches all runner coroutines and waits.
285 :
286 : This awaitable is used inside the when_all coroutine to handle
287 : the concurrent execution of child tasks.
288 : */
289 : template<typename... Ts>
290 : class when_all_launcher
291 : {
292 : std::tuple<task<Ts>...>* tasks_;
293 : when_all_state<Ts...>* state_;
294 :
295 : public:
296 5 : when_all_launcher(
297 : std::tuple<task<Ts>...>* tasks,
298 : when_all_state<Ts...>* state)
299 5 : : tasks_(tasks)
300 5 : , state_(state)
301 : {
302 5 : }
303 :
304 5 : bool await_ready() const noexcept
305 : {
306 5 : return sizeof...(Ts) == 0;
307 : }
308 :
309 5 : coro await_suspend(coro continuation, executor_ref caller_ex, std::stop_token parent_token = {})
310 : {
311 5 : state_->continuation_ = continuation;
312 5 : state_->caller_ex_ = caller_ex;
313 :
314 : // Forward parent's stop requests to children
315 5 : if(parent_token.stop_possible())
316 : {
317 4 : state_->parent_stop_callback_.emplace(
318 : parent_token,
319 2 : typename when_all_state<Ts...>::stop_callback_fn{&state_->stop_source_});
320 :
321 2 : if(parent_token.stop_requested())
322 1 : state_->stop_source_.request_stop();
323 : }
324 :
325 : // CRITICAL: If the last task finishes synchronously then the parent
326 : // coroutine resumes, destroying its frame, and destroying this object
327 : // prior to the completion of await_suspend. Therefore, await_suspend
328 : // must ensure `this` cannot be referenced after calling `launch_one`
329 : // for the last time.
330 5 : auto token = state_->stop_source_.get_token();
331 2 : [&]<std::size_t... Is>(std::index_sequence<Is...>) {
332 5 : (..., launch_one<Is>(caller_ex, token));
333 5 : }(std::index_sequence_for<Ts...>{});
334 :
335 : // Let signal_completion() handle resumption
336 10 : return std::noop_coroutine();
337 5 : }
338 :
339 5 : void await_resume() const noexcept
340 : {
341 : // Results are extracted by the when_all coroutine from state
342 5 : }
343 :
344 : private:
345 : template<std::size_t I>
346 9 : void launch_one(executor_ref caller_ex, std::stop_token token)
347 : {
348 9 : auto runner = make_when_all_runner<I>(
349 9 : std::move(std::get<I>(*tasks_)), state_);
350 :
351 9 : auto h = runner.release();
352 9 : h.promise().state_ = state_;
353 9 : h.promise().ex_ = caller_ex;
354 9 : h.promise().stop_token_ = token;
355 :
356 9 : coro ch{h};
357 9 : state_->runner_handles_[I] = ch;
358 9 : state_->caller_ex_.dispatch(ch).resume();
359 9 : }
360 : };
361 :
362 : /** Compute the result type for when_all.
363 :
364 : Returns void when all tasks are void (P2300 aligned),
365 : otherwise returns a tuple with void types filtered out.
366 : */
367 : template<typename... Ts>
368 : using when_all_result_t = std::conditional_t<
369 : std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
370 : void,
371 : filter_void_tuple_t<Ts...>>;
372 :
373 : /** Helper to extract a single result, returning empty tuple for void.
374 : This is a separate function to work around a GCC-11 ICE that occurs
375 : when using nested immediately-invoked lambdas with pack expansion.
376 : */
377 : template<std::size_t I, typename... Ts>
378 9 : auto extract_single_result(when_all_state<Ts...>& state)
379 : {
380 : using T = std::tuple_element_t<I, std::tuple<Ts...>>;
381 : if constexpr (std::is_void_v<T>)
382 0 : return std::tuple<>();
383 : else
384 9 : return std::make_tuple(std::move(std::get<I>(state.results_)).get());
385 : }
386 :
387 : /** Extract results from state, filtering void types.
388 : */
389 : template<typename... Ts>
390 5 : auto extract_results(when_all_state<Ts...>& state)
391 : {
392 6 : return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
393 5 : return std::tuple_cat(extract_single_result<Is>(state)...);
394 10 : }(std::index_sequence_for<Ts...>{});
395 : }
396 :
397 : } // namespace detail
398 :
399 : /** Wait for all tasks to complete concurrently.
400 :
401 : @par Example
402 : @code
403 : task<void> example() {
404 : auto [a, b] = co_await when_all(
405 : fetch_int(), // task<int>
406 : fetch_string() // task<std::string>
407 : );
408 : }
409 : @endcode
410 :
411 : @param tasks The tasks to execute concurrently.
412 : @return A task yielding a tuple of results (void types filtered out).
413 :
414 : Key features:
415 : @li All child tasks are launched concurrently
416 : @li Results are collected in input order
417 : @li First error is captured; subsequent errors are discarded
418 : @li On error, stop is requested for all siblings
419 : @li Completes only after all children have completed
420 : @li Void tasks do not contribute to the result tuple
421 : @li Properly propagates frame allocators to all child coroutines
422 : */
423 : template<typename... Ts>
424 : [[nodiscard]] task<detail::when_all_result_t<Ts...>>
425 5 : when_all(task<Ts>... tasks)
426 : {
427 : using result_type = detail::when_all_result_t<Ts...>;
428 :
429 : // State is stored in the coroutine frame, using the frame allocator
430 : detail::when_all_state<Ts...> state;
431 :
432 : // Store tasks in the frame
433 : std::tuple<task<Ts>...> task_tuple(std::move(tasks)...);
434 :
435 : // Launch all tasks and wait for completion
436 : co_await detail::when_all_launcher<Ts...>(&task_tuple, &state);
437 :
438 : // Propagate first exception if any.
439 : // Safe without explicit acquire: capture_exception() is sequenced-before
440 : // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
441 : // last task's decrement that resumes this coroutine.
442 : if(state.first_exception_)
443 : std::rethrow_exception(state.first_exception_);
444 :
445 : // Extract and return results
446 : if constexpr (std::is_void_v<result_type>)
447 : co_return;
448 : else
449 : co_return detail::extract_results(state);
450 10 : }
451 :
452 : // For backwards compatibility and type queries, expose result type computation
453 : template<typename... Ts>
454 : using when_all_result_type = detail::when_all_result_t<Ts...>;
455 :
456 : } // namespace capy
457 : } // namespace boost
458 :
459 : #endif
|