1 #pragma once
2
3 #include "vecseqcupm.hpp"
4
5 #include <petsc/private/randomimpl.h> // for _p_PetscRandom
6
7 #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
8 #include "../src/sys/objects/device/impls/cupm/kernels.hpp"
9
10 #if PetscDefined(USE_COMPLEX)
11 #include <thrust/transform_reduce.h>
12 #endif
13 #include <thrust/transform.h>
14 #include <thrust/reduce.h>
15 #include <thrust/functional.h>
16 #include <thrust/tuple.h>
17 #include <thrust/device_ptr.h>
18 #include <thrust/iterator/zip_iterator.h>
19 #include <thrust/iterator/counting_iterator.h>
20 #include <thrust/iterator/constant_iterator.h>
21 #include <thrust/inner_product.h>
22
23 namespace Petsc
24 {
25
26 namespace vec
27 {
28
29 namespace cupm
30 {
31
32 namespace impl
33 {
34
35 // ==========================================================================================
36 // VecSeq_CUPM - Private API
37 // ==========================================================================================
38
39 template <device::cupm::DeviceType T>
VecIMPLCast_(Vec v)40 inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
41 {
42 return static_cast<Vec_Seq *>(v->data);
43 }
44
45 template <device::cupm::DeviceType T>
VECIMPLCUPM_()46 inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
47 {
48 return VECSEQCUPM();
49 }
50
51 template <device::cupm::DeviceType T>
VECIMPL_()52 inline constexpr VecType VecSeq_CUPM<T>::VECIMPL_() noexcept
53 {
54 return VECSEQ;
55 }
56
57 template <device::cupm::DeviceType T>
ClearAsyncFunctions(Vec v)58 inline PetscErrorCode VecSeq_CUPM<T>::ClearAsyncFunctions(Vec v) noexcept
59 {
60 PetscFunctionBegin;
61 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Abs), nullptr));
62 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBY), nullptr));
63 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBYPCZ), nullptr));
64 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPY), nullptr));
65 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AYPX), nullptr));
66 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Conjugate), nullptr));
67 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Copy), nullptr));
68 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Exp), nullptr));
69 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Log), nullptr));
70 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(MAXPY), nullptr));
71 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseDivide), nullptr));
72 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMax), nullptr));
73 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMaxAbs), nullptr));
74 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMin), nullptr));
75 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMult), nullptr));
76 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Reciprocal), nullptr));
77 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Scale), nullptr));
78 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Set), nullptr));
79 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Shift), nullptr));
80 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(SqrtAbs), nullptr));
81 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Swap), nullptr));
82 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(WAXPY), nullptr));
83 PetscFunctionReturn(PETSC_SUCCESS);
84 }
85
86 template <device::cupm::DeviceType T>
InitializeAsyncFunctions(Vec v)87 inline PetscErrorCode VecSeq_CUPM<T>::InitializeAsyncFunctions(Vec v) noexcept
88 {
89 PetscFunctionBegin;
90 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Abs), VecSeq_CUPM<T>::AbsAsync));
91 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBY), VecSeq_CUPM<T>::AXPBYAsync));
92 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBYPCZ), VecSeq_CUPM<T>::AXPBYPCZAsync));
93 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPY), VecSeq_CUPM<T>::AXPYAsync));
94 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AYPX), VecSeq_CUPM<T>::AYPXAsync));
95 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Conjugate), VecSeq_CUPM<T>::ConjugateAsync));
96 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Copy), VecSeq_CUPM<T>::CopyAsync));
97 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Exp), VecSeq_CUPM<T>::ExpAsync));
98 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Log), VecSeq_CUPM<T>::LogAsync));
99 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(MAXPY), VecSeq_CUPM<T>::MAXPYAsync));
100 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseDivide), VecSeq_CUPM<T>::PointwiseDivideAsync));
101 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMax), VecSeq_CUPM<T>::PointwiseMaxAsync));
102 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMaxAbs), VecSeq_CUPM<T>::PointwiseMaxAbsAsync));
103 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMin), VecSeq_CUPM<T>::PointwiseMinAsync));
104 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMult), VecSeq_CUPM<T>::PointwiseMultAsync));
105 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Reciprocal), VecSeq_CUPM<T>::ReciprocalAsync));
106 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Scale), VecSeq_CUPM<T>::ScaleAsync));
107 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Set), VecSeq_CUPM<T>::SetAsync));
108 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Shift), VecSeq_CUPM<T>::ShiftAsync));
109 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(SqrtAbs), VecSeq_CUPM<T>::SqrtAbsAsync));
110 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Swap), VecSeq_CUPM<T>::SwapAsync));
111 PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(WAXPY), VecSeq_CUPM<T>::WAXPYAsync));
112 PetscFunctionReturn(PETSC_SUCCESS);
113 }
114
115 template <device::cupm::DeviceType T>
VecDestroy_IMPL_(Vec v)116 inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
117 {
118 PetscFunctionBegin;
119 PetscCall(ClearAsyncFunctions(v));
120 PetscCall(VecDestroy_Seq(v));
121 PetscFunctionReturn(PETSC_SUCCESS);
122 }
123
124 template <device::cupm::DeviceType T>
VecResetArray_IMPL_(Vec v)125 inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
126 {
127 return VecResetArray_Seq(v);
128 }
129
130 template <device::cupm::DeviceType T>
VecPlaceArray_IMPL_(Vec v,const PetscScalar * a)131 inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
132 {
133 return VecPlaceArray_Seq(v, a);
134 }
135
136 template <device::cupm::DeviceType T>
VecCreate_IMPL_Private_(Vec v,PetscBool * alloc_missing,PetscInt,PetscScalar * host_array)137 inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
138 {
139 PetscMPIInt size;
140
141 PetscFunctionBegin;
142 if (alloc_missing) *alloc_missing = PETSC_FALSE;
143 PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
144 PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
145 PetscCall(VecCreate_Seq_Private(v, host_array));
146 PetscCall(InitializeAsyncFunctions(v));
147 PetscFunctionReturn(PETSC_SUCCESS);
148 }
149
150 // for functions with an early return based one vec size we still need to artificially bump the
151 // object state. This is to prevent the following:
152 //
153 // 0. Suppose you have a Vec {
154 // rank 0: [0],
155 // rank 1: [<empty>]
156 // }
157 // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
158 // 2. Vec enters e.g. VecSet(10)
159 // 3. rank 1 has local size 0 and bails immediately
160 // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
161 // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
162 // 6. Vec enters VecNorm(), and calls VecNormAvailable()
163 // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
164 // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
165 // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
166 template <device::cupm::DeviceType T>
MaybeIncrementEmptyLocalVec(Vec v)167 inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
168 {
169 PetscFunctionBegin;
170 if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
171 PetscFunctionReturn(PETSC_SUCCESS);
172 }
173
174 template <device::cupm::DeviceType T>
CreateSeqCUPM_(Vec v,PetscDeviceContext dctx,PetscScalar * host_array,PetscScalar * device_array)175 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
176 {
177 PetscFunctionBegin;
178 PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
179 PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
180 PetscFunctionReturn(PETSC_SUCCESS);
181 }
182
183 template <device::cupm::DeviceType T>
184 template <typename BinaryFuncT>
PointwiseBinary_(BinaryFuncT && binary,Vec xin,Vec yin,Vec zout,PetscDeviceContext dctx)185 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout, PetscDeviceContext dctx) noexcept
186 {
187 PetscFunctionBegin;
188 if (const auto n = zout->map->n) {
189 cupmStream_t stream;
190
191 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
192 PetscCall(GetHandlesFrom_(dctx, &stream));
193 // clang-format off
194 PetscCallThrust(
195 const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data());
196
197 THRUST_CALL(
198 thrust::transform,
199 stream,
200 dxptr, dxptr + n,
201 thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()),
202 thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()),
203 std::forward<BinaryFuncT>(binary)
204 )
205 );
206 // clang-format on
207 PetscCall(PetscLogGpuFlops(n));
208 PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
209 } else {
210 PetscCall(MaybeIncrementEmptyLocalVec(zout));
211 }
212 PetscFunctionReturn(PETSC_SUCCESS);
213 }
214
215 template <device::cupm::DeviceType T>
216 template <typename BinaryFuncT>
PointwiseBinaryDispatch_(PetscErrorCode (* VecSeqFunction)(Vec,Vec,Vec),BinaryFuncT && binary,Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)217 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinaryDispatch_(PetscErrorCode (*VecSeqFunction)(Vec, Vec, Vec), BinaryFuncT &&binary, Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
218 {
219 PetscFunctionBegin;
220 if (xin->boundtocpu || yin->boundtocpu) {
221 PetscCall((*VecSeqFunction)(wout, xin, yin));
222 } else {
223 // note order of arguments! xin and yin are read, wout is written!
224 PetscCall(PointwiseBinary_(std::forward<BinaryFuncT>(binary), xin, yin, wout, dctx));
225 }
226 PetscFunctionReturn(PETSC_SUCCESS);
227 }
228
229 template <device::cupm::DeviceType T>
230 template <typename UnaryFuncT>
PointwiseUnary_(UnaryFuncT && unary,Vec xinout,Vec yin,PetscDeviceContext dctx)231 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin, PetscDeviceContext dctx) noexcept
232 {
233 const auto inplace = !yin || (xinout == yin);
234
235 PetscFunctionBegin;
236 if (const auto n = xinout->map->n) {
237 cupmStream_t stream;
238 const auto apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) {
239 PetscFunctionBegin;
240 // clang-format off
241 PetscCallThrust(
242 const auto xptr = thrust::device_pointer_cast(xinout);
243
244 THRUST_CALL(
245 thrust::transform,
246 stream,
247 xptr, xptr + n,
248 (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr,
249 std::forward<UnaryFuncT>(unary)
250 )
251 );
252 // clang-format on
253 PetscFunctionReturn(PETSC_SUCCESS);
254 };
255
256 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
257 PetscCall(GetHandlesFrom_(dctx, &stream));
258 if (inplace) {
259 PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data()));
260 } else {
261 PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
262 }
263 PetscCall(PetscLogGpuFlops(n));
264 PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
265 } else {
266 if (inplace) {
267 PetscCall(MaybeIncrementEmptyLocalVec(xinout));
268 } else {
269 PetscCall(MaybeIncrementEmptyLocalVec(yin));
270 }
271 }
272 PetscFunctionReturn(PETSC_SUCCESS);
273 }
274
275 // ==========================================================================================
276 // VecSeq_CUPM - Public API - Constructors
277 // ==========================================================================================
278
279 // VecCreateSeqCUPM()
280 template <device::cupm::DeviceType T>
CreateSeqCUPM(MPI_Comm comm,PetscInt bs,PetscInt n,Vec * v,PetscBool call_set_type)281 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
282 {
283 PetscFunctionBegin;
284 PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
285 PetscFunctionReturn(PETSC_SUCCESS);
286 }
287
288 // VecCreateSeqCUPMWithArrays()
289 template <device::cupm::DeviceType T>
CreateSeqCUPMWithBothArrays(MPI_Comm comm,PetscInt bs,PetscInt n,const PetscScalar host_array[],const PetscScalar device_array[],Vec * v)290 inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
291 {
292 PetscDeviceContext dctx;
293
294 PetscFunctionBegin;
295 PetscCall(GetHandles_(&dctx));
296 // do NOT call VecSetType(), otherwise ops->create() -> create() ->
297 // CreateSeqCUPM_() is called!
298 PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE));
299 PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
300 PetscFunctionReturn(PETSC_SUCCESS);
301 }
302
303 // v->ops->duplicate
304 template <device::cupm::DeviceType T>
Duplicate(Vec v,Vec * y)305 inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
306 {
307 PetscDeviceContext dctx;
308
309 PetscFunctionBegin;
310 PetscCall(GetHandles_(&dctx));
311 PetscCall(Duplicate_CUPMBase(v, y, dctx));
312 PetscFunctionReturn(PETSC_SUCCESS);
313 }
314
315 // ==========================================================================================
316 // VecSeq_CUPM - Public API - Utility
317 // ==========================================================================================
318
319 // v->ops->bindtocpu
320 template <device::cupm::DeviceType T>
BindToCPU(Vec v,PetscBool usehost)321 inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
322 {
323 PetscDeviceContext dctx;
324
325 PetscFunctionBegin;
326 PetscCall(GetHandles_(&dctx));
327 PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));
328
329 // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
330 VecSetOp_CUPM(dot, VecDot_Seq, Dot);
331 VecSetOp_CUPM(norm, VecNorm_Seq, Norm);
332 VecSetOp_CUPM(tdot, VecTDot_Seq, TDot);
333 VecSetOp_CUPM(mdot, VecMDot_Seq, MDot);
334 VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
335 VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
336 v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
337 VecSetOp_CUPM(max, VecMax_Seq, Max);
338 VecSetOp_CUPM(min, VecMin_Seq, Min);
339 VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO);
340 VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO);
341 PetscFunctionReturn(PETSC_SUCCESS);
342 }
343
344 // ==========================================================================================
345 // VecSeq_CUPM - Public API - Mutators
346 // ==========================================================================================
347
348 // v->ops->getlocalvector or v->ops->getlocalvectorread
349 template <device::cupm::DeviceType T>
350 template <PetscMemoryAccessMode access>
GetLocalVector(Vec v,Vec w)351 inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept
352 {
353 PetscBool wisseqcupm;
354
355 PetscFunctionBegin;
356 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
357 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
358 if (wisseqcupm) {
359 if (const auto wseq = VecIMPLCast(w)) {
360 if (auto &alloced = wseq->array_allocated) {
361 const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));
362
363 PetscCall(PetscFree(alloced));
364 }
365 wseq->array = nullptr;
366 wseq->unplacedarray = nullptr;
367 }
368 if (const auto wcu = VecCUPMCast(w)) {
369 if (auto &device_array = wcu->array_d) {
370 cupmStream_t stream;
371
372 PetscCall(GetHandles_(&stream));
373 PetscCallCUPM(cupmFreeAsync(device_array, stream));
374 }
375 PetscCall(PetscFree(w->spptr /* wcu */));
376 }
377 }
378 if (v->petscnative && wisseqcupm) {
379 PetscCall(PetscFree(w->data));
380 w->data = v->data;
381 w->offloadmask = v->offloadmask;
382 w->pinned_memory = v->pinned_memory;
383 w->spptr = v->spptr;
384 PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
385 } else {
386 const auto array = &VecIMPLCast(w)->array;
387
388 if (access == PETSC_MEMORY_ACCESS_READ) {
389 PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
390 } else {
391 PetscCall(VecGetArray(v, array));
392 }
393 w->offloadmask = PETSC_OFFLOAD_CPU;
394 if (wisseqcupm) {
395 PetscDeviceContext dctx;
396
397 PetscCall(GetHandles_(&dctx));
398 PetscCall(DeviceAllocateCheck_(dctx, w));
399 }
400 }
401 PetscFunctionReturn(PETSC_SUCCESS);
402 }
403
404 // v->ops->restorelocalvector or v->ops->restorelocalvectorread
405 template <device::cupm::DeviceType T>
406 template <PetscMemoryAccessMode access>
RestoreLocalVector(Vec v,Vec w)407 inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept
408 {
409 PetscBool wisseqcupm;
410
411 PetscFunctionBegin;
412 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
413 PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
414 if (v->petscnative && wisseqcupm) {
415 // the assignments to nullptr are __critical__, as w may persist after this call returns
416 // and shouldn't share data with v!
417 v->pinned_memory = w->pinned_memory;
418 v->offloadmask = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
419 v->data = util::exchange(w->data, nullptr);
420 v->spptr = util::exchange(w->spptr, nullptr);
421 } else {
422 const auto array = &VecIMPLCast(w)->array;
423
424 if (access == PETSC_MEMORY_ACCESS_READ) {
425 PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
426 } else {
427 PetscCall(VecRestoreArray(v, array));
428 }
429 if (w->spptr && wisseqcupm) {
430 cupmStream_t stream;
431
432 PetscCall(GetHandles_(&stream));
433 PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
434 PetscCall(PetscFree(w->spptr));
435 }
436 }
437 PetscFunctionReturn(PETSC_SUCCESS);
438 }
439
440 // ==========================================================================================
441 // VecSeq_CUPM - Public API - Compute Methods
442 // ==========================================================================================
443
444 // VecAYPXAsync_Private
445 template <device::cupm::DeviceType T>
AYPXAsync(Vec yin,PetscScalar alpha,Vec xin,PetscDeviceContext dctx)446 inline PetscErrorCode VecSeq_CUPM<T>::AYPXAsync(Vec yin, PetscScalar alpha, Vec xin, PetscDeviceContext dctx) noexcept
447 {
448 const auto n = static_cast<cupmBlasInt_t>(yin->map->n);
449 PetscBool xiscupm;
450
451 PetscFunctionBegin;
452 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
453 if (!xiscupm) {
454 PetscCall(VecAYPX_Seq(yin, alpha, xin));
455 PetscFunctionReturn(PETSC_SUCCESS);
456 }
457 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
458 if (alpha == PetscScalar(0.0)) {
459 cupmStream_t stream;
460
461 PetscCall(GetHandlesFrom_(dctx, &stream));
462 PetscCall(PetscLogGpuTimeBegin());
463 PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
464 PetscCall(PetscLogGpuTimeEnd());
465 } else if (n) {
466 const auto alphaIsOne = alpha == PetscScalar(1.0);
467 const auto calpha = cupmScalarPtrCast(&alpha);
468 cupmBlasHandle_t cupmBlasHandle;
469
470 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
471 {
472 const auto yptr = DeviceArrayReadWrite(dctx, yin);
473 const auto xptr = DeviceArrayRead(dctx, xin);
474
475 PetscCall(PetscLogGpuTimeBegin());
476 if (alphaIsOne) {
477 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
478 } else {
479 const auto one = cupmScalarCast(1.0);
480
481 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
482 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
483 }
484 PetscCall(PetscLogGpuTimeEnd());
485 }
486 PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
487 }
488 if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
489 PetscFunctionReturn(PETSC_SUCCESS);
490 }
491
492 // v->ops->aypx
493 template <device::cupm::DeviceType T>
AYPX(Vec yin,PetscScalar alpha,Vec xin)494 inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept
495 {
496 PetscFunctionBegin;
497 PetscCall(AYPXAsync(yin, alpha, xin, nullptr));
498 PetscFunctionReturn(PETSC_SUCCESS);
499 }
500
501 // VecAXPYAsync_Private
502 template <device::cupm::DeviceType T>
AXPYAsync(Vec yin,PetscScalar alpha,Vec xin,PetscDeviceContext dctx)503 inline PetscErrorCode VecSeq_CUPM<T>::AXPYAsync(Vec yin, PetscScalar alpha, Vec xin, PetscDeviceContext dctx) noexcept
504 {
505 PetscBool xiscupm;
506
507 PetscFunctionBegin;
508 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
509 if (xiscupm) {
510 const auto n = static_cast<cupmBlasInt_t>(yin->map->n);
511 cupmBlasHandle_t cupmBlasHandle;
512
513 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
514 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
515 PetscCall(PetscLogGpuTimeBegin());
516 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
517 PetscCall(PetscLogGpuTimeEnd());
518 PetscCall(PetscLogGpuFlops(2 * n));
519 if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
520 } else {
521 PetscCall(VecAXPY_Seq(yin, alpha, xin));
522 }
523 PetscFunctionReturn(PETSC_SUCCESS);
524 }
525
526 // v->ops->axpy
527 template <device::cupm::DeviceType T>
AXPY(Vec yin,PetscScalar alpha,Vec xin)528 inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept
529 {
530 PetscFunctionBegin;
531 PetscCall(AXPYAsync(yin, alpha, xin, nullptr));
532 PetscFunctionReturn(PETSC_SUCCESS);
533 }
534
535 namespace detail
536 {
537
538 struct divides {
operator ()Petsc::vec::cupm::impl::detail::divides539 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return rhs == PetscScalar{0.0} ? rhs : lhs / rhs; }
540 };
541
542 } // namespace detail
543
544 // VecPointwiseDivideAsync_Private
545 template <device::cupm::DeviceType T>
PointwiseDivideAsync(Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)546 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivideAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
547 {
548 PetscFunctionBegin;
549 PetscCall(PointwiseBinaryDispatch_(VecPointwiseDivide_Seq, detail::divides{}, wout, xin, yin, dctx));
550 PetscFunctionReturn(PETSC_SUCCESS);
551 }
552
553 // v->ops->pointwisedivide
554 template <device::cupm::DeviceType T>
PointwiseDivide(Vec wout,Vec xin,Vec yin)555 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec wout, Vec xin, Vec yin) noexcept
556 {
557 PetscFunctionBegin;
558 PetscCall(PointwiseDivideAsync(wout, xin, yin, nullptr));
559 PetscFunctionReturn(PETSC_SUCCESS);
560 }
561
562 // VecPointwiseMultAsync_Private
563 template <device::cupm::DeviceType T>
PointwiseMultAsync(Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)564 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMultAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
565 {
566 PetscFunctionBegin;
567 PetscCall(PointwiseBinaryDispatch_(VecPointwiseMult_Seq, thrust::multiplies<PetscScalar>{}, wout, xin, yin, dctx));
568 PetscFunctionReturn(PETSC_SUCCESS);
569 }
570
571 // v->ops->pointwisemult
572 template <device::cupm::DeviceType T>
PointwiseMult(Vec wout,Vec xin,Vec yin)573 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec wout, Vec xin, Vec yin) noexcept
574 {
575 PetscFunctionBegin;
576 PetscCall(PointwiseMultAsync(wout, xin, yin, nullptr));
577 PetscFunctionReturn(PETSC_SUCCESS);
578 }
579
580 namespace detail
581 {
582
583 struct MaximumRealPart {
operator ()Petsc::vec::cupm::impl::detail::MaximumRealPart584 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::maximum<PetscReal>{}(PetscRealPart(lhs), PetscRealPart(rhs)); }
585 };
586
587 } // namespace detail
588
589 // VecPointwiseMaxAsync_Private
590 template <device::cupm::DeviceType T>
PointwiseMaxAsync(Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)591 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
592 {
593 PetscFunctionBegin;
594 PetscCall(PointwiseBinaryDispatch_(VecPointwiseMax_Seq, detail::MaximumRealPart{}, wout, xin, yin, dctx));
595 PetscFunctionReturn(PETSC_SUCCESS);
596 }
597
598 // v->ops->pointwisemax
599 template <device::cupm::DeviceType T>
PointwiseMax(Vec wout,Vec xin,Vec yin)600 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMax(Vec wout, Vec xin, Vec yin) noexcept
601 {
602 PetscFunctionBegin;
603 PetscCall(PointwiseMaxAsync(wout, xin, yin, nullptr));
604 PetscFunctionReturn(PETSC_SUCCESS);
605 }
606
607 namespace detail
608 {
609
610 struct MaximumAbsoluteValue {
operator ()Petsc::vec::cupm::impl::detail::MaximumAbsoluteValue611 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::maximum<PetscReal>{}(PetscAbsScalar(lhs), PetscAbsScalar(rhs)); }
612 };
613
614 } // namespace detail
615
616 // VecPointwiseMaxAbsAsync_Private
617 template <device::cupm::DeviceType T>
PointwiseMaxAbsAsync(Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)618 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAbsAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
619 {
620 PetscFunctionBegin;
621 PetscCall(PointwiseBinaryDispatch_(VecPointwiseMaxAbs_Seq, detail::MaximumAbsoluteValue{}, wout, xin, yin, dctx));
622 PetscFunctionReturn(PETSC_SUCCESS);
623 }
624
625 // v->ops->pointwisemaxabs
626 template <device::cupm::DeviceType T>
PointwiseMaxAbs(Vec wout,Vec xin,Vec yin)627 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAbs(Vec wout, Vec xin, Vec yin) noexcept
628 {
629 PetscFunctionBegin;
630 PetscCall(PointwiseMaxAbsAsync(wout, xin, yin, nullptr));
631 PetscFunctionReturn(PETSC_SUCCESS);
632 }
633
634 namespace detail
635 {
636
637 struct MinimumRealPart {
operator ()Petsc::vec::cupm::impl::detail::MinimumRealPart638 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::minimum<PetscReal>{}(PetscRealPart(lhs), PetscRealPart(rhs)); }
639 };
640
641 } // namespace detail
642
643 // VecPointwiseMinAsync_Private
644 template <device::cupm::DeviceType T>
PointwiseMinAsync(Vec wout,Vec xin,Vec yin,PetscDeviceContext dctx)645 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMinAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
646 {
647 PetscFunctionBegin;
648 PetscCall(PointwiseBinaryDispatch_(VecPointwiseMin_Seq, detail::MinimumRealPart{}, wout, xin, yin, dctx));
649 PetscFunctionReturn(PETSC_SUCCESS);
650 }
651
652 // v->ops->pointwisemin
653 template <device::cupm::DeviceType T>
PointwiseMin(Vec wout,Vec xin,Vec yin)654 inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMin(Vec wout, Vec xin, Vec yin) noexcept
655 {
656 PetscFunctionBegin;
657 PetscCall(PointwiseMinAsync(wout, xin, yin, nullptr));
658 PetscFunctionReturn(PETSC_SUCCESS);
659 }
660
661 namespace detail
662 {
663
664 struct Reciprocal {
operator ()Petsc::vec::cupm::impl::detail::Reciprocal665 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept
666 {
667 // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
668 // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
669 // everything in PetscScalar...
670 return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
671 }
672 };
673
674 } // namespace detail
675
676 // VecReciprocalAsync_Private
677 template <device::cupm::DeviceType T>
ReciprocalAsync(Vec xin,PetscDeviceContext dctx)678 inline PetscErrorCode VecSeq_CUPM<T>::ReciprocalAsync(Vec xin, PetscDeviceContext dctx) noexcept
679 {
680 PetscFunctionBegin;
681 PetscCall(PointwiseUnary_(detail::Reciprocal{}, xin, nullptr, dctx));
682 PetscFunctionReturn(PETSC_SUCCESS);
683 }
684
685 // v->ops->reciprocal
686 template <device::cupm::DeviceType T>
Reciprocal(Vec xin)687 inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept
688 {
689 PetscFunctionBegin;
690 PetscCall(ReciprocalAsync(xin, nullptr));
691 PetscFunctionReturn(PETSC_SUCCESS);
692 }
693
694 namespace detail
695 {
696
697 struct AbsoluteValue {
operator ()Petsc::vec::cupm::impl::detail::AbsoluteValue698 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscAbsScalar(s); }
699 };
700
701 } // namespace detail
702
703 // VecAbsAsync_Private
704 template <device::cupm::DeviceType T>
AbsAsync(Vec xin,PetscDeviceContext dctx)705 inline PetscErrorCode VecSeq_CUPM<T>::AbsAsync(Vec xin, PetscDeviceContext dctx) noexcept
706 {
707 PetscFunctionBegin;
708 PetscCall(PointwiseUnary_(detail::AbsoluteValue{}, xin, nullptr, dctx));
709 PetscFunctionReturn(PETSC_SUCCESS);
710 }
711
712 // v->ops->abs
713 template <device::cupm::DeviceType T>
Abs(Vec xin)714 inline PetscErrorCode VecSeq_CUPM<T>::Abs(Vec xin) noexcept
715 {
716 PetscFunctionBegin;
717 PetscCall(AbsAsync(xin, nullptr));
718 PetscFunctionReturn(PETSC_SUCCESS);
719 }
720
721 namespace detail
722 {
723
724 struct SquareRootAbsoluteValue {
operator ()Petsc::vec::cupm::impl::detail::SquareRootAbsoluteValue725 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscSqrtReal(PetscAbsScalar(s)); }
726 };
727
728 } // namespace detail
729
730 // VecSqrtAbsAsync_Private
731 template <device::cupm::DeviceType T>
SqrtAbsAsync(Vec xin,PetscDeviceContext dctx)732 inline PetscErrorCode VecSeq_CUPM<T>::SqrtAbsAsync(Vec xin, PetscDeviceContext dctx) noexcept
733 {
734 PetscFunctionBegin;
735 PetscCall(PointwiseUnary_(detail::SquareRootAbsoluteValue{}, xin, nullptr, dctx));
736 PetscFunctionReturn(PETSC_SUCCESS);
737 }
738
739 // v->ops->sqrt
740 template <device::cupm::DeviceType T>
SqrtAbs(Vec xin)741 inline PetscErrorCode VecSeq_CUPM<T>::SqrtAbs(Vec xin) noexcept
742 {
743 PetscFunctionBegin;
744 PetscCall(SqrtAbsAsync(xin, nullptr));
745 PetscFunctionReturn(PETSC_SUCCESS);
746 }
747
748 namespace detail
749 {
750
751 struct Exponent {
operator ()Petsc::vec::cupm::impl::detail::Exponent752 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscExpScalar(s); }
753 };
754
755 } // namespace detail
756
757 // VecExpAsync_Private
758 template <device::cupm::DeviceType T>
ExpAsync(Vec xin,PetscDeviceContext dctx)759 inline PetscErrorCode VecSeq_CUPM<T>::ExpAsync(Vec xin, PetscDeviceContext dctx) noexcept
760 {
761 PetscFunctionBegin;
762 PetscCall(PointwiseUnary_(detail::Exponent{}, xin, nullptr, dctx));
763 PetscFunctionReturn(PETSC_SUCCESS);
764 }
765
766 // v->ops->exp
767 template <device::cupm::DeviceType T>
Exp(Vec xin)768 inline PetscErrorCode VecSeq_CUPM<T>::Exp(Vec xin) noexcept
769 {
770 PetscFunctionBegin;
771 PetscCall(ExpAsync(xin, nullptr));
772 PetscFunctionReturn(PETSC_SUCCESS);
773 }
774
775 namespace detail
776 {
777
778 struct Logarithm {
operator ()Petsc::vec::cupm::impl::detail::Logarithm779 PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscLogScalar(s); }
780 };
781
782 } // namespace detail
783
784 // VecLogAsync_Private
785 template <device::cupm::DeviceType T>
LogAsync(Vec xin,PetscDeviceContext dctx)786 inline PetscErrorCode VecSeq_CUPM<T>::LogAsync(Vec xin, PetscDeviceContext dctx) noexcept
787 {
788 PetscFunctionBegin;
789 PetscCall(PointwiseUnary_(detail::Logarithm{}, xin, nullptr, dctx));
790 PetscFunctionReturn(PETSC_SUCCESS);
791 }
792
793 // v->ops->log
794 template <device::cupm::DeviceType T>
Log(Vec xin)795 inline PetscErrorCode VecSeq_CUPM<T>::Log(Vec xin) noexcept
796 {
797 PetscFunctionBegin;
798 PetscCall(LogAsync(xin, nullptr));
799 PetscFunctionReturn(PETSC_SUCCESS);
800 }
801
802 // v->ops->waxpy
803 template <device::cupm::DeviceType T>
WAXPYAsync(Vec win,PetscScalar alpha,Vec xin,Vec yin,PetscDeviceContext dctx)804 inline PetscErrorCode VecSeq_CUPM<T>::WAXPYAsync(Vec win, PetscScalar alpha, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
805 {
806 PetscBool xiscupm, yiscupm;
807
808 PetscFunctionBegin;
809 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
810 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
811 if (!xiscupm || !yiscupm) {
812 PetscCall(VecWAXPY_Seq(win, alpha, xin, yin));
813 PetscFunctionReturn(PETSC_SUCCESS);
814 }
815 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
816 if (alpha == PetscScalar(0.0)) {
817 PetscCall(CopyAsync(yin, win, dctx));
818 } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
819 cupmBlasHandle_t cupmBlasHandle;
820 cupmStream_t stream;
821 PetscBool xiscupm, yiscupm;
822
823 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
824 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
825 if (!xiscupm || !yiscupm) {
826 PetscCall(VecWAXPY_Seq(win, alpha, xin, yin));
827 PetscFunctionReturn(PETSC_SUCCESS);
828 }
829 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle, NULL, &stream));
830 {
831 const auto wptr = DeviceArrayWrite(dctx, win);
832
833 PetscCall(PetscLogGpuTimeBegin());
834 PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
835 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
836 PetscCall(PetscLogGpuTimeEnd());
837 }
838 PetscCall(PetscLogGpuFlops(2 * n));
839 PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
840 }
841 PetscFunctionReturn(PETSC_SUCCESS);
842 }
843
844 // v->ops->waxpy
845 template <device::cupm::DeviceType T>
WAXPY(Vec win,PetscScalar alpha,Vec xin,Vec yin)846 inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
847 {
848 PetscFunctionBegin;
849 PetscCall(WAXPYAsync(win, alpha, xin, yin, nullptr));
850 PetscFunctionReturn(PETSC_SUCCESS);
851 }
852
853 namespace kernels
854 {
855
856 template <typename... Args>
MAXPY_kernel(const PetscInt size,PetscScalar * PETSC_RESTRICT xptr,const PetscScalar * PETSC_RESTRICT aptr,Args...yptr)857 PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
858 {
859 constexpr int N = sizeof...(Args);
860 const auto tx = threadIdx.x;
861 const PetscScalar *yptr_p[] = {yptr...};
862
863 PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];
864
865 // load a to shared memory
866 if (tx < N) aptr_shmem[tx] = aptr[tx];
867 __syncthreads();
868
869 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
870 // these may look the same but give different results!
871 #if 0
872 PetscScalar sum = 0.0;
873
874 #pragma unroll
875 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
876 xptr[i] += sum;
877 #else
878 auto sum = xptr[i];
879
880 #pragma unroll
881 for (auto j = 0; j < N; ++j) sum += aptr_shmem[j] * yptr_p[j][i];
882 xptr[i] = sum;
883 #endif
884 });
885 return;
886 }
887
888 } // namespace kernels
889
890 namespace detail
891 {
892
893 // a helper-struct to gobble the size_t input, it is used with template parameter pack
894 // expansion such that
895 // typename repeat_type<MyType, IdxParamPack>...
896 // expands to
897 // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
898 template <typename T, std::size_t>
899 struct repeat_type {
900 using type = T;
901 };
902
903 } // namespace detail
904
905 template <device::cupm::DeviceType T>
906 template <std::size_t... Idx>
MAXPY_kernel_dispatch_(PetscDeviceContext dctx,cupmStream_t stream,PetscScalar * xptr,const PetscScalar * aptr,const Vec * yin,PetscInt size,util::index_sequence<Idx...>)907 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
908 {
909 PetscFunctionBegin;
910 // clang-format off
911 PetscCall(
912 PetscCUPMLaunchKernel1D(
913 size, 0, stream,
914 kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
915 size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
916 )
917 );
918 // clang-format on
919 PetscFunctionReturn(PETSC_SUCCESS);
920 }
921
922 template <device::cupm::DeviceType T>
923 template <int N>
MAXPY_kernel_dispatch_(PetscDeviceContext dctx,cupmStream_t stream,PetscScalar * xptr,const PetscScalar * aptr,const Vec * yin,PetscInt size,PetscInt & yidx)924 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
925 {
926 PetscFunctionBegin;
927 PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
928 yidx += N;
929 PetscFunctionReturn(PETSC_SUCCESS);
930 }
931
932 // VecMAXPYAsync_Private
933 template <device::cupm::DeviceType T>
MAXPYAsync(Vec xin,PetscInt nv,const PetscScalar * alpha,Vec * yin,PetscDeviceContext dctx)934 inline PetscErrorCode VecSeq_CUPM<T>::MAXPYAsync(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin, PetscDeviceContext dctx) noexcept
935 {
936 const auto n = xin->map->n;
937 cupmStream_t stream;
938 PetscBool yiscupm = PETSC_TRUE;
939
940 PetscFunctionBegin;
941 for (PetscInt i = 0; i < nv && yiscupm; i++) PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin[i]), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
942 if (!yiscupm) {
943 PetscCall(VecMAXPY_Seq(xin, nv, alpha, yin));
944 PetscFunctionReturn(PETSC_SUCCESS);
945 }
946 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
947 PetscCall(GetHandlesFrom_(dctx, &stream));
948 {
949 const auto xptr = DeviceArrayReadWrite(dctx, xin);
950 PetscScalar *d_alpha = nullptr;
951 PetscInt yidx = 0;
952
953 // placement of early-return is deliberate, we would like to capture the
954 // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
955 if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
956 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
957 PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
958 PetscCall(PetscLogGpuTimeBegin());
959 do {
960 switch (nv - yidx) {
961 case 7:
962 PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
963 break;
964 case 6:
965 PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
966 break;
967 case 5:
968 PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
969 break;
970 case 4:
971 PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
972 break;
973 case 3:
974 PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
975 break;
976 case 2:
977 PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
978 break;
979 case 1:
980 PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
981 break;
982 default: // 8 or more
983 PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
984 break;
985 }
986 } while (yidx < nv);
987 PetscCall(PetscLogGpuTimeEnd());
988 PetscCall(PetscDeviceFree(dctx, d_alpha));
989 PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
990 }
991 PetscCall(PetscLogGpuFlops(nv * 2 * n));
992 PetscFunctionReturn(PETSC_SUCCESS);
993 }
994
995 // v->ops->maxpy
996 template <device::cupm::DeviceType T>
MAXPY(Vec xin,PetscInt nv,const PetscScalar * alpha,Vec * yin)997 inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
998 {
999 PetscFunctionBegin;
1000 PetscCall(MAXPYAsync(xin, nv, alpha, yin, nullptr));
1001 PetscFunctionReturn(PETSC_SUCCESS);
1002 }
1003
1004 template <device::cupm::DeviceType T>
Dot(Vec xin,Vec yin,PetscScalar * z)1005 inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept
1006 {
1007 PetscBool yiscupm;
1008
1009 PetscFunctionBegin;
1010 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1011 if (!yiscupm) {
1012 PetscCall(VecDot_Seq(xin, yin, z));
1013 PetscFunctionReturn(PETSC_SUCCESS);
1014 }
1015 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1016 PetscDeviceContext dctx;
1017 cupmBlasHandle_t cupmBlasHandle;
1018
1019 PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1020 // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
1021 // second
1022 PetscCall(PetscLogGpuTimeBegin());
1023 PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
1024 PetscCall(PetscLogGpuTimeEnd());
1025 PetscCall(PetscLogGpuFlops(2 * n - 1));
1026 } else {
1027 *z = 0.0;
1028 }
1029 PetscFunctionReturn(PETSC_SUCCESS);
1030 }
1031
1032 #define MDOT_WORKGROUP_NUM 128
1033 #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM
1034
1035 namespace kernels
1036 {
1037
EntriesPerGroup(const PetscInt size)1038 PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
1039 {
1040 const auto group_entries = (size - 1) / gridDim.x + 1;
1041 // for very small vectors, a group should still do some work
1042 return group_entries ? group_entries : 1;
1043 }
1044
1045 template <typename... ConstPetscScalarPointer>
MDot_kernel(const PetscScalar * PETSC_RESTRICT x,const PetscInt size,PetscScalar * PETSC_RESTRICT results,ConstPetscScalarPointer...y)1046 PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
1047 {
1048 constexpr int N = sizeof...(ConstPetscScalarPointer);
1049 const PetscScalar *ylocal[] = {y...};
1050 PetscScalar sumlocal[N];
1051
1052 PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];
1053
1054 // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
1055 // types, so each of these go on separate lines...
1056 const auto tx = threadIdx.x;
1057 const auto bx = blockIdx.x;
1058 const auto bdx = blockDim.x;
1059 const auto gdx = gridDim.x;
1060 const auto worksize = EntriesPerGroup(size);
1061 const auto begin = tx + bx * worksize;
1062 const auto end = min((bx + 1) * worksize, size);
1063
1064 #pragma unroll
1065 for (auto i = 0; i < N; ++i) sumlocal[i] = 0;
1066
1067 for (auto i = begin; i < end; i += bdx) {
1068 const auto xi = x[i]; // load only once from global memory!
1069
1070 #pragma unroll
1071 for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
1072 }
1073
1074 #pragma unroll
1075 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];
1076
1077 // parallel reduction
1078 for (auto stride = bdx / 2; stride > 0; stride /= 2) {
1079 __syncthreads();
1080 if (tx < stride) {
1081 #pragma unroll
1082 for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
1083 }
1084 }
1085 // bottom N threads per block write to global memory
1086 // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
1087 // writes to the same sections in the above loop that it is about to read from below, but
1088 // running this under the racecheck tool of compute-sanitizer reports a write-after-write hazard.
1089 __syncthreads();
1090 if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
1091 return;
1092 }
1093
1094 namespace
1095 {
1096
sum_kernel(const PetscInt size,PetscScalar * PETSC_RESTRICT results)1097 PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
1098 {
1099 int local_i = 0;
1100 PetscScalar local_results[8];
1101
1102 // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
1103 //
1104 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
1105 // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
1106 // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
1107 // | ______________________________________________________/
1108 // | / <- MDOT_WORKGROUP_NUM ->
1109 // |/
1110 // +
1111 // v
1112 // *-*-*
1113 // | | | ...
1114 // *-*-*
1115 //
1116 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
1117 PetscScalar z_sum = 0;
1118
1119 for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
1120 local_results[local_i++] = z_sum;
1121 });
1122 // if we needed more than 1 workgroup to handle the vector we should sync since other threads
1123 // may currently be reading from results
1124 if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
1125 // Local buffer is now written to global memory
1126 ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
1127 const auto j = --local_i;
1128
1129 if (j >= 0) results[i] = local_results[j];
1130 });
1131 return;
1132 }
1133
1134 } // namespace
1135
1136 #if PetscDefined(USING_HCC)
1137 namespace do_not_use
1138 {
1139
silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()1140 inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1141 {
1142 (void)sum_kernel;
1143 }
1144
1145 } // namespace do_not_use
1146 #endif
1147
1148 } // namespace kernels
1149
1150 template <device::cupm::DeviceType T>
1151 template <std::size_t... Idx>
MDot_kernel_dispatch_(PetscDeviceContext dctx,cupmStream_t stream,const PetscScalar * xarr,const Vec yin[],PetscInt size,PetscScalar * results,util::index_sequence<Idx...>)1152 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
1153 {
1154 PetscFunctionBegin;
1155 // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
1156 // 128 blocks of 128 threads every time which may be wasteful
1157 // clang-format off
1158 PetscCallCUPM(
1159 cupmLaunchKernel(
1160 kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
1161 MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
1162 xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
1163 )
1164 );
1165 // clang-format on
1166 PetscFunctionReturn(PETSC_SUCCESS);
1167 }
1168
1169 template <device::cupm::DeviceType T>
1170 template <int N>
MDot_kernel_dispatch_(PetscDeviceContext dctx,cupmStream_t stream,const PetscScalar * xarr,const Vec yin[],PetscInt size,PetscScalar * results,PetscInt & yidx)1171 inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
1172 {
1173 PetscFunctionBegin;
1174 PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
1175 yidx += N;
1176 PetscFunctionReturn(PETSC_SUCCESS);
1177 }
1178
1179 template <device::cupm::DeviceType T>
MDot_(std::false_type,Vec xin,PetscInt nv,const Vec yin[],PetscScalar * z,PetscDeviceContext dctx)1180 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
1181 {
1182 // the largest possible size of a batch
1183 constexpr PetscInt batchsize = 8;
1184 // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
1185 // do not create substreams. Note we don't create more than 8 streams, in practice we could
1186 // not get more parallelism with higher numbers.
1187 const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
1188 const auto n = xin->map->n;
1189 const auto nwork = nv * MDOT_WORKGROUP_NUM;
1190 PetscScalar *d_results;
1191 cupmStream_t stream;
1192
1193 PetscFunctionBegin;
1194 PetscCall(GetHandlesFrom_(dctx, &stream));
1195 // allocate scratchpad memory for the results of individual work groups
1196 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
1197 {
1198 const auto xptr = DeviceArrayRead(dctx, xin);
1199 PetscInt yidx = 0;
1200 auto subidx = 0;
1201 auto cur_stream = stream;
1202 auto cur_ctx = dctx;
1203 PetscDeviceContext *sub = nullptr;
1204 PetscStreamType stype;
1205
1206 // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
1207 // sub. Ideally the parent context should also join in on the fork, but it is extremely
1208 // fiddly to do so presently
1209 PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
1210 if (stype == PETSC_STREAM_DEFAULT || stype == PETSC_STREAM_DEFAULT_WITH_BARRIER) stype = PETSC_STREAM_NONBLOCKING;
1211 // If we have a default stream create nonblocking streams instead (as we can
1212 // locally exploit the parallelism). Otherwise use the prescribed stream type.
1213 PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
1214 PetscCall(PetscLogGpuTimeBegin());
1215 do {
1216 if (num_sub_streams) {
1217 cur_ctx = sub[subidx++ % num_sub_streams];
1218 PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
1219 }
1220 // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
1221 // it is very likely better to do 4+5 rather than 8+1
1222 switch (nv - yidx) {
1223 case 7:
1224 PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1225 break;
1226 case 6:
1227 PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1228 break;
1229 case 5:
1230 PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1231 break;
1232 case 4:
1233 PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1234 break;
1235 case 3:
1236 PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1237 break;
1238 case 2:
1239 PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1240 break;
1241 case 1:
1242 PetscCall(MDot_kernel_dispatch_<1>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1243 break;
1244 default: // 8 or more
1245 PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1246 break;
1247 }
1248 } while (yidx < nv);
1249 PetscCall(PetscLogGpuTimeEnd());
1250 PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
1251 }
1252
1253 PetscCall(PetscCUPMLaunchKernel1D(nv, 0, stream, kernels::sum_kernel, nv, d_results));
1254 // copy result of device reduction to host
1255 PetscCall(PetscCUPMMemcpyAsync(z, d_results, nv, cupmMemcpyDeviceToHost, stream));
1256 // do these now while final reduction is in flight
1257 PetscCall(PetscLogGpuFlops(nwork));
1258 PetscCall(PetscDeviceFree(dctx, d_results));
1259 PetscFunctionReturn(PETSC_SUCCESS);
1260 }
1261
1262 #undef MDOT_WORKGROUP_NUM
1263 #undef MDOT_WORKGROUP_SIZE
1264
1265 template <device::cupm::DeviceType T>
MDot_(std::true_type,Vec xin,PetscInt nv,const Vec yin[],PetscScalar * z,PetscDeviceContext dctx)1266 inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
1267 {
1268 // probably not worth it to run more than 8 of these at a time?
1269 const auto n_sub = PetscMin(nv, 8);
1270 const auto n = static_cast<cupmBlasInt_t>(xin->map->n);
1271 const auto xptr = DeviceArrayRead(dctx, xin);
1272 PetscScalar *d_z;
1273 PetscDeviceContext *subctx;
1274 cupmStream_t stream;
1275
1276 PetscFunctionBegin;
1277 PetscCall(GetHandlesFrom_(dctx, &stream));
1278 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
1279 PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
1280 PetscCall(PetscLogGpuTimeBegin());
1281 for (PetscInt i = 0; i < nv; ++i) {
1282 const auto sub = subctx[i % n_sub];
1283 cupmBlasHandle_t handle;
1284 cupmBlasPointerMode_t old_mode;
1285
1286 PetscCall(GetHandlesFrom_(sub, &handle));
1287 PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
1288 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
1289 PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
1290 if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
1291 }
1292 PetscCall(PetscLogGpuTimeEnd());
1293 PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
1294 PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
1295 PetscCall(PetscDeviceFree(dctx, d_z));
1296 // REVIEW ME: flops?????
1297 PetscFunctionReturn(PETSC_SUCCESS);
1298 }
1299
1300 // v->ops->mdot
1301 template <device::cupm::DeviceType T>
MDot(Vec xin,PetscInt nv,const Vec yin[],PetscScalar * z)1302 inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
1303 {
1304 PetscFunctionBegin;
1305 if (PetscUnlikely(nv == 1)) {
1306 // dot handles nv = 0 correctly
1307 PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z));
1308 } else if (const auto n = xin->map->n) {
1309 PetscDeviceContext dctx;
1310
1311 PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
1312 PetscCall(GetHandles_(&dctx));
1313 PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
1314 // REVIEW ME: double count of flops??
1315 PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
1316 PetscCall(PetscDeviceContextSynchronize(dctx));
1317 } else {
1318 PetscCall(PetscArrayzero(z, nv));
1319 }
1320 PetscFunctionReturn(PETSC_SUCCESS);
1321 }
1322
1323 // VecSetAsync_Private
1324 template <device::cupm::DeviceType T>
SetAsync(Vec xin,PetscScalar alpha,PetscDeviceContext dctx)1325 inline PetscErrorCode VecSeq_CUPM<T>::SetAsync(Vec xin, PetscScalar alpha, PetscDeviceContext dctx) noexcept
1326 {
1327 const auto n = xin->map->n;
1328 cupmStream_t stream;
1329
1330 PetscFunctionBegin;
1331 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1332 PetscCall(GetHandlesFrom_(dctx, &stream));
1333 {
1334 const auto xptr = DeviceArrayWrite(dctx, xin);
1335
1336 if (alpha == PetscScalar(0.0)) {
1337 PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1338 } else {
1339 const auto dptr = thrust::device_pointer_cast(xptr.data());
1340
1341 PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha));
1342 }
1343 }
1344 if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1345 PetscFunctionReturn(PETSC_SUCCESS);
1346 }
1347
1348 // v->ops->set
1349 template <device::cupm::DeviceType T>
Set(Vec xin,PetscScalar alpha)1350 inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept
1351 {
1352 PetscFunctionBegin;
1353 PetscCall(SetAsync(xin, alpha, nullptr));
1354 PetscFunctionReturn(PETSC_SUCCESS);
1355 }
1356
1357 // VecScaleAsync_Private
1358 template <device::cupm::DeviceType T>
ScaleAsync(Vec xin,PetscScalar alpha,PetscDeviceContext dctx)1359 inline PetscErrorCode VecSeq_CUPM<T>::ScaleAsync(Vec xin, PetscScalar alpha, PetscDeviceContext dctx) noexcept
1360 {
1361 PetscFunctionBegin;
1362 if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1363 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1364 if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1365 PetscCall(SetAsync(xin, alpha, dctx));
1366 } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1367 cupmBlasHandle_t cupmBlasHandle;
1368
1369 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1370 PetscCall(PetscLogGpuTimeBegin());
1371 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1372 PetscCall(PetscLogGpuTimeEnd());
1373 PetscCall(PetscLogGpuFlops(n));
1374 PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1375 } else {
1376 PetscCall(MaybeIncrementEmptyLocalVec(xin));
1377 }
1378 PetscFunctionReturn(PETSC_SUCCESS);
1379 }
1380
1381 // v->ops->scale
1382 template <device::cupm::DeviceType T>
Scale(Vec xin,PetscScalar alpha)1383 inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept
1384 {
1385 PetscFunctionBegin;
1386 PetscCall(ScaleAsync(xin, alpha, nullptr));
1387 PetscFunctionReturn(PETSC_SUCCESS);
1388 }
1389
1390 // v->ops->tdot
1391 template <device::cupm::DeviceType T>
TDot(Vec xin,Vec yin,PetscScalar * z)1392 inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept
1393 {
1394 PetscBool yiscupm;
1395
1396 PetscFunctionBegin;
1397 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1398 if (!yiscupm) {
1399 PetscCall(VecTDot_Seq(xin, yin, z));
1400 PetscFunctionReturn(PETSC_SUCCESS);
1401 }
1402 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1403 PetscDeviceContext dctx;
1404 cupmBlasHandle_t cupmBlasHandle;
1405
1406 PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1407 PetscCall(PetscLogGpuTimeBegin());
1408 PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1409 PetscCall(PetscLogGpuTimeEnd());
1410 PetscCall(PetscLogGpuFlops(2 * n - 1));
1411 } else {
1412 *z = 0.0;
1413 }
1414 PetscFunctionReturn(PETSC_SUCCESS);
1415 }
1416
1417 // VecCopyAsync_Private
1418 template <device::cupm::DeviceType T>
CopyAsync(Vec xin,Vec yout,PetscDeviceContext dctx)1419 inline PetscErrorCode VecSeq_CUPM<T>::CopyAsync(Vec xin, Vec yout, PetscDeviceContext dctx) noexcept
1420 {
1421 PetscFunctionBegin;
1422 if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1423 if (const auto n = xin->map->n) {
1424 const auto xmask = xin->offloadmask;
1425 // silence buggy gcc warning: mode may be used uninitialized in this function
1426 auto mode = cupmMemcpyDeviceToDevice;
1427 cupmStream_t stream;
1428
1429 // translate from PetscOffloadMask to cupmMemcpyKind
1430 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1431 switch (const auto ymask = yout->offloadmask) {
1432 case PETSC_OFFLOAD_UNALLOCATED: {
1433 PetscBool yiscupm;
1434
1435 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1436 if (yiscupm) {
1437 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1438 break;
1439 }
1440 } // fall-through if unallocated and not cupm
1441 #if PETSC_CPP_VERSION >= 17
1442 [[fallthrough]];
1443 #endif
1444 case PETSC_OFFLOAD_CPU: {
1445 PetscBool yiscupm;
1446
1447 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1448 if (yiscupm) {
1449 mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToDevice : cupmMemcpyDeviceToDevice;
1450 } else {
1451 mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1452 }
1453 break;
1454 }
1455 case PETSC_OFFLOAD_BOTH:
1456 case PETSC_OFFLOAD_GPU:
1457 mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1458 break;
1459 default:
1460 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1461 }
1462
1463 PetscCall(GetHandlesFrom_(dctx, &stream));
1464 switch (mode) {
1465 case cupmMemcpyDeviceToDevice: // the best case
1466 case cupmMemcpyHostToDevice: { // not terrible
1467 const auto yptr = DeviceArrayWrite(dctx, yout);
1468 const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1469
1470 PetscCall(PetscLogGpuTimeBegin());
1471 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1472 PetscCall(PetscLogGpuTimeEnd());
1473 } break;
1474 case cupmMemcpyDeviceToHost: // not great
1475 case cupmMemcpyHostToHost: { // worst case
1476 const auto xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1477 PetscScalar *yptr;
1478
1479 PetscCall(VecGetArrayWrite(yout, &yptr));
1480 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1481 PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1482 if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1483 PetscCall(VecRestoreArrayWrite(yout, &yptr));
1484 } break;
1485 default:
1486 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1487 }
1488 PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1489 } else {
1490 PetscCall(MaybeIncrementEmptyLocalVec(yout));
1491 }
1492 PetscFunctionReturn(PETSC_SUCCESS);
1493 }
1494
1495 // v->ops->copy
1496 template <device::cupm::DeviceType T>
Copy(Vec xin,Vec yout)1497 inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept
1498 {
1499 PetscFunctionBegin;
1500 PetscCall(CopyAsync(xin, yout, nullptr));
1501 PetscFunctionReturn(PETSC_SUCCESS);
1502 }
1503
1504 // VecSwapAsync_Private
1505 template <device::cupm::DeviceType T>
SwapAsync(Vec xin,Vec yin,PetscDeviceContext dctx)1506 inline PetscErrorCode VecSeq_CUPM<T>::SwapAsync(Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
1507 {
1508 PetscBool yiscupm;
1509
1510 PetscFunctionBegin;
1511 if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1512 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1513 PetscCheck(yiscupm, PetscObjectComm(PetscObjectCast(yin)), PETSC_ERR_SUP, "Cannot swap with Y of type %s", PetscObjectCast(yin)->type_name);
1514 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1515 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1516 cupmBlasHandle_t cupmBlasHandle;
1517
1518 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1519 PetscCall(PetscLogGpuTimeBegin());
1520 PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1521 PetscCall(PetscLogGpuTimeEnd());
1522 PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1523 } else {
1524 PetscCall(MaybeIncrementEmptyLocalVec(xin));
1525 PetscCall(MaybeIncrementEmptyLocalVec(yin));
1526 }
1527 PetscFunctionReturn(PETSC_SUCCESS);
1528 }
1529
1530 // v->ops->swap
1531 template <device::cupm::DeviceType T>
Swap(Vec xin,Vec yin)1532 inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept
1533 {
1534 PetscFunctionBegin;
1535 PetscCall(SwapAsync(xin, yin, nullptr));
1536 PetscFunctionReturn(PETSC_SUCCESS);
1537 }
1538
1539 // VecAXPYBYAsync_Private
1540 template <device::cupm::DeviceType T>
AXPBYAsync(Vec yin,PetscScalar alpha,PetscScalar beta,Vec xin,PetscDeviceContext dctx)1541 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYAsync(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin, PetscDeviceContext dctx) noexcept
1542 {
1543 PetscBool xiscupm;
1544
1545 PetscFunctionBegin;
1546 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1547 if (!xiscupm) {
1548 PetscCall(VecAXPBY_Seq(yin, alpha, beta, xin));
1549 PetscFunctionReturn(PETSC_SUCCESS);
1550 }
1551 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1552 if (alpha == PetscScalar(0.0)) {
1553 PetscCall(ScaleAsync(yin, beta, dctx));
1554 } else if (beta == PetscScalar(1.0)) {
1555 PetscCall(AXPYAsync(yin, alpha, xin, dctx));
1556 } else if (alpha == PetscScalar(1.0)) {
1557 PetscCall(AYPXAsync(yin, beta, xin, dctx));
1558 } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1559 PetscBool xiscupm;
1560
1561 PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1562 if (!xiscupm) {
1563 PetscCall(VecAXPBY_Seq(yin, alpha, beta, xin));
1564 PetscFunctionReturn(PETSC_SUCCESS);
1565 }
1566
1567 const auto betaIsZero = beta == PetscScalar(0.0);
1568 const auto aptr = cupmScalarPtrCast(&alpha);
1569 cupmBlasHandle_t cupmBlasHandle;
1570
1571 PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1572 {
1573 const auto xptr = DeviceArrayRead(dctx, xin);
1574
1575 if (betaIsZero /* beta = 0 */) {
1576 // here we can get away with purely write-only as we memcpy into it first
1577 const auto yptr = DeviceArrayWrite(dctx, yin);
1578 cupmStream_t stream;
1579
1580 PetscCall(GetHandlesFrom_(dctx, &stream));
1581 PetscCall(PetscLogGpuTimeBegin());
1582 PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1583 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1584 } else {
1585 const auto yptr = DeviceArrayReadWrite(dctx, yin);
1586
1587 PetscCall(PetscLogGpuTimeBegin());
1588 PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1589 PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1590 }
1591 }
1592 PetscCall(PetscLogGpuTimeEnd());
1593 PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1594 PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1595 } else {
1596 PetscCall(MaybeIncrementEmptyLocalVec(yin));
1597 }
1598 PetscFunctionReturn(PETSC_SUCCESS);
1599 }
1600
1601 // v->ops->axpby
1602 template <device::cupm::DeviceType T>
AXPBY(Vec yin,PetscScalar alpha,PetscScalar beta,Vec xin)1603 inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1604 {
1605 PetscFunctionBegin;
1606 PetscCall(AXPBYAsync(yin, alpha, beta, xin, nullptr));
1607 PetscFunctionReturn(PETSC_SUCCESS);
1608 }
1609
1610 // VecAXPBYPCZAsync_Private
1611 template <device::cupm::DeviceType T>
AXPBYPCZAsync(Vec zin,PetscScalar alpha,PetscScalar beta,PetscScalar gamma,Vec xin,Vec yin,PetscDeviceContext dctx)1612 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZAsync(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
1613 {
1614 PetscFunctionBegin;
1615 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1616 if (gamma != PetscScalar(1.0)) PetscCall(ScaleAsync(zin, gamma, dctx));
1617 PetscCall(AXPYAsync(zin, alpha, xin, dctx));
1618 PetscCall(AXPYAsync(zin, beta, yin, dctx));
1619 PetscFunctionReturn(PETSC_SUCCESS);
1620 }
1621
1622 // v->ops->axpbypcz
1623 template <device::cupm::DeviceType T>
AXPBYPCZ(Vec zin,PetscScalar alpha,PetscScalar beta,PetscScalar gamma,Vec xin,Vec yin)1624 inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1625 {
1626 PetscFunctionBegin;
1627 PetscCall(AXPBYPCZAsync(zin, alpha, beta, gamma, xin, yin, nullptr));
1628 PetscFunctionReturn(PETSC_SUCCESS);
1629 }
1630
1631 // v->ops->norm
1632 template <device::cupm::DeviceType T>
Norm(Vec xin,NormType type,PetscReal * z)1633 inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept
1634 {
1635 PetscDeviceContext dctx;
1636 cupmBlasHandle_t cupmBlasHandle;
1637
1638 PetscFunctionBegin;
1639 PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1640 if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1641 const auto xptr = DeviceArrayRead(dctx, xin);
1642 PetscInt flopCount = 0;
1643
1644 PetscCall(PetscLogGpuTimeBegin());
1645 switch (type) {
1646 case NORM_1_AND_2:
1647 case NORM_1:
1648 PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1649 flopCount = std::max(n - 1, 0);
1650 if (type == NORM_1) break;
1651 ++z; // fall-through
1652 #if PETSC_CPP_VERSION >= 17
1653 [[fallthrough]];
1654 #endif
1655 case NORM_2:
1656 case NORM_FROBENIUS:
1657 PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1658 flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1659 break;
1660 case NORM_INFINITY: {
1661 cupmBlasInt_t max_loc = 0;
1662 PetscScalar xv = 0.;
1663 cupmStream_t stream;
1664
1665 PetscCall(GetHandlesFrom_(dctx, &stream));
1666 PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1667 PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1668 *z = PetscAbsScalar(xv);
1669 // REVIEW ME: flopCount = ???
1670 } break;
1671 }
1672 PetscCall(PetscLogGpuTimeEnd());
1673 PetscCall(PetscLogGpuFlops(flopCount));
1674 } else {
1675 z[0] = 0.0;
1676 z[type == NORM_1_AND_2] = 0.0;
1677 }
1678 PetscFunctionReturn(PETSC_SUCCESS);
1679 }
1680
1681 namespace detail
1682 {
1683
1684 template <NormType wnormtype>
1685 class ErrorWNormTransformBase {
1686 public:
1687 using result_type = thrust::tuple<PetscReal, PetscReal, PetscReal, PetscInt, PetscInt, PetscInt>;
1688
ErrorWNormTransformBase(PetscReal v)1689 constexpr explicit ErrorWNormTransformBase(PetscReal v) noexcept : ignore_max_{v} { }
1690
1691 protected:
1692 struct NormTuple {
1693 PetscReal norm;
1694 PetscInt loc;
1695 };
1696
compute_norm_(PetscReal err,PetscReal tol)1697 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL static NormTuple compute_norm_(PetscReal err, PetscReal tol) noexcept
1698 {
1699 if (tol > 0.) {
1700 const auto val = err / tol;
1701
1702 return {wnormtype == NORM_INFINITY ? val : PetscSqr(val), 1};
1703 } else {
1704 return {0.0, 0};
1705 }
1706 }
1707
1708 PetscReal ignore_max_;
1709 };
1710
1711 template <NormType wnormtype>
1712 struct ErrorWNormTransform : ErrorWNormTransformBase<wnormtype> {
1713 using base_type = ErrorWNormTransformBase<wnormtype>;
1714 using result_type = typename base_type::result_type;
1715 using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar>;
1716
1717 using base_type::base_type;
1718
operator ()Petsc::vec::cupm::impl::detail::ErrorWNormTransform1719 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1720 {
1721 const auto u = thrust::get<0>(x); // with x.get<0>(), cuda-12.4.0 gives error: class "cuda::std::__4::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar>" has no member "get"
1722 const auto y = thrust::get<1>(x);
1723 const auto au = PetscAbsScalar(u);
1724 const auto ay = PetscAbsScalar(y);
1725 const auto skip = au < this->ignore_max_ || ay < this->ignore_max_;
1726 const auto tola = skip ? 0.0 : PetscRealPart(thrust::get<2>(x));
1727 const auto tolr = skip ? 0.0 : PetscRealPart(thrust::get<3>(x)) * PetscMax(au, ay);
1728 const auto tol = tola + tolr;
1729 const auto err = PetscAbsScalar(u - y);
1730 const auto tup_a = this->compute_norm_(err, tola);
1731 const auto tup_r = this->compute_norm_(err, tolr);
1732 const auto tup_n = this->compute_norm_(err, tol);
1733
1734 return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1735 }
1736 };
1737
1738 template <NormType wnormtype>
1739 struct ErrorWNormETransform : ErrorWNormTransformBase<wnormtype> {
1740 using base_type = ErrorWNormTransformBase<wnormtype>;
1741 using result_type = typename base_type::result_type;
1742 using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar, PetscScalar>;
1743
1744 using base_type::base_type;
1745
operator ()Petsc::vec::cupm::impl::detail::ErrorWNormETransform1746 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1747 {
1748 const auto au = PetscAbsScalar(thrust::get<0>(x));
1749 const auto ay = PetscAbsScalar(thrust::get<1>(x));
1750 const auto skip = au < this->ignore_max_ || ay < this->ignore_max_;
1751 const auto tola = skip ? 0.0 : PetscRealPart(thrust::get<3>(x));
1752 const auto tolr = skip ? 0.0 : PetscRealPart(thrust::get<4>(x)) * PetscMax(au, ay);
1753 const auto tol = tola + tolr;
1754 const auto err = PetscAbsScalar(thrust::get<2>(x));
1755 const auto tup_a = this->compute_norm_(err, tola);
1756 const auto tup_r = this->compute_norm_(err, tolr);
1757 const auto tup_n = this->compute_norm_(err, tol);
1758
1759 return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1760 }
1761 };
1762
1763 template <NormType wnormtype>
1764 struct ErrorWNormReduce {
1765 using value_type = typename ErrorWNormTransformBase<wnormtype>::result_type;
1766
operator ()Petsc::vec::cupm::impl::detail::ErrorWNormReduce1767 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept
1768 {
1769 // cannot use lhs.get<0>() etc since the using decl above ambiguates the fact that
1770 // result_type is a template, so in order to fix this we would need to write:
1771 //
1772 // lhs.template get<0>()
1773 //
1774 // which is unseemly.
1775 if (wnormtype == NORM_INFINITY) {
1776 // clang-format off
1777 return {
1778 PetscMax(thrust::get<0>(lhs), thrust::get<0>(rhs)),
1779 PetscMax(thrust::get<1>(lhs), thrust::get<1>(rhs)),
1780 PetscMax(thrust::get<2>(lhs), thrust::get<2>(rhs)),
1781 thrust::get<3>(lhs) + thrust::get<3>(rhs),
1782 thrust::get<4>(lhs) + thrust::get<4>(rhs),
1783 thrust::get<5>(lhs) + thrust::get<5>(rhs)
1784 };
1785 // clang-format on
1786 } else {
1787 // clang-format off
1788 return {
1789 thrust::get<0>(lhs) + thrust::get<0>(rhs),
1790 thrust::get<1>(lhs) + thrust::get<1>(rhs),
1791 thrust::get<2>(lhs) + thrust::get<2>(rhs),
1792 thrust::get<3>(lhs) + thrust::get<3>(rhs),
1793 thrust::get<4>(lhs) + thrust::get<4>(rhs),
1794 thrust::get<5>(lhs) + thrust::get<5>(rhs)
1795 };
1796 // clang-format on
1797 }
1798 }
1799 };
1800
1801 template <template <NormType> class WNormTransformType, typename Tuple, typename cupmStream_t>
ExecuteWNorm(Tuple && first,Tuple && last,NormType wnormtype,cupmStream_t stream,PetscReal ignore_max,PetscReal * norm,PetscInt * norm_loc,PetscReal * norma,PetscInt * norma_loc,PetscReal * normr,PetscInt * normr_loc)1802 inline PetscErrorCode ExecuteWNorm(Tuple &&first, Tuple &&last, NormType wnormtype, cupmStream_t stream, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1803 {
1804 auto begin = thrust::make_zip_iterator(std::forward<Tuple>(first));
1805 auto end = thrust::make_zip_iterator(std::forward<Tuple>(last));
1806 PetscReal n = 0, na = 0, nr = 0;
1807 PetscInt n_loc = 0, na_loc = 0, nr_loc = 0;
1808
1809 PetscFunctionBegin;
1810 // clang-format off
1811 if (wnormtype == NORM_INFINITY) {
1812 PetscCallThrust(
1813 thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1814 thrust::transform_reduce,
1815 stream,
1816 std::move(begin),
1817 std::move(end),
1818 WNormTransformType<NORM_INFINITY>{ignore_max},
1819 thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1820 ErrorWNormReduce<NORM_INFINITY>{}
1821 )
1822 );
1823 } else {
1824 PetscCallThrust(
1825 thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1826 thrust::transform_reduce,
1827 stream,
1828 std::move(begin),
1829 std::move(end),
1830 WNormTransformType<NORM_2>{ignore_max},
1831 thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1832 ErrorWNormReduce<NORM_2>{}
1833 )
1834 );
1835 }
1836 // clang-format on
1837 if (wnormtype == NORM_2) {
1838 *norm = PetscSqrtReal(*norm);
1839 *norma = PetscSqrtReal(*norma);
1840 *normr = PetscSqrtReal(*normr);
1841 }
1842 PetscFunctionReturn(PETSC_SUCCESS);
1843 }
1844
1845 } // namespace detail
1846
1847 // v->ops->errorwnorm
1848 template <device::cupm::DeviceType T>
ErrorWnorm(Vec U,Vec Y,Vec E,NormType wnormtype,PetscReal atol,Vec vatol,PetscReal rtol,Vec vrtol,PetscReal ignore_max,PetscReal * norm,PetscInt * norm_loc,PetscReal * norma,PetscInt * norma_loc,PetscReal * normr,PetscInt * normr_loc)1849 inline PetscErrorCode VecSeq_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1850 {
1851 const auto nl = U->map->n;
1852 auto ait = thrust::make_constant_iterator(static_cast<PetscScalar>(atol));
1853 auto rit = thrust::make_constant_iterator(static_cast<PetscScalar>(rtol));
1854 PetscDeviceContext dctx;
1855 cupmStream_t stream;
1856
1857 PetscFunctionBegin;
1858 PetscCall(GetHandles_(&dctx, &stream));
1859 {
1860 const auto ConditionalDeviceArrayRead = [&](Vec v) {
1861 if (v) {
1862 return thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1863 } else {
1864 return thrust::device_ptr<PetscScalar>{nullptr};
1865 }
1866 };
1867
1868 const auto uarr = DeviceArrayRead(dctx, U);
1869 const auto yarr = DeviceArrayRead(dctx, Y);
1870 const auto uptr = thrust::device_pointer_cast(uarr.data());
1871 const auto yptr = thrust::device_pointer_cast(yarr.data());
1872 const auto eptr = ConditionalDeviceArrayRead(E);
1873 const auto rptr = ConditionalDeviceArrayRead(vrtol);
1874 const auto aptr = ConditionalDeviceArrayRead(vatol);
1875
1876 if (!vatol && !vrtol) {
1877 if (E) {
1878 // clang-format off
1879 PetscCall(
1880 detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1881 thrust::make_tuple(uptr, yptr, eptr, ait, rit),
1882 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rit),
1883 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1884 )
1885 );
1886 // clang-format on
1887 } else {
1888 // clang-format off
1889 PetscCall(
1890 detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1891 thrust::make_tuple(uptr, yptr, ait, rit),
1892 thrust::make_tuple(uptr + nl, yptr + nl, ait, rit),
1893 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1894 )
1895 );
1896 // clang-format on
1897 }
1898 } else if (!vatol) {
1899 if (E) {
1900 // clang-format off
1901 PetscCall(
1902 detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1903 thrust::make_tuple(uptr, yptr, eptr, ait, rptr),
1904 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rptr + nl),
1905 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1906 )
1907 );
1908 // clang-format on
1909 } else {
1910 // clang-format off
1911 PetscCall(
1912 detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1913 thrust::make_tuple(uptr, yptr, ait, rptr),
1914 thrust::make_tuple(uptr + nl, yptr + nl, ait, rptr + nl),
1915 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1916 )
1917 );
1918 // clang-format on
1919 }
1920 } else if (!vrtol) {
1921 if (E) {
1922 // clang-format off
1923 PetscCall(
1924 detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1925 thrust::make_tuple(uptr, yptr, eptr, aptr, rit),
1926 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rit),
1927 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1928 )
1929 );
1930 // clang-format on
1931 } else {
1932 // clang-format off
1933 PetscCall(
1934 detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1935 thrust::make_tuple(uptr, yptr, aptr, rit),
1936 thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rit),
1937 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1938 )
1939 );
1940 // clang-format on
1941 }
1942 } else {
1943 if (E) {
1944 // clang-format off
1945 PetscCall(
1946 detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1947 thrust::make_tuple(uptr, yptr, eptr, aptr, rptr),
1948 thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rptr + nl),
1949 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1950 )
1951 );
1952 // clang-format on
1953 } else {
1954 // clang-format off
1955 PetscCall(
1956 detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1957 thrust::make_tuple(uptr, yptr, aptr, rptr),
1958 thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rptr + nl),
1959 wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1960 )
1961 );
1962 // clang-format on
1963 }
1964 }
1965 }
1966 PetscFunctionReturn(PETSC_SUCCESS);
1967 }
1968
1969 namespace detail
1970 {
1971 struct dotnorm2_mult {
operator ()Petsc::vec::cupm::impl::detail::dotnorm2_mult1972 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1973 {
1974 const auto conjt = PetscConj(t);
1975
1976 return {s * conjt, t * conjt};
1977 }
1978 };
1979
1980 // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1981 // would do it myself but now I am worried that they do so on purpose...
1982 struct dotnorm2_tuple_plus {
1983 using value_type = thrust::tuple<PetscScalar, PetscScalar>;
1984
operator ()Petsc::vec::cupm::impl::detail::dotnorm2_tuple_plus1985 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {thrust::get<0>(lhs) + thrust::get<0>(rhs), thrust::get<1>(lhs) + thrust::get<1>(rhs)}; }
1986 };
1987
1988 } // namespace detail
1989
1990 // v->ops->dotnorm2
1991 template <device::cupm::DeviceType T>
DotNorm2(Vec s,Vec t,PetscScalar * dp,PetscScalar * nm)1992 inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1993 {
1994 PetscDeviceContext dctx;
1995 cupmStream_t stream;
1996
1997 PetscFunctionBegin;
1998 PetscCall(GetHandles_(&dctx, &stream));
1999 {
2000 PetscScalar dpt = 0.0, nmt = 0.0;
2001 const auto sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());
2002
2003 // clang-format off
2004 PetscCallThrust(
2005 thrust::tie(*dp, *nm) = THRUST_CALL(
2006 thrust::inner_product,
2007 stream,
2008 sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
2009 thrust::make_tuple(dpt, nmt),
2010 detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
2011 );
2012 );
2013 // clang-format on
2014 }
2015 PetscFunctionReturn(PETSC_SUCCESS);
2016 }
2017
2018 namespace detail
2019 {
2020 struct conjugate {
operator ()Petsc::vec::cupm::impl::detail::conjugate2021 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &x) const noexcept { return PetscConj(x); }
2022 };
2023
2024 } // namespace detail
2025
2026 // v->ops->conjugate
2027 template <device::cupm::DeviceType T>
ConjugateAsync(Vec xin,PetscDeviceContext dctx)2028 inline PetscErrorCode VecSeq_CUPM<T>::ConjugateAsync(Vec xin, PetscDeviceContext dctx) noexcept
2029 {
2030 PetscFunctionBegin;
2031 if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin, nullptr, dctx));
2032 PetscFunctionReturn(PETSC_SUCCESS);
2033 }
2034
2035 // v->ops->conjugate
2036 template <device::cupm::DeviceType T>
Conjugate(Vec xin)2037 inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept
2038 {
2039 PetscFunctionBegin;
2040 PetscCall(ConjugateAsync(xin, nullptr));
2041 PetscFunctionReturn(PETSC_SUCCESS);
2042 }
2043
2044 namespace detail
2045 {
2046
2047 struct real_part {
operator ()Petsc::vec::cupm::impl::detail::real_part2048 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const noexcept { return {PetscRealPart(thrust::get<0>(x)), thrust::get<1>(x)}; }
2049
operator ()Petsc::vec::cupm::impl::detail::real_part2050 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(const PetscScalar &x) const noexcept { return PetscRealPart(x); }
2051 };
2052
2053 // deriving from Operator allows us to "store" an instance of the operator in the class but
2054 // also take advantage of empty base class optimization if the operator is stateless
2055 template <typename Operator>
2056 class tuple_compare : Operator {
2057 public:
2058 using tuple_type = thrust::tuple<PetscReal, PetscInt>;
2059 using operator_type = Operator;
2060
operator ()(const tuple_type & x,const tuple_type & y) const2061 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
2062 {
2063 if (op_()(thrust::get<0>(y), thrust::get<0>(x))) {
2064 // if y is strictly greater/less than x, return y
2065 return y;
2066 } else if (thrust::get<0>(y) == thrust::get<0>(x)) {
2067 // if equal, prefer lower index
2068 return thrust::get<1>(y) < thrust::get<1>(x) ? y : x;
2069 }
2070 // otherwise return x
2071 return x;
2072 }
2073
2074 private:
op_() const2075 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
2076 };
2077
2078 } // namespace detail
2079
2080 template <device::cupm::DeviceType T>
2081 template <typename TupleFuncT, typename UnaryFuncT>
MinMax_(TupleFuncT && tuple_ftr,UnaryFuncT && unary_ftr,Vec v,PetscInt * p,PetscReal * m)2082 inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
2083 {
2084 PetscFunctionBegin;
2085 PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
2086 if (p) *p = -1;
2087 if (const auto n = v->map->n) {
2088 PetscDeviceContext dctx;
2089 cupmStream_t stream;
2090
2091 PetscCall(GetHandles_(&dctx, &stream));
2092 // needed to:
2093 // 1. switch between transform_reduce and reduce
2094 // 2. strip the real_part functor from the arguments
2095 #if PetscDefined(USE_COMPLEX)
2096 #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
2097 #else
2098 #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
2099 #endif
2100 {
2101 const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
2102
2103 if (p) {
2104 // clang-format off
2105 const auto zip = thrust::make_zip_iterator(
2106 thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
2107 );
2108 // clang-format on
2109 // need to use preprocessor conditionals since otherwise thrust complains about not being
2110 // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
2111 // builds...
2112 // clang-format off
2113 PetscCallThrust(
2114 thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
2115 stream, zip, zip + n, detail::real_part{},
2116 thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
2117 );
2118 );
2119 // clang-format on
2120 } else {
2121 // clang-format off
2122 PetscCallThrust(
2123 *m = THRUST_MINMAX_REDUCE(
2124 stream, vptr, vptr + n, detail::real_part{},
2125 *m, std::forward<UnaryFuncT>(unary_ftr)
2126 );
2127 );
2128 // clang-format on
2129 }
2130 }
2131 #undef THRUST_MINMAX_REDUCE
2132 }
2133 // REVIEW ME: flops?
2134 PetscFunctionReturn(PETSC_SUCCESS);
2135 }
2136
2137 // v->ops->max
2138 template <device::cupm::DeviceType T>
Max(Vec v,PetscInt * p,PetscReal * m)2139 inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept
2140 {
2141 #if CCCL_VERSION >= 3001000
2142 using tuple_functor = detail::tuple_compare<cuda::std::greater<PetscReal>>;
2143 using unary_functor = cuda::maximum<PetscReal>;
2144 #else
2145 using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
2146 using unary_functor = thrust::maximum<PetscReal>;
2147 #endif
2148
2149 PetscFunctionBegin;
2150 *m = PETSC_MIN_REAL;
2151 // use {} constructor syntax otherwise most vexing parse
2152 PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
2153 PetscFunctionReturn(PETSC_SUCCESS);
2154 }
2155
2156 // v->ops->min
2157 template <device::cupm::DeviceType T>
Min(Vec v,PetscInt * p,PetscReal * m)2158 inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept
2159 {
2160 #if CCCL_VERSION >= 3001000
2161 using tuple_functor = detail::tuple_compare<cuda::std::less<PetscReal>>;
2162 using unary_functor = cuda::minimum<PetscReal>;
2163 #else
2164 using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
2165 using unary_functor = thrust::minimum<PetscReal>;
2166 #endif
2167
2168 PetscFunctionBegin;
2169 *m = PETSC_MAX_REAL;
2170 // use {} constructor syntax otherwise most vexing parse
2171 PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
2172 PetscFunctionReturn(PETSC_SUCCESS);
2173 }
2174
2175 // v->ops->sum
2176 template <device::cupm::DeviceType T>
Sum(Vec v,PetscScalar * sum)2177 inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept
2178 {
2179 PetscFunctionBegin;
2180 if (const auto n = v->map->n) {
2181 PetscDeviceContext dctx;
2182 cupmStream_t stream;
2183
2184 PetscCall(GetHandles_(&dctx, &stream));
2185 const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
2186 // REVIEW ME: why not cupmBlasXasum()?
2187 PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
2188 // REVIEW ME: must be at least n additions
2189 PetscCall(PetscLogGpuFlops(n));
2190 } else {
2191 *sum = 0.0;
2192 }
2193 PetscFunctionReturn(PETSC_SUCCESS);
2194 }
2195
2196 template <device::cupm::DeviceType T>
ShiftAsync(Vec v,PetscScalar shift,PetscDeviceContext dctx)2197 inline PetscErrorCode VecSeq_CUPM<T>::ShiftAsync(Vec v, PetscScalar shift, PetscDeviceContext dctx) noexcept
2198 {
2199 PetscFunctionBegin;
2200 PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v, nullptr, dctx));
2201 PetscFunctionReturn(PETSC_SUCCESS);
2202 }
2203
2204 template <device::cupm::DeviceType T>
Shift(Vec v,PetscScalar shift)2205 inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept
2206 {
2207 PetscFunctionBegin;
2208 PetscCall(ShiftAsync(v, shift, nullptr));
2209 PetscFunctionReturn(PETSC_SUCCESS);
2210 }
2211
2212 template <device::cupm::DeviceType T>
SetRandom(Vec v,PetscRandom rand)2213 inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept
2214 {
2215 PetscFunctionBegin;
2216 if (const auto n = v->map->n) {
2217 PetscBool iscurand;
2218 PetscDeviceContext dctx;
2219
2220 PetscCall(GetHandles_(&dctx));
2221 PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
2222 if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
2223 else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
2224 } else {
2225 PetscCall(MaybeIncrementEmptyLocalVec(v));
2226 }
2227 // REVIEW ME: flops????
2228 // REVIEW ME: Timing???
2229 PetscFunctionReturn(PETSC_SUCCESS);
2230 }
2231
2232 // v->ops->setpreallocation
2233 template <device::cupm::DeviceType T>
SetPreallocationCOO(Vec v,PetscCount ncoo,const PetscInt coo_i[])2234 inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
2235 {
2236 PetscDeviceContext dctx;
2237
2238 PetscFunctionBegin;
2239 PetscCall(GetHandles_(&dctx));
2240 PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
2241 PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
2242 PetscFunctionReturn(PETSC_SUCCESS);
2243 }
2244
2245 // v->ops->setvaluescoo
2246 template <device::cupm::DeviceType T>
SetValuesCOO(Vec x,const PetscScalar v[],InsertMode imode)2247 inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
2248 {
2249 auto vv = const_cast<PetscScalar *>(v);
2250 PetscMemType memtype;
2251 PetscDeviceContext dctx;
2252 cupmStream_t stream;
2253
2254 PetscFunctionBegin;
2255 PetscCall(GetHandles_(&dctx, &stream));
2256 PetscCall(PetscGetMemType(v, &memtype));
2257 if (PetscMemTypeHost(memtype)) {
2258 const auto size = VecIMPLCast(x)->coo_n;
2259
2260 // If user gave v[] in host, we might need to copy it to device if any
2261 PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
2262 PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
2263 }
2264
2265 if (const auto n = x->map->n) {
2266 const auto vcu = VecCUPMCast(x);
2267
2268 PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
2269 } else {
2270 PetscCall(MaybeIncrementEmptyLocalVec(x));
2271 }
2272
2273 if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
2274 PetscCall(PetscDeviceContextSynchronize(dctx));
2275 PetscFunctionReturn(PETSC_SUCCESS);
2276 }
2277
2278 } // namespace impl
2279
2280 } // namespace cupm
2281
2282 } // namespace vec
2283
2284 } // namespace Petsc
2285