-
Notifications
You must be signed in to change notification settings - Fork 236
Expand file tree
/
Copy pathalgorithm_base.cuh
More file actions
178 lines (146 loc) · 6.22 KB
/
algorithm_base.cuh
File metadata and controls
178 lines (146 loc) · 6.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
/*
* Copyright (c) 2022 NVIDIA Corporation
*
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://llvm.org/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../../stdexec/execution.hpp"
#include "../../stdexec/__detail/__ranges.hpp"
#include <type_traits>
#include <cuda/std/type_traits>
#include <cub/device/device_reduce.cuh>
#include "common.cuh"
#include "../detail/throw_on_cuda_error.cuh"
namespace nvexec::STDEXEC_STREAM_DETAIL_NS::__algo_range_init_fun {
template <class Range, class InitT, class Fun>
using binary_invoke_result_t = ::cuda::std::decay_t<
::cuda::std::invoke_result_t<Fun, stdexec::ranges::range_reference_t<Range>, InitT>>;
template <class SenderId, class ReceiverId, class InitT, class Fun, class DerivedReceiver>
struct receiver_t {
struct __t : public stream_receiver_base {
using Sender = stdexec::__t<SenderId>;
using Receiver = stdexec::__t<ReceiverId>;
template <class... Range>
struct result_size_for {
using __t = __msize_t< sizeof(typename DerivedReceiver::template result_t<Range...>)>;
};
template <class... Sizes>
struct max_in_pack {
static constexpr ::std::size_t value = ::std::max({::std::size_t{}, __v<Sizes>...});
};
struct max_result_size {
template <class... _As>
using result_size_for_t = stdexec::__t<result_size_for<_As...>>;
static constexpr ::std::size_t value = //
__v< __gather_completions_for<
set_value_t,
Sender,
env_of_t<Receiver>,
__q<result_size_for_t>,
__q<max_in_pack>>>;
};
operation_state_base_t<ReceiverId>& op_state_;
STDEXEC_NO_UNIQUE_ADDRESS InitT init_;
STDEXEC_NO_UNIQUE_ADDRESS Fun fun_;
public:
using __id = receiver_t;
constexpr static ::std::size_t memory_allocation_size = max_result_size::value;
template <same_as<set_value_t> _Tag, class Range>
friend void tag_invoke(_Tag, __t&& self, Range&& range) noexcept {
DerivedReceiver::set_value_impl((__t&&) self, (Range&&) range);
}
template <__one_of<set_error_t, set_stopped_t> Tag, class... As>
friend void tag_invoke(Tag, __t&& self, As&&... as) noexcept {
self.op_state_.propagate_completion_signal(Tag(), (As&&) as...);
}
friend env_of_t<Receiver> tag_invoke(get_env_t, const __t& self) noexcept {
return get_env(self.op_state_.rcvr_);
}
__t(InitT init, Fun fun, operation_state_base_t<ReceiverId>& op_state)
: op_state_(op_state)
, init_((InitT&&) init)
, fun_((Fun&&) fun) {
}
};
};
template <class Tag, class SenderId, class InitT, class Fun, class DerivedSender>
struct sender_t {
struct __t : stream_sender_base {
using Sender = stdexec::__t<SenderId>;
using __id = sender_t;
using is_sender = void;
template <class Receiver>
using receiver_t = typename DerivedSender::template receiver_t<Receiver>;
template <class Range>
using _set_value_t = typename DerivedSender::template _set_value_t<Range>;
Sender sndr_;
// why is this called initT, anyway? If other algorithms will use this in the future im not sure initT is a good name
STDEXEC_NO_UNIQUE_ADDRESS InitT init_;
STDEXEC_NO_UNIQUE_ADDRESS Fun fun_;
template <std::size_t Index, typename... Types>
using nth_type_of = std::tuple_element_t<Index, std::tuple<Types...>>;
template <typename T>
void print_type_name() const {
stdexec::print(std::declval<stdexec::__detail::__name_of<T>>());
}
template <typename T>
struct print_the_type;
// This shouldn't be here. Imo I think algorithm_base should
// have a __data struct that each inheritor is responsible for providing. I put this here to get things to compile.
template <class _InitT, class _Fun>
struct __data {
_InitT __initT_;
STDEXEC_NO_UNIQUE_ADDRESS _Fun __fun_;
static constexpr auto __mbrs_ = __mliterals<&__data::__initT_, &__data::__fun_>();
};
template <class _InitT, class _Fun>
__data(_InitT, _Fun) -> __data<_InitT, _Fun>;
// this is basically the apply function that sender_apply is looking for.
template <class S, class Fn>
auto plscompile(S s, Fn f) {
auto inside = s.sndr_;
auto data = __data(s.init_, s.fun_);
// stdexec::__detail::__name_of<decltype(inside)> hi;
auto invoked = f(std::declval<Tag>(), data, inside);
return invoked;
}
template <class Self, class Env>
using completion_signatures = //
__try_make_completion_signatures<
__copy_cvref_t<Self, Sender>,
Env,
completion_signatures<set_error_t(cudaError_t)>,
__q<_set_value_t >>;
template <__decays_to<__t> Self, receiver Receiver>
requires receiver_of<Receiver, completion_signatures<Self, env_of_t<Receiver>>>
friend auto tag_invoke(connect_t, Self&& self, Receiver rcvr)
-> stream_op_state_t< __copy_cvref_t<Self, Sender>, receiver_t<Receiver>, Receiver> {
return stream_op_state<__copy_cvref_t<Self, Sender>>(
((Self&&) self).sndr_,
(Receiver&&) rcvr,
[&](operation_state_base_t<stdexec::__id<Receiver>>& stream_provider)
-> receiver_t<Receiver> {
return receiver_t<Receiver>(self.init_, self.fun_, stream_provider);
});
}
template <__decays_to<__t> Self, class Env>
friend auto tag_invoke(get_completion_signatures_t, Self&&, Env&&)
-> completion_signatures<Self, Env> {
return {};
}
friend auto tag_invoke(get_env_t, const __t& self) noexcept -> env_of_t<const Sender&> {
return get_env(self.sndr_);
}
};
};
}