1 #pragma once
2
3 #include <petsc/private/veccupmimpl.h>
4 #include <petsc/private/cpp/utility.hpp> // util::index_sequence
5
6 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> // grid_stride_1D()
7 #include <../src/vec/vec/impls/dvecimpl.h> // Vec_Seq
8
9 namespace Petsc
10 {
11
12 namespace vec
13 {
14
15 namespace cupm
16 {
17
18 namespace impl
19 {
20
21 // ==========================================================================================
22 // VecSeq_CUPM
23 // ==========================================================================================
24
25 template <device::cupm::DeviceType T>
26 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
27 public:
28 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
29
30 private:
31 PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept;
32 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
33 PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
34
35 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
36 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
37 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
38 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
39
40 static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
41
42 // common core for min and max
43 template <typename TupleFuncT, typename UnaryFuncT>
44 static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
45 // common core for pointwise binary and pointwise unary thrust functions
46 template <typename BinaryFuncT>
47 static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
48 template <typename BinaryFuncT>
49 static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
50 template <typename UnaryFuncT>
51 static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
52 // mdot dispatchers
53 static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
54 static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
55 template <std::size_t... Idx>
56 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
57 template <int>
58 static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
59 template <std::size_t... Idx>
60 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
61 template <int>
62 static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
63 // common core for the various create routines
64 static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
65
66 public:
67 // callable directly via a bespoke function
68 static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
69 static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
70
71 static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept;
72 static PetscErrorCode ClearAsyncFunctions(Vec) noexcept;
73
74 // callable indirectly via function pointers
75 static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
76 static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
77 static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
78 static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
79 static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
80 static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
81 static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
82 static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
83 static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
84 static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept;
85 static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
86 static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept;
87 static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
88 static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept;
89 static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
90 static PetscErrorCode Reciprocal(Vec) noexcept;
91 static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept;
92 static PetscErrorCode Abs(Vec) noexcept;
93 static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept;
94 static PetscErrorCode SqrtAbs(Vec) noexcept;
95 static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept;
96 static PetscErrorCode Exp(Vec) noexcept;
97 static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept;
98 static PetscErrorCode Log(Vec) noexcept;
99 static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept;
100 static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
101 static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
102 static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
103 static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept;
104 static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
105 static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
106 static PetscErrorCode Set(Vec, PetscScalar) noexcept;
107 static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
108 static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
109 static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
110 static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
111 static PetscErrorCode Copy(Vec, Vec) noexcept;
112 static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept;
113 static PetscErrorCode Swap(Vec, Vec) noexcept;
114 static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept;
115 static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
116 static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept;
117 static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
118 static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
119 static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
120 static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
121 static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
122 static PetscErrorCode Conjugate(Vec) noexcept;
123 static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept;
124 template <PetscMemoryAccessMode>
125 static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
126 template <PetscMemoryAccessMode>
127 static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
128 static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
129 static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
130 static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
131 static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
132 static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
133 static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
134 static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
135 static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
136 static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
137 };
138
139 namespace kernels
140 {
141
142 template <typename F>
add_coo_values_impl(const PetscScalar * PETSC_RESTRICT vv,PetscCount n,const PetscCount * PETSC_RESTRICT jmap,const PetscCount * PETSC_RESTRICT perm,InsertMode imode,PetscScalar * PETSC_RESTRICT xv,F && xvindex)143 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
144 {
145 ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
146 const auto end = jmap[i + 1];
147 const auto idx = xvindex(i);
148 PetscScalar sum = 0.0;
149
150 for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
151
152 if (imode == INSERT_VALUES) {
153 xv[idx] = sum;
154 } else {
155 xv[idx] += sum;
156 }
157 });
158 return;
159 }
160
161 namespace
162 {
163 PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wunused-function")
add_coo_values(const PetscScalar * PETSC_RESTRICT v,PetscCount n,const PetscCount * PETSC_RESTRICT jmap1,const PetscCount * PETSC_RESTRICT perm1,InsertMode imode,PetscScalar * PETSC_RESTRICT xv)164 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
165 {
166 add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
167 return;
168 }
169 PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END()
170 } // namespace
171
172 } // namespace kernels
173
174 } // namespace impl
175
176 // ==========================================================================================
177 // VecSeq_CUPM - Implementations
178 // ==========================================================================================
179
180 template <device::cupm::DeviceType T>
VecCreateSeqCUPMAsync(MPI_Comm comm,PetscInt n,Vec * v)181 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
182 {
183 PetscFunctionBegin;
184 PetscAssertPointer(v, 4);
185 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
186 PetscFunctionReturn(PETSC_SUCCESS);
187 }
188
189 template <device::cupm::DeviceType T>
VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm,PetscInt bs,PetscInt n,const PetscScalar cpuarray[],const PetscScalar gpuarray[],Vec * v)190 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
191 {
192 PetscFunctionBegin;
193 if (n && cpuarray) PetscAssertPointer(cpuarray, 4);
194 PetscAssertPointer(v, 6);
195 PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
196 PetscFunctionReturn(PETSC_SUCCESS);
197 }
198
199 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
VecCUPMGetArrayAsync_Private(Vec v,PetscScalar ** a,PetscDeviceContext dctx)200 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
201 {
202 PetscFunctionBegin;
203 PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
204 PetscAssertPointer(a, 2);
205 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
206 PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
207 PetscFunctionReturn(PETSC_SUCCESS);
208 }
209
210 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
VecCUPMRestoreArrayAsync_Private(Vec v,PetscScalar ** a,PetscDeviceContext dctx)211 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
212 {
213 PetscFunctionBegin;
214 PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
215 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
216 PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
217 PetscFunctionReturn(PETSC_SUCCESS);
218 }
219
220 template <device::cupm::DeviceType T>
VecCUPMGetArrayAsync(Vec v,PetscScalar ** a,PetscDeviceContext dctx=nullptr)221 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
222 {
223 PetscFunctionBegin;
224 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
225 PetscFunctionReturn(PETSC_SUCCESS);
226 }
227
228 template <device::cupm::DeviceType T>
VecCUPMRestoreArrayAsync(Vec v,PetscScalar ** a,PetscDeviceContext dctx=nullptr)229 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
230 {
231 PetscFunctionBegin;
232 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
233 PetscFunctionReturn(PETSC_SUCCESS);
234 }
235
236 template <device::cupm::DeviceType T>
VecCUPMGetArrayReadAsync(Vec v,const PetscScalar ** a,PetscDeviceContext dctx=nullptr)237 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
238 {
239 PetscFunctionBegin;
240 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
241 PetscFunctionReturn(PETSC_SUCCESS);
242 }
243
244 template <device::cupm::DeviceType T>
VecCUPMRestoreArrayReadAsync(Vec v,const PetscScalar ** a,PetscDeviceContext dctx=nullptr)245 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
246 {
247 PetscFunctionBegin;
248 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
249 PetscFunctionReturn(PETSC_SUCCESS);
250 }
251
252 template <device::cupm::DeviceType T>
VecCUPMGetArrayWriteAsync(Vec v,PetscScalar ** a,PetscDeviceContext dctx=nullptr)253 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
254 {
255 PetscFunctionBegin;
256 PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
257 PetscFunctionReturn(PETSC_SUCCESS);
258 }
259
260 template <device::cupm::DeviceType T>
VecCUPMRestoreArrayWriteAsync(Vec v,PetscScalar ** a,PetscDeviceContext dctx=nullptr)261 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
262 {
263 PetscFunctionBegin;
264 PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
265 PetscFunctionReturn(PETSC_SUCCESS);
266 }
267
268 template <device::cupm::DeviceType T>
VecCUPMPlaceArrayAsync(Vec vin,const PetscScalar a[])269 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
270 {
271 PetscFunctionBegin;
272 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
273 PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
274 PetscFunctionReturn(PETSC_SUCCESS);
275 }
276
277 template <device::cupm::DeviceType T>
VecCUPMReplaceArrayAsync(Vec vin,const PetscScalar a[])278 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
279 {
280 PetscFunctionBegin;
281 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
282 PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
283 PetscFunctionReturn(PETSC_SUCCESS);
284 }
285
286 template <device::cupm::DeviceType T>
VecCUPMResetArrayAsync(Vec vin)287 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
288 {
289 PetscFunctionBegin;
290 PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
291 PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
292 PetscFunctionReturn(PETSC_SUCCESS);
293 }
294
295 } // namespace cupm
296
297 } // namespace vec
298
299 } // namespace Petsc
300
301 #if PetscDefined(HAVE_CUDA)
302 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
303 #endif
304
305 #if PetscDefined(HAVE_HIP)
306 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
307 #endif
308