xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm.hpp (revision 3c6c2bd607ccbc6dbafa3d0fbe34551e4abda79c)
1 #pragma once
2 
3 #include <petsc/private/veccupmimpl.h>
4 #include <petsc/private/cpp/utility.hpp> // util::index_sequence
5 
6 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> // grid_stride_1D()
7 #include <../src/vec/vec/impls/dvecimpl.h>                  // Vec_Seq
8 
9 namespace Petsc
10 {
11 
12 namespace vec
13 {
14 
15 namespace cupm
16 {
17 
18 namespace impl
19 {
20 
21 // ==========================================================================================
22 // VecSeq_CUPM
23 // ==========================================================================================
24 
25 template <device::cupm::DeviceType T>
26 class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
27 public:
28   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
29 
30 private:
31   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
32   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
33   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
34 
35   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
36   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
37   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
38   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
39 
40   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
41 
42   // common core for min and max
43   template <typename TupleFuncT, typename UnaryFuncT>
44   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
45   // common core for pointwise binary and pointwise unary thrust functions
46   template <typename BinaryFuncT>
47   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
48   template <typename BinaryFuncT>
49   static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
50   template <typename UnaryFuncT>
51   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
52   // mdot dispatchers
53   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
54   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
55   template <std::size_t... Idx>
56   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
57   template <int>
58   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
59   template <std::size_t... Idx>
60   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
61   template <int>
62   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
63   // common core for the various create routines
64   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
65 
66 public:
67   // callable directly via a bespoke function
68   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
69   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
70 
71   static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept;
72   static PetscErrorCode ClearAsyncFunctions(Vec) noexcept;
73 
74   // callable indirectly via function pointers
75   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
76   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
77   static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
78   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
79   static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
80   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
81   static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
82   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
83   static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
84   static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept;
85   static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
86   static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept;
87   static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
88   static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept;
89   static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
90   static PetscErrorCode Reciprocal(Vec) noexcept;
91   static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept;
92   static PetscErrorCode Abs(Vec) noexcept;
93   static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept;
94   static PetscErrorCode SqrtAbs(Vec) noexcept;
95   static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept;
96   static PetscErrorCode Exp(Vec) noexcept;
97   static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept;
98   static PetscErrorCode Log(Vec) noexcept;
99   static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept;
100   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
101   static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
102   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
103   static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept;
104   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
105   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
106   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
107   static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
108   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
109   static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
110   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
111   static PetscErrorCode Copy(Vec, Vec) noexcept;
112   static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept;
113   static PetscErrorCode Swap(Vec, Vec) noexcept;
114   static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept;
115   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
116   static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept;
117   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
118   static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
119   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
120   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
121   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
122   static PetscErrorCode Conjugate(Vec) noexcept;
123   static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept;
124   template <PetscMemoryAccessMode>
125   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
126   template <PetscMemoryAccessMode>
127   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
128   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
129   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
130   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
131   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
132   static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
133   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
134   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
135   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
136   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
137 };
138 
139 namespace kernels
140 {
141 
142 template <typename F>
add_coo_values_impl(const PetscScalar * PETSC_RESTRICT vv,PetscCount n,const PetscCount * PETSC_RESTRICT jmap,const PetscCount * PETSC_RESTRICT perm,InsertMode imode,PetscScalar * PETSC_RESTRICT xv,F && xvindex)143 PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
144 {
145   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
146     const auto  end = jmap[i + 1];
147     const auto  idx = xvindex(i);
148     PetscScalar sum = 0.0;
149 
150     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
151 
152     if (imode == INSERT_VALUES) {
153       xv[idx] = sum;
154     } else {
155       xv[idx] += sum;
156     }
157   });
158   return;
159 }
160 
161 namespace
162 {
163 PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wunused-function")
add_coo_values(const PetscScalar * PETSC_RESTRICT v,PetscCount n,const PetscCount * PETSC_RESTRICT jmap1,const PetscCount * PETSC_RESTRICT perm1,InsertMode imode,PetscScalar * PETSC_RESTRICT xv)164 PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
165 {
166   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
167   return;
168 }
169 PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END()
170 } // namespace
171 
172 } // namespace kernels
173 
174 } // namespace impl
175 
176 // ==========================================================================================
177 // VecSeq_CUPM - Implementations
178 // ==========================================================================================
179 
180 template <device::cupm::DeviceType T>
VecCreateSeqCUPMAsync(MPI_Comm comm,PetscInt n,Vec * v)181 inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
182 {
183   PetscFunctionBegin;
184   PetscAssertPointer(v, 4);
185   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
186   PetscFunctionReturn(PETSC_SUCCESS);
187 }
188 
189 template <device::cupm::DeviceType T>
VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm,PetscInt bs,PetscInt n,const PetscScalar cpuarray[],const PetscScalar gpuarray[],Vec * v)190 inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
191 {
192   PetscFunctionBegin;
193   if (n && cpuarray) PetscAssertPointer(cpuarray, 4);
194   PetscAssertPointer(v, 6);
195   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
196   PetscFunctionReturn(PETSC_SUCCESS);
197 }
198 
199 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
VecCUPMGetArrayAsync_Private(Vec v,PetscScalar ** a,PetscDeviceContext dctx)200 inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
201 {
202   PetscFunctionBegin;
203   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
204   PetscAssertPointer(a, 2);
205   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
206   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
207   PetscFunctionReturn(PETSC_SUCCESS);
208 }
209 
210 template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
VecCUPMRestoreArrayAsync_Private(Vec v,PetscScalar ** a,PetscDeviceContext dctx)211 inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
212 {
213   PetscFunctionBegin;
214   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
215   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
216   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
217   PetscFunctionReturn(PETSC_SUCCESS);
218 }
219 
220 template <device::cupm::DeviceType T>
VecCUPMGetArrayAsync(Vec v,PetscScalar ** a,PetscDeviceContext dctx=nullptr)221 inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
222 {
223   PetscFunctionBegin;
224   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
225   PetscFunctionReturn(PETSC_SUCCESS);
226 }
227 
228 template <device::cupm::DeviceType T>
VecCUPMRestoreArrayAsync(Vec v,PetscScalar ** a,PetscDeviceContext dctx=nullptr)229 inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
230 {
231   PetscFunctionBegin;
232   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
233   PetscFunctionReturn(PETSC_SUCCESS);
234 }
235 
236 template <device::cupm::DeviceType T>
VecCUPMGetArrayReadAsync(Vec v,const PetscScalar ** a,PetscDeviceContext dctx=nullptr)237 inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
238 {
239   PetscFunctionBegin;
240   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
241   PetscFunctionReturn(PETSC_SUCCESS);
242 }
243 
244 template <device::cupm::DeviceType T>
VecCUPMRestoreArrayReadAsync(Vec v,const PetscScalar ** a,PetscDeviceContext dctx=nullptr)245 inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
246 {
247   PetscFunctionBegin;
248   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
249   PetscFunctionReturn(PETSC_SUCCESS);
250 }
251 
252 template <device::cupm::DeviceType T>
VecCUPMGetArrayWriteAsync(Vec v,PetscScalar ** a,PetscDeviceContext dctx=nullptr)253 inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
254 {
255   PetscFunctionBegin;
256   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
257   PetscFunctionReturn(PETSC_SUCCESS);
258 }
259 
260 template <device::cupm::DeviceType T>
VecCUPMRestoreArrayWriteAsync(Vec v,PetscScalar ** a,PetscDeviceContext dctx=nullptr)261 inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
262 {
263   PetscFunctionBegin;
264   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
265   PetscFunctionReturn(PETSC_SUCCESS);
266 }
267 
268 template <device::cupm::DeviceType T>
VecCUPMPlaceArrayAsync(Vec vin,const PetscScalar a[])269 inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
270 {
271   PetscFunctionBegin;
272   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
273   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
274   PetscFunctionReturn(PETSC_SUCCESS);
275 }
276 
277 template <device::cupm::DeviceType T>
VecCUPMReplaceArrayAsync(Vec vin,const PetscScalar a[])278 inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
279 {
280   PetscFunctionBegin;
281   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
282   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
283   PetscFunctionReturn(PETSC_SUCCESS);
284 }
285 
286 template <device::cupm::DeviceType T>
VecCUPMResetArrayAsync(Vec vin)287 inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
288 {
289   PetscFunctionBegin;
290   PetscValidHeaderSpecific(vin, VEC_CLASSID, 1);
291   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
292   PetscFunctionReturn(PETSC_SUCCESS);
293 }
294 
295 } // namespace cupm
296 
297 } // namespace vec
298 
299 } // namespace Petsc
300 
301 #if PetscDefined(HAVE_CUDA)
302 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
303 #endif
304 
305 #if PetscDefined(HAVE_HIP)
306 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
307 #endif
308