xref: /petsc/src/vec/vec/impls/seq/cupm/vecseqcupm_impl.hpp (revision ede9db9363e1fdaaa09befd664c8164883ccce80)
1 #pragma once
2 
3 #include "vecseqcupm.hpp"
4 
5 #include <petsc/private/randomimpl.h> // for _p_PetscRandom
6 
7 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
8 #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
9 
10 #if PetscDefined(USE_COMPLEX)
11   #include <thrust/transform_reduce.h>
12 #endif
13 #include <thrust/transform.h>
14 #include <thrust/reduce.h>
15 #include <thrust/functional.h>
16 #include <thrust/tuple.h>
17 #include <thrust/device_ptr.h>
18 #include <thrust/iterator/zip_iterator.h>
19 #include <thrust/iterator/counting_iterator.h>
20 #include <thrust/iterator/constant_iterator.h>
21 #include <thrust/inner_product.h>
22 
23 namespace Petsc
24 {
25 
26 namespace vec
27 {
28 
29 namespace cupm
30 {
31 
32 namespace impl
33 {
34 
35 // ==========================================================================================
36 // VecSeq_CUPM - Private API
37 // ==========================================================================================
38 
39 template <device::cupm::DeviceType T>
VecIMPLCast_(Vec v)40 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
41 {
42   return static_cast<Vec_Seq *>(v->data);
43 }
44 
45 template <device::cupm::DeviceType T>
VECIMPLCUPM_()46 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
47 {
48   return VECSEQCUPM();
49 }
50 
51 template <device::cupm::DeviceType T>
VECIMPL_()52 inline constexpr VecType VecSeq_CUPM<T>::VECIMPL_() noexcept
53 {
54   return VECSEQ;
55 }
56 
57 template <device::cupm::DeviceType T>
ClearAsyncFunctions(Vec v)58 inline PetscErrorCode VecSeq_CUPM<T>::ClearAsyncFunctions(Vec v) noexcept
59 {
60   PetscFunctionBegin;
61   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Abs), nullptr));
62   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBY), nullptr));
63   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBYPCZ), nullptr));
64   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPY), nullptr));
65   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AYPX), nullptr));
66   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Conjugate), nullptr));
67   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Copy), nullptr));
68   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Exp), nullptr));
69   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Log), nullptr));
70   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(MAXPY), nullptr));
71   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseDivide), nullptr));
72   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMax), nullptr));
73   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMaxAbs), nullptr));
74   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMin), nullptr));
75   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMult), nullptr));
76   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Reciprocal), nullptr));
77   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Scale), nullptr));
78   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Set), nullptr));
79   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Shift), nullptr));
80   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(SqrtAbs), nullptr));
81   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Swap), nullptr));
82   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(WAXPY), nullptr));
83   PetscFunctionReturn(PETSC_SUCCESS);
84 }
85 
86 template <device::cupm::DeviceType T>
InitializeAsyncFunctions(Vec v)87 inline PetscErrorCode VecSeq_CUPM<T>::InitializeAsyncFunctions(Vec v) noexcept
88 {
89   PetscFunctionBegin;
90   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Abs), VecSeq_CUPM<T>::AbsAsync));
91   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBY), VecSeq_CUPM<T>::AXPBYAsync));
92   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBYPCZ), VecSeq_CUPM<T>::AXPBYPCZAsync));
93   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPY), VecSeq_CUPM<T>::AXPYAsync));
94   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AYPX), VecSeq_CUPM<T>::AYPXAsync));
95   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Conjugate), VecSeq_CUPM<T>::ConjugateAsync));
96   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Copy), VecSeq_CUPM<T>::CopyAsync));
97   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Exp), VecSeq_CUPM<T>::ExpAsync));
98   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Log), VecSeq_CUPM<T>::LogAsync));
99   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(MAXPY), VecSeq_CUPM<T>::MAXPYAsync));
100   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseDivide), VecSeq_CUPM<T>::PointwiseDivideAsync));
101   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMax), VecSeq_CUPM<T>::PointwiseMaxAsync));
102   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMaxAbs), VecSeq_CUPM<T>::PointwiseMaxAbsAsync));
103   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMin), VecSeq_CUPM<T>::PointwiseMinAsync));
104   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMult), VecSeq_CUPM<T>::PointwiseMultAsync));
105   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Reciprocal), VecSeq_CUPM<T>::ReciprocalAsync));
106   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Scale), VecSeq_CUPM<T>::ScaleAsync));
107   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Set), VecSeq_CUPM<T>::SetAsync));
108   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Shift), VecSeq_CUPM<T>::ShiftAsync));
109   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(SqrtAbs), VecSeq_CUPM<T>::SqrtAbsAsync));
110   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Swap), VecSeq_CUPM<T>::SwapAsync));
111   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(WAXPY), VecSeq_CUPM<T>::WAXPYAsync));
112   PetscFunctionReturn(PETSC_SUCCESS);
113 }
114 
115 template <device::cupm::DeviceType T>
VecDestroy_IMPL_(Vec v)116 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
117 {
118   PetscFunctionBegin;
119   PetscCall(ClearAsyncFunctions(v));
120   PetscCall(VecDestroy_Seq(v));
121   PetscFunctionReturn(PETSC_SUCCESS);
122 }
123 
124 template <device::cupm::DeviceType T>
VecResetArray_IMPL_(Vec v)125 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
126 {
127   return VecResetArray_Seq(v);
128 }
129 
130 template <device::cupm::DeviceType T>
VecPlaceArray_IMPL_(Vec v,const PetscScalar * a)131 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
132 {
133   return VecPlaceArray_Seq(v, a);
134 }
135 
136 template <device::cupm::DeviceType T>
VecCreate_IMPL_Private_(Vec v,PetscBool * alloc_missing,PetscInt,PetscScalar * host_array)137 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
138 {
139   PetscMPIInt size;
140 
141   PetscFunctionBegin;
142   if (alloc_missing) *alloc_missing = PETSC_FALSE;
143   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
144   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
145   PetscCall(VecCreate_Seq_Private(v, host_array));
146   PetscCall(InitializeAsyncFunctions(v));
147   PetscFunctionReturn(PETSC_SUCCESS);
148 }
149 
150 // for functions with an early return based one vec size we still need to artificially bump the
151 // object state. This is to prevent the following:
152 //
153 // 0. Suppose you have a Vec {
154 //   rank 0: [0],
155 //   rank 1: [<empty>]
156 // }
157 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
158 // 2. Vec enters e.g. VecSet(10)
159 // 3. rank 1 has local size 0 and bails immediately
160 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
161 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
162 // 6. Vec enters VecNorm(), and calls VecNormAvailable()
163 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
164 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
165 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
166 template <device::cupm::DeviceType T>
MaybeIncrementEmptyLocalVec(Vec v)167 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
168 {
169   PetscFunctionBegin;
170   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
171   PetscFunctionReturn(PETSC_SUCCESS);
172 }
173 
174 template <device::cupm::DeviceType T>
CreateSeqCUPM_(Vec v,PetscDeviceContext dctx,PetscScalar * host_array,PetscScalar * device_array)175 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
176 {
177   PetscFunctionBegin;
178   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
179   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
180   PetscFunctionReturn(PETSC_SUCCESS);
181 }
182 
183 template <device::cupm::DeviceType T>
184 template <typename BinaryFuncT>
PointwiseBinary_(BinaryFuncT && binary,Vec xin,Vec yin,Vec zout,PetscDeviceContext dctx)185 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout, PetscDeviceContext dctx) noexcept
186 {
187   PetscFunctionBegin;
188   if (const auto n = zout->map->n) {
189     cupmStream_t stream;
190 
191     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
192     PetscCall(GetHandlesFrom_(dctx, &stream));
193     // clang-format off
194     PetscCallThrust(
195       const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data());
196 
197       THRUST_CALL(
198         thrust::transform,
199         stream,
200         dxptr, dxptr + n,
201         thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()),
202         thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()),
203         std::forward<BinaryFuncT>(binary)
204       )
205     );
206     // clang-format on
207     PetscCall(PetscLogGpuFlops(n));
208     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
209   } else {
210     PetscCall(MaybeIncrementEmptyLocalVec(zout));
211   }
212   PetscFunctionReturn(PETSC_SUCCESS);
213 }
214 
215 template <device::cupm::DeviceType T>
216 template <typename BinaryFuncT>
PointwiseBinaryDispatch_(PetscErrorCode (* VecSeqFunction)(Vec,Vec,Vec),BinaryFuncT && binary,Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)217 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinaryDispatch_(PetscErrorCode (*VecSeqFunction)(Vec, Vec, Vec), BinaryFuncT &&binary, Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
218 {
219   PetscFunctionBegin;
220   if (xin->boundtocpu || yin->boundtocpu) {
221     PetscCall((*VecSeqFunction)(wout, xin, yin));
222   } else {
223     // note order of arguments! xin and yin are read, wout is written!
224     PetscCall(PointwiseBinary_(std::forward<BinaryFuncT>(binary), xin, yin, wout, dctx));
225   }
226   PetscFunctionReturn(PETSC_SUCCESS);
227 }
228 
229 template <device::cupm::DeviceType T>
230 template <typename UnaryFuncT>
PointwiseUnary_(UnaryFuncT && unary,Vec xinout,Vec yin,PetscDeviceContext dctx)231 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin, PetscDeviceContext dctx) noexcept
232 {
233   const auto inplace = !yin || (xinout == yin);
234 
235   PetscFunctionBegin;
236   if (const auto n = xinout->map->n) {
237     cupmStream_t stream;
238     const auto   apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) {
239       PetscFunctionBegin;
240       // clang-format off
241       PetscCallThrust(
242         const auto xptr = thrust::device_pointer_cast(xinout);
243 
244         THRUST_CALL(
245           thrust::transform,
246           stream,
247           xptr, xptr + n,
248           (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr,
249           std::forward<UnaryFuncT>(unary)
250         )
251       );
252       // clang-format on
253       PetscFunctionReturn(PETSC_SUCCESS);
254     };
255 
256     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
257     PetscCall(GetHandlesFrom_(dctx, &stream));
258     if (inplace) {
259       PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data()));
260     } else {
261       PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
262     }
263     PetscCall(PetscLogGpuFlops(n));
264     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
265   } else {
266     if (inplace) {
267       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
268     } else {
269       PetscCall(MaybeIncrementEmptyLocalVec(yin));
270     }
271   }
272   PetscFunctionReturn(PETSC_SUCCESS);
273 }
274 
275 // ==========================================================================================
276 // VecSeq_CUPM - Public API - Constructors
277 // ==========================================================================================
278 
279 // VecCreateSeqCUPM()
280 template <device::cupm::DeviceType T>
CreateSeqCUPM(MPI_Comm comm,PetscInt bs,PetscInt n,Vec * v,PetscBool call_set_type)281 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
282 {
283   PetscFunctionBegin;
284   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
285   PetscFunctionReturn(PETSC_SUCCESS);
286 }
287 
288 // VecCreateSeqCUPMWithArrays()
289 template <device::cupm::DeviceType T>
CreateSeqCUPMWithBothArrays(MPI_Comm comm,PetscInt bs,PetscInt n,const PetscScalar host_array[],const PetscScalar device_array[],Vec * v)290 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
291 {
292   PetscDeviceContext dctx;
293 
294   PetscFunctionBegin;
295   PetscCall(GetHandles_(&dctx));
296   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
297   // CreateSeqCUPM_() is called!
298   PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE));
299   PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
300   PetscFunctionReturn(PETSC_SUCCESS);
301 }
302 
303 // v->ops->duplicate
304 template <device::cupm::DeviceType T>
Duplicate(Vec v,Vec * y)305 inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
306 {
307   PetscDeviceContext dctx;
308 
309   PetscFunctionBegin;
310   PetscCall(GetHandles_(&dctx));
311   PetscCall(Duplicate_CUPMBase(v, y, dctx));
312   PetscFunctionReturn(PETSC_SUCCESS);
313 }
314 
315 // ==========================================================================================
316 // VecSeq_CUPM - Public API - Utility
317 // ==========================================================================================
318 
319 // v->ops->bindtocpu
320 template <device::cupm::DeviceType T>
BindToCPU(Vec v,PetscBool usehost)321 inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
322 {
323   PetscDeviceContext dctx;
324 
325   PetscFunctionBegin;
326   PetscCall(GetHandles_(&dctx));
327   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));
328 
329   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
330   VecSetOp_CUPM(dot, VecDot_Seq, Dot);
331   VecSetOp_CUPM(norm, VecNorm_Seq, Norm);
332   VecSetOp_CUPM(tdot, VecTDot_Seq, TDot);
333   VecSetOp_CUPM(mdot, VecMDot_Seq, MDot);
334   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
335   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
336   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
337   VecSetOp_CUPM(max, VecMax_Seq, Max);
338   VecSetOp_CUPM(min, VecMin_Seq, Min);
339   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO);
340   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO);
341   PetscFunctionReturn(PETSC_SUCCESS);
342 }
343 
344 // ==========================================================================================
345 // VecSeq_CUPM - Public API - Mutators
346 // ==========================================================================================
347 
348 // v->ops->getlocalvector or v->ops->getlocalvectorread
349 template <device::cupm::DeviceType T>
350 template <PetscMemoryAccessMode access>
GetLocalVector(Vec v,Vec w)351 inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept
352 {
353   PetscBool wisseqcupm;
354 
355   PetscFunctionBegin;
356   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
357   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
358   if (wisseqcupm) {
359     if (const auto wseq = VecIMPLCast(w)) {
360       if (auto &alloced = wseq->array_allocated) {
361         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));
362 
363         PetscCall(PetscFree(alloced));
364       }
365       wseq->array         = nullptr;
366       wseq->unplacedarray = nullptr;
367     }
368     if (const auto wcu = VecCUPMCast(w)) {
369       if (auto &device_array = wcu->array_d) {
370         cupmStream_t stream;
371 
372         PetscCall(GetHandles_(&stream));
373         PetscCallCUPM(cupmFreeAsync(device_array, stream));
374       }
375       PetscCall(PetscFree(w->spptr /* wcu */));
376     }
377   }
378   if (v->petscnative && wisseqcupm) {
379     PetscCall(PetscFree(w->data));
380     w->data          = v->data;
381     w->offloadmask   = v->offloadmask;
382     w->pinned_memory = v->pinned_memory;
383     w->spptr         = v->spptr;
384     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
385   } else {
386     const auto array = &VecIMPLCast(w)->array;
387 
388     if (access == PETSC_MEMORY_ACCESS_READ) {
389       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
390     } else {
391       PetscCall(VecGetArray(v, array));
392     }
393     w->offloadmask = PETSC_OFFLOAD_CPU;
394     if (wisseqcupm) {
395       PetscDeviceContext dctx;
396 
397       PetscCall(GetHandles_(&dctx));
398       PetscCall(DeviceAllocateCheck_(dctx, w));
399     }
400   }
401   PetscFunctionReturn(PETSC_SUCCESS);
402 }
403 
404 // v->ops->restorelocalvector or v->ops->restorelocalvectorread
405 template <device::cupm::DeviceType T>
406 template <PetscMemoryAccessMode access>
RestoreLocalVector(Vec v,Vec w)407 inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept
408 {
409   PetscBool wisseqcupm;
410 
411   PetscFunctionBegin;
412   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
413   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
414   if (v->petscnative && wisseqcupm) {
415     // the assignments to nullptr are __critical__, as w may persist after this call returns
416     // and shouldn't share data with v!
417     v->pinned_memory = w->pinned_memory;
418     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
419     v->data          = util::exchange(w->data, nullptr);
420     v->spptr         = util::exchange(w->spptr, nullptr);
421   } else {
422     const auto array = &VecIMPLCast(w)->array;
423 
424     if (access == PETSC_MEMORY_ACCESS_READ) {
425       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
426     } else {
427       PetscCall(VecRestoreArray(v, array));
428     }
429     if (w->spptr && wisseqcupm) {
430       cupmStream_t stream;
431 
432       PetscCall(GetHandles_(&stream));
433       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
434       PetscCall(PetscFree(w->spptr));
435     }
436   }
437   PetscFunctionReturn(PETSC_SUCCESS);
438 }
439 
440 // ==========================================================================================
441 // VecSeq_CUPM - Public API - Compute Methods
442 // ==========================================================================================
443 
444 // VecAYPXAsync_Private
445 template <device::cupm::DeviceType T>
AYPXAsync(Vec yin,PetscScalar alpha,Vec xin,PetscDeviceContext dctx)446 inline PetscErrorCode VecSeq_CUPM<T>::AYPXAsync(Vec yin, PetscScalar alpha, Vec xin, PetscDeviceContext dctx) noexcept
447 {
448   const auto n = static_cast<cupmBlasInt_t>(yin->map->n);
449   PetscBool  xiscupm;
450 
451   PetscFunctionBegin;
452   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
453   if (!xiscupm) {
454     PetscCall(VecAYPX_Seq(yin, alpha, xin));
455     PetscFunctionReturn(PETSC_SUCCESS);
456   }
457   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
458   if (alpha == PetscScalar(0.0)) {
459     cupmStream_t stream;
460 
461     PetscCall(GetHandlesFrom_(dctx, &stream));
462     PetscCall(PetscLogGpuTimeBegin());
463     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
464     PetscCall(PetscLogGpuTimeEnd());
465   } else if (n) {
466     const auto       alphaIsOne = alpha == PetscScalar(1.0);
467     const auto       calpha     = cupmScalarPtrCast(&alpha);
468     cupmBlasHandle_t cupmBlasHandle;
469 
470     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
471     {
472       const auto yptr = DeviceArrayReadWrite(dctx, yin);
473       const auto xptr = DeviceArrayRead(dctx, xin);
474 
475       PetscCall(PetscLogGpuTimeBegin());
476       if (alphaIsOne) {
477         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
478       } else {
479         const auto one = cupmScalarCast(1.0);
480 
481         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
482         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
483       }
484       PetscCall(PetscLogGpuTimeEnd());
485     }
486     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
487   }
488   if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
489   PetscFunctionReturn(PETSC_SUCCESS);
490 }
491 
492 // v->ops->aypx
493 template <device::cupm::DeviceType T>
AYPX(Vec yin,PetscScalar alpha,Vec xin)494 inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept
495 {
496   PetscFunctionBegin;
497   PetscCall(AYPXAsync(yin, alpha, xin, nullptr));
498   PetscFunctionReturn(PETSC_SUCCESS);
499 }
500 
501 // VecAXPYAsync_Private
502 template <device::cupm::DeviceType T>
AXPYAsync(Vec yin,PetscScalar alpha,Vec xin,PetscDeviceContext dctx)503 inline PetscErrorCode VecSeq_CUPM<T>::AXPYAsync(Vec yin, PetscScalar alpha, Vec xin, PetscDeviceContext dctx) noexcept
504 {
505   PetscBool xiscupm;
506 
507   PetscFunctionBegin;
508   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
509   if (xiscupm) {
510     const auto       n = static_cast<cupmBlasInt_t>(yin->map->n);
511     cupmBlasHandle_t cupmBlasHandle;
512 
513     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
514     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
515     PetscCall(PetscLogGpuTimeBegin());
516     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
517     PetscCall(PetscLogGpuTimeEnd());
518     PetscCall(PetscLogGpuFlops(2 * n));
519     if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
520   } else {
521     PetscCall(VecAXPY_Seq(yin, alpha, xin));
522   }
523   PetscFunctionReturn(PETSC_SUCCESS);
524 }
525 
526 // v->ops->axpy
527 template <device::cupm::DeviceType T>
AXPY(Vec yin,PetscScalar alpha,Vec xin)528 inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept
529 {
530   PetscFunctionBegin;
531   PetscCall(AXPYAsync(yin, alpha, xin, nullptr));
532   PetscFunctionReturn(PETSC_SUCCESS);
533 }
534 
535 namespace detail
536 {
537 
538 struct divides {
operator ()Petsc::vec::cupm::impl::detail::divides539   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return rhs == PetscScalar{0.0} ? rhs : lhs / rhs; }
540 };
541 
542 } // namespace detail
543 
544 // VecPointwiseDivideAsync_Private
545 template <device::cupm::DeviceType T>
PointwiseDivideAsync(Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)546 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivideAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
547 {
548   PetscFunctionBegin;
549   PetscCall(PointwiseBinaryDispatch_(VecPointwiseDivide_Seq, detail::divides{}, wout, xin, yin, dctx));
550   PetscFunctionReturn(PETSC_SUCCESS);
551 }
552 
553 // v->ops->pointwisedivide
554 template <device::cupm::DeviceType T>
PointwiseDivide(Vec wout,Vec xin,Vec yin)555 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec wout, Vec xin, Vec yin) noexcept
556 {
557   PetscFunctionBegin;
558   PetscCall(PointwiseDivideAsync(wout, xin, yin, nullptr));
559   PetscFunctionReturn(PETSC_SUCCESS);
560 }
561 
562 // VecPointwiseMultAsync_Private
563 template <device::cupm::DeviceType T>
PointwiseMultAsync(Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)564 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMultAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
565 {
566   PetscFunctionBegin;
567   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMult_Seq, thrust::multiplies<PetscScalar>{}, wout, xin, yin, dctx));
568   PetscFunctionReturn(PETSC_SUCCESS);
569 }
570 
571 // v->ops->pointwisemult
572 template <device::cupm::DeviceType T>
PointwiseMult(Vec wout,Vec xin,Vec yin)573 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec wout, Vec xin, Vec yin) noexcept
574 {
575   PetscFunctionBegin;
576   PetscCall(PointwiseMultAsync(wout, xin, yin, nullptr));
577   PetscFunctionReturn(PETSC_SUCCESS);
578 }
579 
580 namespace detail
581 {
582 
583 struct MaximumRealPart {
operator ()Petsc::vec::cupm::impl::detail::MaximumRealPart584   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::maximum<PetscReal>{}(PetscRealPart(lhs), PetscRealPart(rhs)); }
585 };
586 
587 } // namespace detail
588 
589 // VecPointwiseMaxAsync_Private
590 template <device::cupm::DeviceType T>
PointwiseMaxAsync(Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)591 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
592 {
593   PetscFunctionBegin;
594   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMax_Seq, detail::MaximumRealPart{}, wout, xin, yin, dctx));
595   PetscFunctionReturn(PETSC_SUCCESS);
596 }
597 
598 // v->ops->pointwisemax
599 template <device::cupm::DeviceType T>
PointwiseMax(Vec wout,Vec xin,Vec yin)600 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMax(Vec wout, Vec xin, Vec yin) noexcept
601 {
602   PetscFunctionBegin;
603   PetscCall(PointwiseMaxAsync(wout, xin, yin, nullptr));
604   PetscFunctionReturn(PETSC_SUCCESS);
605 }
606 
607 namespace detail
608 {
609 
610 struct MaximumAbsoluteValue {
operator ()Petsc::vec::cupm::impl::detail::MaximumAbsoluteValue611   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::maximum<PetscReal>{}(PetscAbsScalar(lhs), PetscAbsScalar(rhs)); }
612 };
613 
614 } // namespace detail
615 
616 // VecPointwiseMaxAbsAsync_Private
617 template <device::cupm::DeviceType T>
PointwiseMaxAbsAsync(Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)618 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAbsAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
619 {
620   PetscFunctionBegin;
621   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMaxAbs_Seq, detail::MaximumAbsoluteValue{}, wout, xin, yin, dctx));
622   PetscFunctionReturn(PETSC_SUCCESS);
623 }
624 
625 // v->ops->pointwisemaxabs
626 template <device::cupm::DeviceType T>
PointwiseMaxAbs(Vec wout,Vec xin,Vec yin)627 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAbs(Vec wout, Vec xin, Vec yin) noexcept
628 {
629   PetscFunctionBegin;
630   PetscCall(PointwiseMaxAbsAsync(wout, xin, yin, nullptr));
631   PetscFunctionReturn(PETSC_SUCCESS);
632 }
633 
634 namespace detail
635 {
636 
637 struct MinimumRealPart {
operator ()Petsc::vec::cupm::impl::detail::MinimumRealPart638   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::minimum<PetscReal>{}(PetscRealPart(lhs), PetscRealPart(rhs)); }
639 };
640 
641 } // namespace detail
642 
643 // VecPointwiseMinAsync_Private
644 template <device::cupm::DeviceType T>
PointwiseMinAsync(Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)645 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMinAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
646 {
647   PetscFunctionBegin;
648   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMin_Seq, detail::MinimumRealPart{}, wout, xin, yin, dctx));
649   PetscFunctionReturn(PETSC_SUCCESS);
650 }
651 
652 // v->ops->pointwisemin
653 template <device::cupm::DeviceType T>
PointwiseMin(Vec wout,Vec xin,Vec yin)654 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMin(Vec wout, Vec xin, Vec yin) noexcept
655 {
656   PetscFunctionBegin;
657   PetscCall(PointwiseMinAsync(wout, xin, yin, nullptr));
658   PetscFunctionReturn(PETSC_SUCCESS);
659 }
660 
661 namespace detail
662 {
663 
664 struct Reciprocal {
operator ()Petsc::vec::cupm::impl::detail::Reciprocal665   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept
666   {
667     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
668     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
669     // everything in PetscScalar...
670     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
671   }
672 };
673 
674 } // namespace detail
675 
676 // VecReciprocalAsync_Private
677 template <device::cupm::DeviceType T>
ReciprocalAsync(Vec xin,PetscDeviceContext dctx)678 inline PetscErrorCode VecSeq_CUPM<T>::ReciprocalAsync(Vec xin, PetscDeviceContext dctx) noexcept
679 {
680   PetscFunctionBegin;
681   PetscCall(PointwiseUnary_(detail::Reciprocal{}, xin, nullptr, dctx));
682   PetscFunctionReturn(PETSC_SUCCESS);
683 }
684 
685 // v->ops->reciprocal
686 template <device::cupm::DeviceType T>
Reciprocal(Vec xin)687 inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept
688 {
689   PetscFunctionBegin;
690   PetscCall(ReciprocalAsync(xin, nullptr));
691   PetscFunctionReturn(PETSC_SUCCESS);
692 }
693 
694 namespace detail
695 {
696 
697 struct AbsoluteValue {
operator ()Petsc::vec::cupm::impl::detail::AbsoluteValue698   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscAbsScalar(s); }
699 };
700 
701 } // namespace detail
702 
703 // VecAbsAsync_Private
704 template <device::cupm::DeviceType T>
AbsAsync(Vec xin,PetscDeviceContext dctx)705 inline PetscErrorCode VecSeq_CUPM<T>::AbsAsync(Vec xin, PetscDeviceContext dctx) noexcept
706 {
707   PetscFunctionBegin;
708   PetscCall(PointwiseUnary_(detail::AbsoluteValue{}, xin, nullptr, dctx));
709   PetscFunctionReturn(PETSC_SUCCESS);
710 }
711 
712 // v->ops->abs
713 template <device::cupm::DeviceType T>
Abs(Vec xin)714 inline PetscErrorCode VecSeq_CUPM<T>::Abs(Vec xin) noexcept
715 {
716   PetscFunctionBegin;
717   PetscCall(AbsAsync(xin, nullptr));
718   PetscFunctionReturn(PETSC_SUCCESS);
719 }
720 
721 namespace detail
722 {
723 
724 struct SquareRootAbsoluteValue {
operator ()Petsc::vec::cupm::impl::detail::SquareRootAbsoluteValue725   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscSqrtReal(PetscAbsScalar(s)); }
726 };
727 
728 } // namespace detail
729 
730 // VecSqrtAbsAsync_Private
731 template <device::cupm::DeviceType T>
SqrtAbsAsync(Vec xin,PetscDeviceContext dctx)732 inline PetscErrorCode VecSeq_CUPM<T>::SqrtAbsAsync(Vec xin, PetscDeviceContext dctx) noexcept
733 {
734   PetscFunctionBegin;
735   PetscCall(PointwiseUnary_(detail::SquareRootAbsoluteValue{}, xin, nullptr, dctx));
736   PetscFunctionReturn(PETSC_SUCCESS);
737 }
738 
739 // v->ops->sqrt
740 template <device::cupm::DeviceType T>
SqrtAbs(Vec xin)741 inline PetscErrorCode VecSeq_CUPM<T>::SqrtAbs(Vec xin) noexcept
742 {
743   PetscFunctionBegin;
744   PetscCall(SqrtAbsAsync(xin, nullptr));
745   PetscFunctionReturn(PETSC_SUCCESS);
746 }
747 
748 namespace detail
749 {
750 
751 struct Exponent {
operator ()Petsc::vec::cupm::impl::detail::Exponent752   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscExpScalar(s); }
753 };
754 
755 } // namespace detail
756 
757 // VecExpAsync_Private
758 template <device::cupm::DeviceType T>
ExpAsync(Vec xin,PetscDeviceContext dctx)759 inline PetscErrorCode VecSeq_CUPM<T>::ExpAsync(Vec xin, PetscDeviceContext dctx) noexcept
760 {
761   PetscFunctionBegin;
762   PetscCall(PointwiseUnary_(detail::Exponent{}, xin, nullptr, dctx));
763   PetscFunctionReturn(PETSC_SUCCESS);
764 }
765 
766 // v->ops->exp
767 template <device::cupm::DeviceType T>
Exp(Vec xin)768 inline PetscErrorCode VecSeq_CUPM<T>::Exp(Vec xin) noexcept
769 {
770   PetscFunctionBegin;
771   PetscCall(ExpAsync(xin, nullptr));
772   PetscFunctionReturn(PETSC_SUCCESS);
773 }
774 
775 namespace detail
776 {
777 
778 struct Logarithm {
operator ()Petsc::vec::cupm::impl::detail::Logarithm779   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscLogScalar(s); }
780 };
781 
782 } // namespace detail
783 
784 // VecLogAsync_Private
785 template <device::cupm::DeviceType T>
LogAsync(Vec xin,PetscDeviceContext dctx)786 inline PetscErrorCode VecSeq_CUPM<T>::LogAsync(Vec xin, PetscDeviceContext dctx) noexcept
787 {
788   PetscFunctionBegin;
789   PetscCall(PointwiseUnary_(detail::Logarithm{}, xin, nullptr, dctx));
790   PetscFunctionReturn(PETSC_SUCCESS);
791 }
792 
793 // v->ops->log
794 template <device::cupm::DeviceType T>
Log(Vec xin)795 inline PetscErrorCode VecSeq_CUPM<T>::Log(Vec xin) noexcept
796 {
797   PetscFunctionBegin;
798   PetscCall(LogAsync(xin, nullptr));
799   PetscFunctionReturn(PETSC_SUCCESS);
800 }
801 
802 // v->ops->waxpy
803 template <device::cupm::DeviceType T>
WAXPYAsync(Vec win,PetscScalar alpha,Vec xin,Vec yin,PetscDeviceContext dctx)804 inline PetscErrorCode VecSeq_CUPM<T>::WAXPYAsync(Vec win, PetscScalar alpha, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
805 {
806   PetscBool xiscupm, yiscupm;
807 
808   PetscFunctionBegin;
809   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
810   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
811   if (!xiscupm || !yiscupm) {
812     PetscCall(VecWAXPY_Seq(win, alpha, xin, yin));
813     PetscFunctionReturn(PETSC_SUCCESS);
814   }
815   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
816   if (alpha == PetscScalar(0.0)) {
817     PetscCall(CopyAsync(yin, win, dctx));
818   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
819     cupmBlasHandle_t cupmBlasHandle;
820     cupmStream_t     stream;
821     PetscBool        xiscupm, yiscupm;
822 
823     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
824     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
825     if (!xiscupm || !yiscupm) {
826       PetscCall(VecWAXPY_Seq(win, alpha, xin, yin));
827       PetscFunctionReturn(PETSC_SUCCESS);
828     }
829     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle, NULL, &stream));
830     {
831       const auto wptr = DeviceArrayWrite(dctx, win);
832 
833       PetscCall(PetscLogGpuTimeBegin());
834       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
835       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
836       PetscCall(PetscLogGpuTimeEnd());
837     }
838     PetscCall(PetscLogGpuFlops(2 * n));
839     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
840   }
841   PetscFunctionReturn(PETSC_SUCCESS);
842 }
843 
844 // v->ops->waxpy
845 template <device::cupm::DeviceType T>
WAXPY(Vec win,PetscScalar alpha,Vec xin,Vec yin)846 inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
847 {
848   PetscFunctionBegin;
849   PetscCall(WAXPYAsync(win, alpha, xin, yin, nullptr));
850   PetscFunctionReturn(PETSC_SUCCESS);
851 }
852 
853 namespace kernels
854 {
855 
856 template <typename... Args>
MAXPY_kernel(const PetscInt size,PetscScalar * PETSC_RESTRICT xptr,const PetscScalar * PETSC_RESTRICT aptr,Args...yptr)857 PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
858 {
859   constexpr int      N        = sizeof...(Args);
860   const auto         tx       = threadIdx.x;
861   const PetscScalar *yptr_p[] = {yptr...};
862 
863   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];
864 
865   // load a to shared memory
866   if (tx < N) aptr_shmem[tx] = aptr[tx];
867   __syncthreads();
868 
869   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
870   // these may look the same but give different results!
871 #if 0
872     PetscScalar sum = 0.0;
873 
874   #pragma unroll
875     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
876     xptr[i] += sum;
877 #else
878     auto sum = xptr[i];
879 
880   #pragma unroll
881     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j] * yptr_p[j][i];
882     xptr[i] = sum;
883 #endif
884   });
885   return;
886 }
887 
888 } // namespace kernels
889 
890 namespace detail
891 {
892 
893 // a helper-struct to gobble the size_t input, it is used with template parameter pack
894 // expansion such that
895 // typename repeat_type<MyType, IdxParamPack>...
896 // expands to
897 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
898 template <typename T, std::size_t>
899 struct repeat_type {
900   using type = T;
901 };
902 
903 } // namespace detail
904 
905 template <device::cupm::DeviceType T>
906 template <std::size_t... Idx>
MAXPY_kernel_dispatch_(PetscDeviceContext dctx,cupmStream_t stream,PetscScalar * xptr,const PetscScalar * aptr,const Vec * yin,PetscInt size,util::index_sequence<Idx...>)907 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
908 {
909   PetscFunctionBegin;
910   // clang-format off
911   PetscCall(
912     PetscCUPMLaunchKernel1D(
913       size, 0, stream,
914       kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
915       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
916     )
917   );
918   // clang-format on
919   PetscFunctionReturn(PETSC_SUCCESS);
920 }
921 
922 template <device::cupm::DeviceType T>
923 template <int N>
MAXPY_kernel_dispatch_(PetscDeviceContext dctx,cupmStream_t stream,PetscScalar * xptr,const PetscScalar * aptr,const Vec * yin,PetscInt size,PetscInt & yidx)924 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
925 {
926   PetscFunctionBegin;
927   PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
928   yidx += N;
929   PetscFunctionReturn(PETSC_SUCCESS);
930 }
931 
932 // VecMAXPYAsync_Private
933 template <device::cupm::DeviceType T>
MAXPYAsync(Vec xin,PetscInt nv,const PetscScalar * alpha,Vec * yin,PetscDeviceContext dctx)934 inline PetscErrorCode VecSeq_CUPM<T>::MAXPYAsync(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin, PetscDeviceContext dctx) noexcept
935 {
936   const auto   n = xin->map->n;
937   cupmStream_t stream;
938   PetscBool    yiscupm = PETSC_TRUE;
939 
940   PetscFunctionBegin;
941   for (PetscInt i = 0; i < nv && yiscupm; i++) PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin[i]), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
942   if (!yiscupm) {
943     PetscCall(VecMAXPY_Seq(xin, nv, alpha, yin));
944     PetscFunctionReturn(PETSC_SUCCESS);
945   }
946   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
947   PetscCall(GetHandlesFrom_(dctx, &stream));
948   {
949     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
950     PetscScalar *d_alpha = nullptr;
951     PetscInt     yidx    = 0;
952 
953     // placement of early-return is deliberate, we would like to capture the
954     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
955     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
956     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
957     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
958     PetscCall(PetscLogGpuTimeBegin());
959     do {
960       switch (nv - yidx) {
961       case 7:
962         PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
963         break;
964       case 6:
965         PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
966         break;
967       case 5:
968         PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
969         break;
970       case 4:
971         PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
972         break;
973       case 3:
974         PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
975         break;
976       case 2:
977         PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
978         break;
979       case 1:
980         PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
981         break;
982       default: // 8 or more
983         PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
984         break;
985       }
986     } while (yidx < nv);
987     PetscCall(PetscLogGpuTimeEnd());
988     PetscCall(PetscDeviceFree(dctx, d_alpha));
989     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
990   }
991   PetscCall(PetscLogGpuFlops(nv * 2 * n));
992   PetscFunctionReturn(PETSC_SUCCESS);
993 }
994 
995 // v->ops->maxpy
996 template <device::cupm::DeviceType T>
MAXPY(Vec xin,PetscInt nv,const PetscScalar * alpha,Vec * yin)997 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
998 {
999   PetscFunctionBegin;
1000   PetscCall(MAXPYAsync(xin, nv, alpha, yin, nullptr));
1001   PetscFunctionReturn(PETSC_SUCCESS);
1002 }
1003 
1004 template <device::cupm::DeviceType T>
Dot(Vec xin,Vec yin,PetscScalar * z)1005 inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept
1006 {
1007   PetscBool yiscupm;
1008 
1009   PetscFunctionBegin;
1010   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1011   if (!yiscupm) {
1012     PetscCall(VecDot_Seq(xin, yin, z));
1013     PetscFunctionReturn(PETSC_SUCCESS);
1014   }
1015   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1016     PetscDeviceContext dctx;
1017     cupmBlasHandle_t   cupmBlasHandle;
1018 
1019     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1020     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
1021     // second
1022     PetscCall(PetscLogGpuTimeBegin());
1023     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
1024     PetscCall(PetscLogGpuTimeEnd());
1025     PetscCall(PetscLogGpuFlops(2 * n - 1));
1026   } else {
1027     *z = 0.0;
1028   }
1029   PetscFunctionReturn(PETSC_SUCCESS);
1030 }
1031 
1032 #define MDOT_WORKGROUP_NUM  128
1033 #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM
1034 
1035 namespace kernels
1036 {
1037 
EntriesPerGroup(const PetscInt size)1038 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
1039 {
1040   const auto group_entries = (size - 1) / gridDim.x + 1;
1041   // for very small vectors, a group should still do some work
1042   return group_entries ? group_entries : 1;
1043 }
1044 
1045 template <typename... ConstPetscScalarPointer>
MDot_kernel(const PetscScalar * PETSC_RESTRICT x,const PetscInt size,PetscScalar * PETSC_RESTRICT results,ConstPetscScalarPointer...y)1046 PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
1047 {
1048   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
1049   const PetscScalar *ylocal[] = {y...};
1050   PetscScalar        sumlocal[N];
1051 
1052   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];
1053 
1054   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
1055   // types, so each of these go on separate lines...
1056   const auto tx       = threadIdx.x;
1057   const auto bx       = blockIdx.x;
1058   const auto bdx      = blockDim.x;
1059   const auto gdx      = gridDim.x;
1060   const auto worksize = EntriesPerGroup(size);
1061   const auto begin    = tx + bx * worksize;
1062   const auto end      = min((bx + 1) * worksize, size);
1063 
1064 #pragma unroll
1065   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;
1066 
1067   for (auto i = begin; i < end; i += bdx) {
1068     const auto xi = x[i]; // load only once from global memory!
1069 
1070 #pragma unroll
1071     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
1072   }
1073 
1074 #pragma unroll
1075   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];
1076 
1077   // parallel reduction
1078   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
1079     __syncthreads();
1080     if (tx < stride) {
1081 #pragma unroll
1082       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
1083     }
1084   }
1085   // bottom N threads per block write to global memory
1086   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
1087   // writes to the same sections in the above loop that it is about to read from below, but
1088   // running this under the racecheck tool of compute-sanitizer reports a write-after-write hazard.
1089   __syncthreads();
1090   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
1091   return;
1092 }
1093 
1094 namespace
1095 {
1096 
sum_kernel(const PetscInt size,PetscScalar * PETSC_RESTRICT results)1097 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
1098 {
1099   int         local_i = 0;
1100   PetscScalar local_results[8];
1101 
1102   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
1103   //
1104   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
1105   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
1106   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
1107   //  |  ______________________________________________________/
1108   //  | /            <- MDOT_WORKGROUP_NUM ->
1109   //  |/
1110   //  +
1111   //  v
1112   // *-*-*
1113   // | | | ...
1114   // *-*-*
1115   //
1116   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
1117     PetscScalar z_sum = 0;
1118 
1119     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
1120     local_results[local_i++] = z_sum;
1121   });
1122   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
1123   // may currently be reading from results
1124   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
1125   // Local buffer is now written to global memory
1126   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
1127     const auto j = --local_i;
1128 
1129     if (j >= 0) results[i] = local_results[j];
1130   });
1131   return;
1132 }
1133 
1134 } // namespace
1135 
1136 #if PetscDefined(USING_HCC)
1137 namespace do_not_use
1138 {
1139 
silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()1140 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1141 {
1142   (void)sum_kernel;
1143 }
1144 
1145 } // namespace do_not_use
1146 #endif
1147 
1148 } // namespace kernels
1149 
1150 template <device::cupm::DeviceType T>
1151 template <std::size_t... Idx>
MDot_kernel_dispatch_(PetscDeviceContext dctx,cupmStream_t stream,const PetscScalar * xarr,const Vec yin[],PetscInt size,PetscScalar * results,util::index_sequence<Idx...>)1152 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
1153 {
1154   PetscFunctionBegin;
1155   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
1156   // 128 blocks of 128 threads every time which may be wasteful
1157   // clang-format off
1158   PetscCallCUPM(
1159     cupmLaunchKernel(
1160       kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
1161       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
1162       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
1163     )
1164   );
1165   // clang-format on
1166   PetscFunctionReturn(PETSC_SUCCESS);
1167 }
1168 
1169 template <device::cupm::DeviceType T>
1170 template <int N>
MDot_kernel_dispatch_(PetscDeviceContext dctx,cupmStream_t stream,const PetscScalar * xarr,const Vec yin[],PetscInt size,PetscScalar * results,PetscInt & yidx)1171 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
1172 {
1173   PetscFunctionBegin;
1174   PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
1175   yidx += N;
1176   PetscFunctionReturn(PETSC_SUCCESS);
1177 }
1178 
1179 template <device::cupm::DeviceType T>
MDot_(std::false_type,Vec xin,PetscInt nv,const Vec yin[],PetscScalar * z,PetscDeviceContext dctx)1180 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
1181 {
1182   // the largest possible size of a batch
1183   constexpr PetscInt batchsize = 8;
1184   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
1185   // do not create substreams. Note we don't create more than 8 streams, in practice we could
1186   // not get more parallelism with higher numbers.
1187   const auto   num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
1188   const auto   n               = xin->map->n;
1189   const auto   nwork           = nv * MDOT_WORKGROUP_NUM;
1190   PetscScalar *d_results;
1191   cupmStream_t stream;
1192 
1193   PetscFunctionBegin;
1194   PetscCall(GetHandlesFrom_(dctx, &stream));
1195   // allocate scratchpad memory for the results of individual work groups
1196   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
1197   {
1198     const auto          xptr       = DeviceArrayRead(dctx, xin);
1199     PetscInt            yidx       = 0;
1200     auto                subidx     = 0;
1201     auto                cur_stream = stream;
1202     auto                cur_ctx    = dctx;
1203     PetscDeviceContext *sub        = nullptr;
1204     PetscStreamType     stype;
1205 
1206     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
1207     // sub. Ideally the parent context should also join in on the fork, but it is extremely
1208     // fiddly to do so presently
1209     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
1210     if (stype == PETSC_STREAM_DEFAULT || stype == PETSC_STREAM_DEFAULT_WITH_BARRIER) stype = PETSC_STREAM_NONBLOCKING;
1211     // If we have a default stream create nonblocking streams instead (as we can
1212     // locally exploit the parallelism). Otherwise use the prescribed stream type.
1213     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
1214     PetscCall(PetscLogGpuTimeBegin());
1215     do {
1216       if (num_sub_streams) {
1217         cur_ctx = sub[subidx++ % num_sub_streams];
1218         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
1219       }
1220       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
1221       // it is very likely better to do 4+5 rather than 8+1
1222       switch (nv - yidx) {
1223       case 7:
1224         PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1225         break;
1226       case 6:
1227         PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1228         break;
1229       case 5:
1230         PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1231         break;
1232       case 4:
1233         PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1234         break;
1235       case 3:
1236         PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1237         break;
1238       case 2:
1239         PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1240         break;
1241       case 1:
1242         PetscCall(MDot_kernel_dispatch_<1>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1243         break;
1244       default: // 8 or more
1245         PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1246         break;
1247       }
1248     } while (yidx < nv);
1249     PetscCall(PetscLogGpuTimeEnd());
1250     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
1251   }
1252 
1253   PetscCall(PetscCUPMLaunchKernel1D(nv, 0, stream, kernels::sum_kernel, nv, d_results));
1254   // copy result of device reduction to host
1255   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nv, cupmMemcpyDeviceToHost, stream));
1256   // do these now while final reduction is in flight
1257   PetscCall(PetscLogGpuFlops(nwork));
1258   PetscCall(PetscDeviceFree(dctx, d_results));
1259   PetscFunctionReturn(PETSC_SUCCESS);
1260 }
1261 
1262 #undef MDOT_WORKGROUP_NUM
1263 #undef MDOT_WORKGROUP_SIZE
1264 
1265 template <device::cupm::DeviceType T>
MDot_(std::true_type,Vec xin,PetscInt nv,const Vec yin[],PetscScalar * z,PetscDeviceContext dctx)1266 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
1267 {
1268   // probably not worth it to run more than 8 of these at a time?
1269   const auto          n_sub = PetscMin(nv, 8);
1270   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
1271   const auto          xptr  = DeviceArrayRead(dctx, xin);
1272   PetscScalar        *d_z;
1273   PetscDeviceContext *subctx;
1274   cupmStream_t        stream;
1275 
1276   PetscFunctionBegin;
1277   PetscCall(GetHandlesFrom_(dctx, &stream));
1278   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
1279   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
1280   PetscCall(PetscLogGpuTimeBegin());
1281   for (PetscInt i = 0; i < nv; ++i) {
1282     const auto            sub = subctx[i % n_sub];
1283     cupmBlasHandle_t      handle;
1284     cupmBlasPointerMode_t old_mode;
1285 
1286     PetscCall(GetHandlesFrom_(sub, &handle));
1287     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
1288     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
1289     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
1290     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
1291   }
1292   PetscCall(PetscLogGpuTimeEnd());
1293   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
1294   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
1295   PetscCall(PetscDeviceFree(dctx, d_z));
1296   // REVIEW ME: flops?????
1297   PetscFunctionReturn(PETSC_SUCCESS);
1298 }
1299 
1300 // v->ops->mdot
1301 template <device::cupm::DeviceType T>
MDot(Vec xin,PetscInt nv,const Vec yin[],PetscScalar * z)1302 inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
1303 {
1304   PetscFunctionBegin;
1305   if (PetscUnlikely(nv == 1)) {
1306     // dot handles nv = 0 correctly
1307     PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z));
1308   } else if (const auto n = xin->map->n) {
1309     PetscDeviceContext dctx;
1310 
1311     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
1312     PetscCall(GetHandles_(&dctx));
1313     PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
1314     // REVIEW ME: double count of flops??
1315     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
1316     PetscCall(PetscDeviceContextSynchronize(dctx));
1317   } else {
1318     PetscCall(PetscArrayzero(z, nv));
1319   }
1320   PetscFunctionReturn(PETSC_SUCCESS);
1321 }
1322 
1323 // VecSetAsync_Private
1324 template <device::cupm::DeviceType T>
SetAsync(Vec xin,PetscScalar alpha,PetscDeviceContext dctx)1325 inline PetscErrorCode VecSeq_CUPM<T>::SetAsync(Vec xin, PetscScalar alpha, PetscDeviceContext dctx) noexcept
1326 {
1327   const auto   n = xin->map->n;
1328   cupmStream_t stream;
1329 
1330   PetscFunctionBegin;
1331   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1332   PetscCall(GetHandlesFrom_(dctx, &stream));
1333   {
1334     const auto xptr = DeviceArrayWrite(dctx, xin);
1335 
1336     if (alpha == PetscScalar(0.0)) {
1337       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1338     } else {
1339       const auto dptr = thrust::device_pointer_cast(xptr.data());
1340 
1341       PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha));
1342     }
1343   }
1344   if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1345   PetscFunctionReturn(PETSC_SUCCESS);
1346 }
1347 
1348 // v->ops->set
1349 template <device::cupm::DeviceType T>
Set(Vec xin,PetscScalar alpha)1350 inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept
1351 {
1352   PetscFunctionBegin;
1353   PetscCall(SetAsync(xin, alpha, nullptr));
1354   PetscFunctionReturn(PETSC_SUCCESS);
1355 }
1356 
1357 // VecScaleAsync_Private
1358 template <device::cupm::DeviceType T>
ScaleAsync(Vec xin,PetscScalar alpha,PetscDeviceContext dctx)1359 inline PetscErrorCode VecSeq_CUPM<T>::ScaleAsync(Vec xin, PetscScalar alpha, PetscDeviceContext dctx) noexcept
1360 {
1361   PetscFunctionBegin;
1362   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1363   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1364   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1365     PetscCall(SetAsync(xin, alpha, dctx));
1366   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1367     cupmBlasHandle_t cupmBlasHandle;
1368 
1369     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1370     PetscCall(PetscLogGpuTimeBegin());
1371     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1372     PetscCall(PetscLogGpuTimeEnd());
1373     PetscCall(PetscLogGpuFlops(n));
1374     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1375   } else {
1376     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1377   }
1378   PetscFunctionReturn(PETSC_SUCCESS);
1379 }
1380 
1381 // v->ops->scale
1382 template <device::cupm::DeviceType T>
Scale(Vec xin,PetscScalar alpha)1383 inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept
1384 {
1385   PetscFunctionBegin;
1386   PetscCall(ScaleAsync(xin, alpha, nullptr));
1387   PetscFunctionReturn(PETSC_SUCCESS);
1388 }
1389 
1390 // v->ops->tdot
1391 template <device::cupm::DeviceType T>
TDot(Vec xin,Vec yin,PetscScalar * z)1392 inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept
1393 {
1394   PetscBool yiscupm;
1395 
1396   PetscFunctionBegin;
1397   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1398   if (!yiscupm) {
1399     PetscCall(VecTDot_Seq(xin, yin, z));
1400     PetscFunctionReturn(PETSC_SUCCESS);
1401   }
1402   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1403     PetscDeviceContext dctx;
1404     cupmBlasHandle_t   cupmBlasHandle;
1405 
1406     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1407     PetscCall(PetscLogGpuTimeBegin());
1408     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1409     PetscCall(PetscLogGpuTimeEnd());
1410     PetscCall(PetscLogGpuFlops(2 * n - 1));
1411   } else {
1412     *z = 0.0;
1413   }
1414   PetscFunctionReturn(PETSC_SUCCESS);
1415 }
1416 
1417 // VecCopyAsync_Private
1418 template <device::cupm::DeviceType T>
CopyAsync(Vec xin,Vec yout,PetscDeviceContext dctx)1419 inline PetscErrorCode VecSeq_CUPM<T>::CopyAsync(Vec xin, Vec yout, PetscDeviceContext dctx) noexcept
1420 {
1421   PetscFunctionBegin;
1422   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1423   if (const auto n = xin->map->n) {
1424     const auto xmask = xin->offloadmask;
1425     // silence buggy gcc warning: mode may be used uninitialized in this function
1426     auto         mode = cupmMemcpyDeviceToDevice;
1427     cupmStream_t stream;
1428 
1429     // translate from PetscOffloadMask to cupmMemcpyKind
1430     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1431     switch (const auto ymask = yout->offloadmask) {
1432     case PETSC_OFFLOAD_UNALLOCATED: {
1433       PetscBool yiscupm;
1434 
1435       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1436       if (yiscupm) {
1437         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1438         break;
1439       }
1440     } // fall-through if unallocated and not cupm
1441 #if PETSC_CPP_VERSION >= 17
1442       [[fallthrough]];
1443 #endif
1444     case PETSC_OFFLOAD_CPU: {
1445       PetscBool yiscupm;
1446 
1447       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1448       if (yiscupm) {
1449         mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToDevice : cupmMemcpyDeviceToDevice;
1450       } else {
1451         mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1452       }
1453       break;
1454     }
1455     case PETSC_OFFLOAD_BOTH:
1456     case PETSC_OFFLOAD_GPU:
1457       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1458       break;
1459     default:
1460       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1461     }
1462 
1463     PetscCall(GetHandlesFrom_(dctx, &stream));
1464     switch (mode) {
1465     case cupmMemcpyDeviceToDevice: // the best case
1466     case cupmMemcpyHostToDevice: { // not terrible
1467       const auto yptr = DeviceArrayWrite(dctx, yout);
1468       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1469 
1470       PetscCall(PetscLogGpuTimeBegin());
1471       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1472       PetscCall(PetscLogGpuTimeEnd());
1473     } break;
1474     case cupmMemcpyDeviceToHost: // not great
1475     case cupmMemcpyHostToHost: { // worst case
1476       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1477       PetscScalar *yptr;
1478 
1479       PetscCall(VecGetArrayWrite(yout, &yptr));
1480       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1481       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1482       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1483       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1484     } break;
1485     default:
1486       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1487     }
1488     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1489   } else {
1490     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1491   }
1492   PetscFunctionReturn(PETSC_SUCCESS);
1493 }
1494 
1495 // v->ops->copy
1496 template <device::cupm::DeviceType T>
Copy(Vec xin,Vec yout)1497 inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept
1498 {
1499   PetscFunctionBegin;
1500   PetscCall(CopyAsync(xin, yout, nullptr));
1501   PetscFunctionReturn(PETSC_SUCCESS);
1502 }
1503 
1504 // VecSwapAsync_Private
1505 template <device::cupm::DeviceType T>
SwapAsync(Vec xin,Vec yin,PetscDeviceContext dctx)1506 inline PetscErrorCode VecSeq_CUPM<T>::SwapAsync(Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
1507 {
1508   PetscBool yiscupm;
1509 
1510   PetscFunctionBegin;
1511   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1512   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1513   PetscCheck(yiscupm, PetscObjectComm(PetscObjectCast(yin)), PETSC_ERR_SUP, "Cannot swap with Y of type %s", PetscObjectCast(yin)->type_name);
1514   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1515   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1516     cupmBlasHandle_t cupmBlasHandle;
1517 
1518     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1519     PetscCall(PetscLogGpuTimeBegin());
1520     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1521     PetscCall(PetscLogGpuTimeEnd());
1522     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1523   } else {
1524     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1525     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1526   }
1527   PetscFunctionReturn(PETSC_SUCCESS);
1528 }
1529 
1530 // v->ops->swap
1531 template <device::cupm::DeviceType T>
Swap(Vec xin,Vec yin)1532 inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept
1533 {
1534   PetscFunctionBegin;
1535   PetscCall(SwapAsync(xin, yin, nullptr));
1536   PetscFunctionReturn(PETSC_SUCCESS);
1537 }
1538 
1539 // VecAXPYBYAsync_Private
1540 template <device::cupm::DeviceType T>
AXPBYAsync(Vec yin,PetscScalar alpha,PetscScalar beta,Vec xin,PetscDeviceContext dctx)1541 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYAsync(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin, PetscDeviceContext dctx) noexcept
1542 {
1543   PetscBool xiscupm;
1544 
1545   PetscFunctionBegin;
1546   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1547   if (!xiscupm) {
1548     PetscCall(VecAXPBY_Seq(yin, alpha, beta, xin));
1549     PetscFunctionReturn(PETSC_SUCCESS);
1550   }
1551   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1552   if (alpha == PetscScalar(0.0)) {
1553     PetscCall(ScaleAsync(yin, beta, dctx));
1554   } else if (beta == PetscScalar(1.0)) {
1555     PetscCall(AXPYAsync(yin, alpha, xin, dctx));
1556   } else if (alpha == PetscScalar(1.0)) {
1557     PetscCall(AYPXAsync(yin, beta, xin, dctx));
1558   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1559     PetscBool xiscupm;
1560 
1561     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1562     if (!xiscupm) {
1563       PetscCall(VecAXPBY_Seq(yin, alpha, beta, xin));
1564       PetscFunctionReturn(PETSC_SUCCESS);
1565     }
1566 
1567     const auto       betaIsZero = beta == PetscScalar(0.0);
1568     const auto       aptr       = cupmScalarPtrCast(&alpha);
1569     cupmBlasHandle_t cupmBlasHandle;
1570 
1571     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1572     {
1573       const auto xptr = DeviceArrayRead(dctx, xin);
1574 
1575       if (betaIsZero /* beta = 0 */) {
1576         // here we can get away with purely write-only as we memcpy into it first
1577         const auto   yptr = DeviceArrayWrite(dctx, yin);
1578         cupmStream_t stream;
1579 
1580         PetscCall(GetHandlesFrom_(dctx, &stream));
1581         PetscCall(PetscLogGpuTimeBegin());
1582         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1583         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1584       } else {
1585         const auto yptr = DeviceArrayReadWrite(dctx, yin);
1586 
1587         PetscCall(PetscLogGpuTimeBegin());
1588         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1589         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1590       }
1591     }
1592     PetscCall(PetscLogGpuTimeEnd());
1593     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1594     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1595   } else {
1596     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1597   }
1598   PetscFunctionReturn(PETSC_SUCCESS);
1599 }
1600 
1601 // v->ops->axpby
1602 template <device::cupm::DeviceType T>
AXPBY(Vec yin,PetscScalar alpha,PetscScalar beta,Vec xin)1603 inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1604 {
1605   PetscFunctionBegin;
1606   PetscCall(AXPBYAsync(yin, alpha, beta, xin, nullptr));
1607   PetscFunctionReturn(PETSC_SUCCESS);
1608 }
1609 
1610 // VecAXPBYPCZAsync_Private
1611 template <device::cupm::DeviceType T>
AXPBYPCZAsync(Vec zin,PetscScalar alpha,PetscScalar beta,PetscScalar gamma,Vec xin,Vec yin,PetscDeviceContext dctx)1612 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZAsync(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
1613 {
1614   PetscFunctionBegin;
1615   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1616   if (gamma != PetscScalar(1.0)) PetscCall(ScaleAsync(zin, gamma, dctx));
1617   PetscCall(AXPYAsync(zin, alpha, xin, dctx));
1618   PetscCall(AXPYAsync(zin, beta, yin, dctx));
1619   PetscFunctionReturn(PETSC_SUCCESS);
1620 }
1621 
1622 // v->ops->axpbypcz
1623 template <device::cupm::DeviceType T>
AXPBYPCZ(Vec zin,PetscScalar alpha,PetscScalar beta,PetscScalar gamma,Vec xin,Vec yin)1624 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1625 {
1626   PetscFunctionBegin;
1627   PetscCall(AXPBYPCZAsync(zin, alpha, beta, gamma, xin, yin, nullptr));
1628   PetscFunctionReturn(PETSC_SUCCESS);
1629 }
1630 
1631 // v->ops->norm
1632 template <device::cupm::DeviceType T>
Norm(Vec xin,NormType type,PetscReal * z)1633 inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept
1634 {
1635   PetscDeviceContext dctx;
1636   cupmBlasHandle_t   cupmBlasHandle;
1637 
1638   PetscFunctionBegin;
1639   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1640   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1641     const auto xptr      = DeviceArrayRead(dctx, xin);
1642     PetscInt   flopCount = 0;
1643 
1644     PetscCall(PetscLogGpuTimeBegin());
1645     switch (type) {
1646     case NORM_1_AND_2:
1647     case NORM_1:
1648       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1649       flopCount = std::max(n - 1, 0);
1650       if (type == NORM_1) break;
1651       ++z; // fall-through
1652 #if PETSC_CPP_VERSION >= 17
1653       [[fallthrough]];
1654 #endif
1655     case NORM_2:
1656     case NORM_FROBENIUS:
1657       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1658       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1659       break;
1660     case NORM_INFINITY: {
1661       cupmBlasInt_t max_loc = 0;
1662       PetscScalar   xv      = 0.;
1663       cupmStream_t  stream;
1664 
1665       PetscCall(GetHandlesFrom_(dctx, &stream));
1666       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1667       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1668       *z = PetscAbsScalar(xv);
1669       // REVIEW ME: flopCount = ???
1670     } break;
1671     }
1672     PetscCall(PetscLogGpuTimeEnd());
1673     PetscCall(PetscLogGpuFlops(flopCount));
1674   } else {
1675     z[0]                    = 0.0;
1676     z[type == NORM_1_AND_2] = 0.0;
1677   }
1678   PetscFunctionReturn(PETSC_SUCCESS);
1679 }
1680 
1681 namespace detail
1682 {
1683 
1684 template <NormType wnormtype>
1685 class ErrorWNormTransformBase {
1686 public:
1687   using result_type = thrust::tuple<PetscReal, PetscReal, PetscReal, PetscInt, PetscInt, PetscInt>;
1688 
ErrorWNormTransformBase(PetscReal v)1689   constexpr explicit ErrorWNormTransformBase(PetscReal v) noexcept : ignore_max_{v} { }
1690 
1691 protected:
1692   struct NormTuple {
1693     PetscReal norm;
1694     PetscInt  loc;
1695   };
1696 
compute_norm_(PetscReal err,PetscReal tol)1697   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL static NormTuple compute_norm_(PetscReal err, PetscReal tol) noexcept
1698   {
1699     if (tol > 0.) {
1700       const auto val = err / tol;
1701 
1702       return {wnormtype == NORM_INFINITY ? val : PetscSqr(val), 1};
1703     } else {
1704       return {0.0, 0};
1705     }
1706   }
1707 
1708   PetscReal ignore_max_;
1709 };
1710 
1711 template <NormType wnormtype>
1712 struct ErrorWNormTransform : ErrorWNormTransformBase<wnormtype> {
1713   using base_type     = ErrorWNormTransformBase<wnormtype>;
1714   using result_type   = typename base_type::result_type;
1715   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar>;
1716 
1717   using base_type::base_type;
1718 
operator ()Petsc::vec::cupm::impl::detail::ErrorWNormTransform1719   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1720   {
1721     const auto u     = thrust::get<0>(x); // with x.get<0>(), cuda-12.4.0 gives error: class "cuda::std::__4::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar>" has no member "get"
1722     const auto y     = thrust::get<1>(x);
1723     const auto au    = PetscAbsScalar(u);
1724     const auto ay    = PetscAbsScalar(y);
1725     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1726     const auto tola  = skip ? 0.0 : PetscRealPart(thrust::get<2>(x));
1727     const auto tolr  = skip ? 0.0 : PetscRealPart(thrust::get<3>(x)) * PetscMax(au, ay);
1728     const auto tol   = tola + tolr;
1729     const auto err   = PetscAbsScalar(u - y);
1730     const auto tup_a = this->compute_norm_(err, tola);
1731     const auto tup_r = this->compute_norm_(err, tolr);
1732     const auto tup_n = this->compute_norm_(err, tol);
1733 
1734     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1735   }
1736 };
1737 
1738 template <NormType wnormtype>
1739 struct ErrorWNormETransform : ErrorWNormTransformBase<wnormtype> {
1740   using base_type     = ErrorWNormTransformBase<wnormtype>;
1741   using result_type   = typename base_type::result_type;
1742   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar, PetscScalar>;
1743 
1744   using base_type::base_type;
1745 
operator ()Petsc::vec::cupm::impl::detail::ErrorWNormETransform1746   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1747   {
1748     const auto au    = PetscAbsScalar(thrust::get<0>(x));
1749     const auto ay    = PetscAbsScalar(thrust::get<1>(x));
1750     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1751     const auto tola  = skip ? 0.0 : PetscRealPart(thrust::get<3>(x));
1752     const auto tolr  = skip ? 0.0 : PetscRealPart(thrust::get<4>(x)) * PetscMax(au, ay);
1753     const auto tol   = tola + tolr;
1754     const auto err   = PetscAbsScalar(thrust::get<2>(x));
1755     const auto tup_a = this->compute_norm_(err, tola);
1756     const auto tup_r = this->compute_norm_(err, tolr);
1757     const auto tup_n = this->compute_norm_(err, tol);
1758 
1759     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1760   }
1761 };
1762 
1763 template <NormType wnormtype>
1764 struct ErrorWNormReduce {
1765   using value_type = typename ErrorWNormTransformBase<wnormtype>::result_type;
1766 
operator ()Petsc::vec::cupm::impl::detail::ErrorWNormReduce1767   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept
1768   {
1769     // cannot use lhs.get<0>() etc since the using decl above ambiguates the fact that
1770     // result_type is a template, so in order to fix this we would need to write:
1771     //
1772     // lhs.template get<0>()
1773     //
1774     // which is unseemly.
1775     if (wnormtype == NORM_INFINITY) {
1776       // clang-format off
1777       return {
1778         PetscMax(thrust::get<0>(lhs), thrust::get<0>(rhs)),
1779         PetscMax(thrust::get<1>(lhs), thrust::get<1>(rhs)),
1780         PetscMax(thrust::get<2>(lhs), thrust::get<2>(rhs)),
1781         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1782         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1783         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1784       };
1785       // clang-format on
1786     } else {
1787       // clang-format off
1788       return {
1789         thrust::get<0>(lhs) + thrust::get<0>(rhs),
1790         thrust::get<1>(lhs) + thrust::get<1>(rhs),
1791         thrust::get<2>(lhs) + thrust::get<2>(rhs),
1792         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1793         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1794         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1795       };
1796       // clang-format on
1797     }
1798   }
1799 };
1800 
1801 template <template <NormType> class WNormTransformType, typename Tuple, typename cupmStream_t>
ExecuteWNorm(Tuple && first,Tuple && last,NormType wnormtype,cupmStream_t stream,PetscReal ignore_max,PetscReal * norm,PetscInt * norm_loc,PetscReal * norma,PetscInt * norma_loc,PetscReal * normr,PetscInt * normr_loc)1802 inline PetscErrorCode ExecuteWNorm(Tuple &&first, Tuple &&last, NormType wnormtype, cupmStream_t stream, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1803 {
1804   auto      begin = thrust::make_zip_iterator(std::forward<Tuple>(first));
1805   auto      end   = thrust::make_zip_iterator(std::forward<Tuple>(last));
1806   PetscReal n = 0, na = 0, nr = 0;
1807   PetscInt  n_loc = 0, na_loc = 0, nr_loc = 0;
1808 
1809   PetscFunctionBegin;
1810   // clang-format off
1811   if (wnormtype == NORM_INFINITY) {
1812     PetscCallThrust(
1813       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1814         thrust::transform_reduce,
1815         stream,
1816         std::move(begin),
1817         std::move(end),
1818         WNormTransformType<NORM_INFINITY>{ignore_max},
1819         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1820         ErrorWNormReduce<NORM_INFINITY>{}
1821       )
1822     );
1823   } else {
1824     PetscCallThrust(
1825       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1826         thrust::transform_reduce,
1827         stream,
1828         std::move(begin),
1829         std::move(end),
1830         WNormTransformType<NORM_2>{ignore_max},
1831         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1832         ErrorWNormReduce<NORM_2>{}
1833       )
1834     );
1835   }
1836   // clang-format on
1837   if (wnormtype == NORM_2) {
1838     *norm  = PetscSqrtReal(*norm);
1839     *norma = PetscSqrtReal(*norma);
1840     *normr = PetscSqrtReal(*normr);
1841   }
1842   PetscFunctionReturn(PETSC_SUCCESS);
1843 }
1844 
1845 } // namespace detail
1846 
1847 // v->ops->errorwnorm
1848 template <device::cupm::DeviceType T>
ErrorWnorm(Vec U,Vec Y,Vec E,NormType wnormtype,PetscReal atol,Vec vatol,PetscReal rtol,Vec vrtol,PetscReal ignore_max,PetscReal * norm,PetscInt * norm_loc,PetscReal * norma,PetscInt * norma_loc,PetscReal * normr,PetscInt * normr_loc)1849 inline PetscErrorCode VecSeq_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1850 {
1851   const auto         nl  = U->map->n;
1852   auto               ait = thrust::make_constant_iterator(static_cast<PetscScalar>(atol));
1853   auto               rit = thrust::make_constant_iterator(static_cast<PetscScalar>(rtol));
1854   PetscDeviceContext dctx;
1855   cupmStream_t       stream;
1856 
1857   PetscFunctionBegin;
1858   PetscCall(GetHandles_(&dctx, &stream));
1859   {
1860     const auto ConditionalDeviceArrayRead = [&](Vec v) {
1861       if (v) {
1862         return thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1863       } else {
1864         return thrust::device_ptr<PetscScalar>{nullptr};
1865       }
1866     };
1867 
1868     const auto uarr = DeviceArrayRead(dctx, U);
1869     const auto yarr = DeviceArrayRead(dctx, Y);
1870     const auto uptr = thrust::device_pointer_cast(uarr.data());
1871     const auto yptr = thrust::device_pointer_cast(yarr.data());
1872     const auto eptr = ConditionalDeviceArrayRead(E);
1873     const auto rptr = ConditionalDeviceArrayRead(vrtol);
1874     const auto aptr = ConditionalDeviceArrayRead(vatol);
1875 
1876     if (!vatol && !vrtol) {
1877       if (E) {
1878         // clang-format off
1879         PetscCall(
1880           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1881             thrust::make_tuple(uptr, yptr, eptr, ait, rit),
1882             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rit),
1883             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1884           )
1885         );
1886         // clang-format on
1887       } else {
1888         // clang-format off
1889         PetscCall(
1890           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1891             thrust::make_tuple(uptr, yptr, ait, rit),
1892             thrust::make_tuple(uptr + nl, yptr + nl, ait, rit),
1893             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1894           )
1895         );
1896         // clang-format on
1897       }
1898     } else if (!vatol) {
1899       if (E) {
1900         // clang-format off
1901         PetscCall(
1902           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1903             thrust::make_tuple(uptr, yptr, eptr, ait, rptr),
1904             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rptr + nl),
1905             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1906           )
1907         );
1908         // clang-format on
1909       } else {
1910         // clang-format off
1911         PetscCall(
1912           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1913             thrust::make_tuple(uptr, yptr, ait, rptr),
1914             thrust::make_tuple(uptr + nl, yptr + nl, ait, rptr + nl),
1915             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1916           )
1917         );
1918         // clang-format on
1919       }
1920     } else if (!vrtol) {
1921       if (E) {
1922         // clang-format off
1923           PetscCall(
1924             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1925               thrust::make_tuple(uptr, yptr, eptr, aptr, rit),
1926               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rit),
1927               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1928             )
1929           );
1930         // clang-format on
1931       } else {
1932         // clang-format off
1933           PetscCall(
1934             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1935               thrust::make_tuple(uptr, yptr, aptr, rit),
1936               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rit),
1937               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1938             )
1939           );
1940         // clang-format on
1941       }
1942     } else {
1943       if (E) {
1944         // clang-format off
1945           PetscCall(
1946             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1947               thrust::make_tuple(uptr, yptr, eptr, aptr, rptr),
1948               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rptr + nl),
1949               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1950             )
1951           );
1952         // clang-format on
1953       } else {
1954         // clang-format off
1955           PetscCall(
1956             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1957               thrust::make_tuple(uptr, yptr, aptr, rptr),
1958               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rptr + nl),
1959               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1960             )
1961           );
1962         // clang-format on
1963       }
1964     }
1965   }
1966   PetscFunctionReturn(PETSC_SUCCESS);
1967 }
1968 
1969 namespace detail
1970 {
1971 struct dotnorm2_mult {
operator ()Petsc::vec::cupm::impl::detail::dotnorm2_mult1972   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1973   {
1974     const auto conjt = PetscConj(t);
1975 
1976     return {s * conjt, t * conjt};
1977   }
1978 };
1979 
1980 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1981 // would do it myself but now I am worried that they do so on purpose...
1982 struct dotnorm2_tuple_plus {
1983   using value_type = thrust::tuple<PetscScalar, PetscScalar>;
1984 
operator ()Petsc::vec::cupm::impl::detail::dotnorm2_tuple_plus1985   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {thrust::get<0>(lhs) + thrust::get<0>(rhs), thrust::get<1>(lhs) + thrust::get<1>(rhs)}; }
1986 };
1987 
1988 } // namespace detail
1989 
1990 // v->ops->dotnorm2
1991 template <device::cupm::DeviceType T>
DotNorm2(Vec s,Vec t,PetscScalar * dp,PetscScalar * nm)1992 inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1993 {
1994   PetscDeviceContext dctx;
1995   cupmStream_t       stream;
1996 
1997   PetscFunctionBegin;
1998   PetscCall(GetHandles_(&dctx, &stream));
1999   {
2000     PetscScalar dpt = 0.0, nmt = 0.0;
2001     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());
2002 
2003     // clang-format off
2004     PetscCallThrust(
2005       thrust::tie(*dp, *nm) = THRUST_CALL(
2006         thrust::inner_product,
2007         stream,
2008         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
2009         thrust::make_tuple(dpt, nmt),
2010         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
2011       );
2012     );
2013     // clang-format on
2014   }
2015   PetscFunctionReturn(PETSC_SUCCESS);
2016 }
2017 
2018 namespace detail
2019 {
2020 struct conjugate {
operator ()Petsc::vec::cupm::impl::detail::conjugate2021   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &x) const noexcept { return PetscConj(x); }
2022 };
2023 
2024 } // namespace detail
2025 
2026 // v->ops->conjugate
2027 template <device::cupm::DeviceType T>
ConjugateAsync(Vec xin,PetscDeviceContext dctx)2028 inline PetscErrorCode VecSeq_CUPM<T>::ConjugateAsync(Vec xin, PetscDeviceContext dctx) noexcept
2029 {
2030   PetscFunctionBegin;
2031   if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin, nullptr, dctx));
2032   PetscFunctionReturn(PETSC_SUCCESS);
2033 }
2034 
2035 // v->ops->conjugate
2036 template <device::cupm::DeviceType T>
Conjugate(Vec xin)2037 inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept
2038 {
2039   PetscFunctionBegin;
2040   PetscCall(ConjugateAsync(xin, nullptr));
2041   PetscFunctionReturn(PETSC_SUCCESS);
2042 }
2043 
2044 namespace detail
2045 {
2046 
2047 struct real_part {
operator ()Petsc::vec::cupm::impl::detail::real_part2048   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const noexcept { return {PetscRealPart(thrust::get<0>(x)), thrust::get<1>(x)}; }
2049 
operator ()Petsc::vec::cupm::impl::detail::real_part2050   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(const PetscScalar &x) const noexcept { return PetscRealPart(x); }
2051 };
2052 
2053 // deriving from Operator allows us to "store" an instance of the operator in the class but
2054 // also take advantage of empty base class optimization if the operator is stateless
2055 template <typename Operator>
2056 class tuple_compare : Operator {
2057 public:
2058   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
2059   using operator_type = Operator;
2060 
operator ()(const tuple_type & x,const tuple_type & y) const2061   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
2062   {
2063     if (op_()(thrust::get<0>(y), thrust::get<0>(x))) {
2064       // if y is strictly greater/less than x, return y
2065       return y;
2066     } else if (thrust::get<0>(y) == thrust::get<0>(x)) {
2067       // if equal, prefer lower index
2068       return thrust::get<1>(y) < thrust::get<1>(x) ? y : x;
2069     }
2070     // otherwise return x
2071     return x;
2072   }
2073 
2074 private:
op_() const2075   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
2076 };
2077 
2078 } // namespace detail
2079 
2080 template <device::cupm::DeviceType T>
2081 template <typename TupleFuncT, typename UnaryFuncT>
MinMax_(TupleFuncT && tuple_ftr,UnaryFuncT && unary_ftr,Vec v,PetscInt * p,PetscReal * m)2082 inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
2083 {
2084   PetscFunctionBegin;
2085   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
2086   if (p) *p = -1;
2087   if (const auto n = v->map->n) {
2088     PetscDeviceContext dctx;
2089     cupmStream_t       stream;
2090 
2091     PetscCall(GetHandles_(&dctx, &stream));
2092     // needed to:
2093     // 1. switch between transform_reduce and reduce
2094     // 2. strip the real_part functor from the arguments
2095 #if PetscDefined(USE_COMPLEX)
2096   #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
2097 #else
2098   #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
2099 #endif
2100     {
2101       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
2102 
2103       if (p) {
2104         // clang-format off
2105         const auto zip = thrust::make_zip_iterator(
2106           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
2107         );
2108         // clang-format on
2109         // need to use preprocessor conditionals since otherwise thrust complains about not being
2110         // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
2111         // builds...
2112         // clang-format off
2113         PetscCallThrust(
2114           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
2115             stream, zip, zip + n, detail::real_part{},
2116             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
2117           );
2118         );
2119         // clang-format on
2120       } else {
2121         // clang-format off
2122         PetscCallThrust(
2123           *m = THRUST_MINMAX_REDUCE(
2124             stream, vptr, vptr + n, detail::real_part{},
2125             *m, std::forward<UnaryFuncT>(unary_ftr)
2126           );
2127         );
2128         // clang-format on
2129       }
2130     }
2131 #undef THRUST_MINMAX_REDUCE
2132   }
2133   // REVIEW ME: flops?
2134   PetscFunctionReturn(PETSC_SUCCESS);
2135 }
2136 
2137 // v->ops->max
2138 template <device::cupm::DeviceType T>
Max(Vec v,PetscInt * p,PetscReal * m)2139 inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept
2140 {
2141 #if CCCL_VERSION >= 3001000
2142   using tuple_functor = detail::tuple_compare<cuda::std::greater<PetscReal>>;
2143   using unary_functor = cuda::maximum<PetscReal>;
2144 #else
2145   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
2146   using unary_functor = thrust::maximum<PetscReal>;
2147 #endif
2148 
2149   PetscFunctionBegin;
2150   *m = PETSC_MIN_REAL;
2151   // use {} constructor syntax otherwise most vexing parse
2152   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
2153   PetscFunctionReturn(PETSC_SUCCESS);
2154 }
2155 
2156 // v->ops->min
2157 template <device::cupm::DeviceType T>
Min(Vec v,PetscInt * p,PetscReal * m)2158 inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept
2159 {
2160 #if CCCL_VERSION >= 3001000
2161   using tuple_functor = detail::tuple_compare<cuda::std::less<PetscReal>>;
2162   using unary_functor = cuda::minimum<PetscReal>;
2163 #else
2164   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
2165   using unary_functor = thrust::minimum<PetscReal>;
2166 #endif
2167 
2168   PetscFunctionBegin;
2169   *m = PETSC_MAX_REAL;
2170   // use {} constructor syntax otherwise most vexing parse
2171   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
2172   PetscFunctionReturn(PETSC_SUCCESS);
2173 }
2174 
2175 // v->ops->sum
2176 template <device::cupm::DeviceType T>
Sum(Vec v,PetscScalar * sum)2177 inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept
2178 {
2179   PetscFunctionBegin;
2180   if (const auto n = v->map->n) {
2181     PetscDeviceContext dctx;
2182     cupmStream_t       stream;
2183 
2184     PetscCall(GetHandles_(&dctx, &stream));
2185     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
2186     // REVIEW ME: why not cupmBlasXasum()?
2187     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
2188     // REVIEW ME: must be at least n additions
2189     PetscCall(PetscLogGpuFlops(n));
2190   } else {
2191     *sum = 0.0;
2192   }
2193   PetscFunctionReturn(PETSC_SUCCESS);
2194 }
2195 
2196 template <device::cupm::DeviceType T>
ShiftAsync(Vec v,PetscScalar shift,PetscDeviceContext dctx)2197 inline PetscErrorCode VecSeq_CUPM<T>::ShiftAsync(Vec v, PetscScalar shift, PetscDeviceContext dctx) noexcept
2198 {
2199   PetscFunctionBegin;
2200   PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v, nullptr, dctx));
2201   PetscFunctionReturn(PETSC_SUCCESS);
2202 }
2203 
2204 template <device::cupm::DeviceType T>
Shift(Vec v,PetscScalar shift)2205 inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept
2206 {
2207   PetscFunctionBegin;
2208   PetscCall(ShiftAsync(v, shift, nullptr));
2209   PetscFunctionReturn(PETSC_SUCCESS);
2210 }
2211 
2212 template <device::cupm::DeviceType T>
SetRandom(Vec v,PetscRandom rand)2213 inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept
2214 {
2215   PetscFunctionBegin;
2216   if (const auto n = v->map->n) {
2217     PetscBool          iscurand;
2218     PetscDeviceContext dctx;
2219 
2220     PetscCall(GetHandles_(&dctx));
2221     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
2222     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
2223     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
2224   } else {
2225     PetscCall(MaybeIncrementEmptyLocalVec(v));
2226   }
2227   // REVIEW ME: flops????
2228   // REVIEW ME: Timing???
2229   PetscFunctionReturn(PETSC_SUCCESS);
2230 }
2231 
2232 // v->ops->setpreallocation
2233 template <device::cupm::DeviceType T>
SetPreallocationCOO(Vec v,PetscCount ncoo,const PetscInt coo_i[])2234 inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
2235 {
2236   PetscDeviceContext dctx;
2237 
2238   PetscFunctionBegin;
2239   PetscCall(GetHandles_(&dctx));
2240   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
2241   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
2242   PetscFunctionReturn(PETSC_SUCCESS);
2243 }
2244 
2245 // v->ops->setvaluescoo
2246 template <device::cupm::DeviceType T>
SetValuesCOO(Vec x,const PetscScalar v[],InsertMode imode)2247 inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
2248 {
2249   auto               vv = const_cast<PetscScalar *>(v);
2250   PetscMemType       memtype;
2251   PetscDeviceContext dctx;
2252   cupmStream_t       stream;
2253 
2254   PetscFunctionBegin;
2255   PetscCall(GetHandles_(&dctx, &stream));
2256   PetscCall(PetscGetMemType(v, &memtype));
2257   if (PetscMemTypeHost(memtype)) {
2258     const auto size = VecIMPLCast(x)->coo_n;
2259 
2260     // If user gave v[] in host, we might need to copy it to device if any
2261     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
2262     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
2263   }
2264 
2265   if (const auto n = x->map->n) {
2266     const auto vcu = VecCUPMCast(x);
2267 
2268     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
2269   } else {
2270     PetscCall(MaybeIncrementEmptyLocalVec(x));
2271   }
2272 
2273   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
2274   PetscCall(PetscDeviceContextSynchronize(dctx));
2275   PetscFunctionReturn(PETSC_SUCCESS);
2276 }
2277 
2278 } // namespace impl
2279 
2280 } // namespace cupm
2281 
2282 } // namespace vec
2283 
2284 } // namespace Petsc
2285