xref: /petsc/src/vec/is/sf/impls/basic/cupm/sfcupm_impl.hpp (revision 4ad8454beace47809662cdae21ee081016eaa39a)
1 #pragma once
2 
3 #include "sfcupm.hpp"
4 #include <../src/sys/objects/device/impls/cupm/kernels.hpp>
5 #include <petsc/private/cupmatomics.hpp>
6 
7 namespace Petsc
8 {
9 
10 namespace sf
11 {
12 
13 namespace cupm
14 {
15 
16 namespace kernels
17 {
18 
19 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
20 PETSC_NODISCARD static PETSC_DEVICE_INLINE_DECL PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid) noexcept
21 {
22   PetscInt        i, j, k, m, n, r;
23   const PetscInt *offset, *start, *dx, *dy, *X, *Y;
24 
25   n      = opt[0];
26   offset = opt + 1;
27   start  = opt + n + 2;
28   dx     = opt + 2 * n + 2;
29   dy     = opt + 3 * n + 2;
30   X      = opt + 5 * n + 2;
31   Y      = opt + 6 * n + 2;
32   for (r = 0; r < n; r++) {
33     if (tid < offset[r + 1]) break;
34   }
35   m = (tid - offset[r]);
36   k = m / (dx[r] * dy[r]);
37   j = (m - k * dx[r] * dy[r]) / dx[r];
38   i = m - k * dx[r] * dy[r] - j * dx[r];
39 
40   return start[r] + k * X[r] * Y[r] + j * X[r] + i;
41 }
42 
43 /*====================================================================================*/
44 /*  Templated CUPM kernels for pack/unpack. The Op can be regular or atomic           */
45 /*====================================================================================*/
46 
47 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
48    <Type> is PetscReal, which is the primitive type we operate on.
49    <bs>   is 16, which says <unit> contains 16 primitive types.
50    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
51    <EQ>   is 0, which is (bs == BS ? 1 : 0)
52 
53   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
54   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
55 */
56 template <class Type, PetscInt BS, PetscInt EQ>
57 PETSC_KERNEL_DECL static void d_Pack(PetscInt bs, PetscInt count, PetscInt start, const PetscInt *opt, const PetscInt *idx, const Type *data, Type *buf)
58 {
59   const PetscInt M   = (EQ) ? 1 : bs / BS; /* If EQ, then M=1 enables compiler's const-propagation */
60   const PetscInt MBS = M * BS;             /* MBS=bs. We turn MBS into a compile-time const when EQ=1. */
61 
62   ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
63     PetscInt t = (opt ? MapTidToIndex(opt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
64     PetscInt s = tid * MBS;
65     for (PetscInt i = 0; i < MBS; i++) buf[s + i] = data[t + i];
66   });
67 }
68 
69 template <class Type, class Op, PetscInt BS, PetscInt EQ>
70 PETSC_KERNEL_DECL static void d_UnpackAndOp(PetscInt bs, PetscInt count, PetscInt start, const PetscInt *opt, const PetscInt *idx, Type *data, const Type *buf)
71 {
72   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
73   Op             op;
74 
75   ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
76     PetscInt t = (opt ? MapTidToIndex(opt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
77     PetscInt s = tid * MBS;
78     for (PetscInt i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
79   });
80 }
81 
82 template <class Type, class Op, PetscInt BS, PetscInt EQ>
83 PETSC_KERNEL_DECL static void d_FetchAndOp(PetscInt bs, PetscInt count, PetscInt rootstart, const PetscInt *rootopt, const PetscInt *rootidx, Type *rootdata, Type *leafbuf)
84 {
85   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
86   Op             op;
87 
88   ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
89     PetscInt r = (rootopt ? MapTidToIndex(rootopt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
90     PetscInt l = tid * MBS;
91     for (PetscInt i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
92   });
93 }
94 
95 template <class Type, class Op, PetscInt BS, PetscInt EQ>
96 PETSC_KERNEL_DECL static void d_ScatterAndOp(PetscInt bs, PetscInt count, PetscInt srcx, PetscInt srcy, PetscInt srcX, PetscInt srcY, PetscInt srcStart, const PetscInt *srcIdx, const Type *src, PetscInt dstx, PetscInt dsty, PetscInt dstX, PetscInt dstY, PetscInt dstStart, const PetscInt *dstIdx, Type *dst)
97 {
98   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
99   Op             op;
100 
101   ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
102     PetscInt s, t;
103 
104     if (!srcIdx) { /* src is either contiguous or 3D */
105       PetscInt k = tid / (srcx * srcy);
106       PetscInt j = (tid - k * srcx * srcy) / srcx;
107       PetscInt i = tid - k * srcx * srcy - j * srcx;
108 
109       s = srcStart + k * srcX * srcY + j * srcX + i;
110     } else {
111       s = srcIdx[tid];
112     }
113 
114     if (!dstIdx) { /* dst is either contiguous or 3D */
115       PetscInt k = tid / (dstx * dsty);
116       PetscInt j = (tid - k * dstx * dsty) / dstx;
117       PetscInt i = tid - k * dstx * dsty - j * dstx;
118 
119       t = dstStart + k * dstX * dstY + j * dstX + i;
120     } else {
121       t = dstIdx[tid];
122     }
123 
124     s *= MBS;
125     t *= MBS;
126     for (PetscInt i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
127   });
128 }
129 
130 template <class Type, class Op, PetscInt BS, PetscInt EQ>
131 PETSC_KERNEL_DECL static void d_FetchAndOpLocal(PetscInt bs, PetscInt count, PetscInt rootstart, const PetscInt *rootopt, const PetscInt *rootidx, Type *rootdata, PetscInt leafstart, const PetscInt *leafopt, const PetscInt *leafidx, const Type *leafdata, Type *leafupdate)
132 {
133   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
134   Op             op;
135 
136   ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
137     PetscInt r = (rootopt ? MapTidToIndex(rootopt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
138     PetscInt l = (leafopt ? MapTidToIndex(leafopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
139     for (PetscInt i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
140   });
141 }
142 
143 /*====================================================================================*/
144 /*                             Regular operations on device                           */
145 /*====================================================================================*/
146 template <typename Type>
147 struct Insert {
148   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
149   {
150     Type old = x;
151     x        = y;
152     return old;
153   }
154 };
155 template <typename Type>
156 struct Add {
157   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
158   {
159     Type old = x;
160     x += y;
161     return old;
162   }
163 };
164 template <typename Type>
165 struct Mult {
166   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
167   {
168     Type old = x;
169     x *= y;
170     return old;
171   }
172 };
173 template <typename Type>
174 struct Min {
175   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
176   {
177     Type old = x;
178     x        = PetscMin(x, y);
179     return old;
180   }
181 };
182 template <typename Type>
183 struct Max {
184   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
185   {
186     Type old = x;
187     x        = PetscMax(x, y);
188     return old;
189   }
190 };
191 template <typename Type>
192 struct LAND {
193   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
194   {
195     Type old = x;
196     x        = x && y;
197     return old;
198   }
199 };
200 template <typename Type>
201 struct LOR {
202   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
203   {
204     Type old = x;
205     x        = x || y;
206     return old;
207   }
208 };
209 template <typename Type>
210 struct LXOR {
211   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
212   {
213     Type old = x;
214     x        = !x != !y;
215     return old;
216   }
217 };
218 template <typename Type>
219 struct BAND {
220   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
221   {
222     Type old = x;
223     x        = x & y;
224     return old;
225   }
226 };
227 template <typename Type>
228 struct BOR {
229   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
230   {
231     Type old = x;
232     x        = x | y;
233     return old;
234   }
235 };
236 template <typename Type>
237 struct BXOR {
238   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
239   {
240     Type old = x;
241     x        = x ^ y;
242     return old;
243   }
244 };
245 template <typename Type>
246 struct Minloc {
247   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
248   {
249     Type old = x;
250     if (y.a < x.a) x = y;
251     else if (y.a == x.a) x.b = min(x.b, y.b);
252     return old;
253   }
254 };
255 template <typename Type>
256 struct Maxloc {
257   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
258   {
259     Type old = x;
260     if (y.a > x.a) x = y;
261     else if (y.a == x.a) x.b = min(x.b, y.b); /* See MPI MAXLOC */
262     return old;
263   }
264 };
265 
266 } // namespace kernels
267 
268 namespace impl
269 {
270 
271 /*====================================================================================*/
272 /*  Wrapper functions of cupm kernels. Function pointers are stored in 'link'         */
273 /*====================================================================================*/
274 template <device::cupm::DeviceType T>
275 template <typename Type, PetscInt BS, PetscInt EQ>
276 inline PetscErrorCode SfInterface<T>::Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data, void *buf) noexcept
277 {
278   const PetscInt *iarray = opt ? opt->array : NULL;
279 
280   PetscFunctionBegin;
281   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
282   if (PetscDefined(USING_NVCC) && !opt && !idx) { /* It is a 'CUDA data to nvshmem buf' memory copy */
283     PetscCallCUPM(cupmMemcpyAsync(buf, (char *)data + start * link->unitbytes, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream));
284   } else {
285     PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_Pack<Type, BS, EQ>, link->bs, count, start, iarray, idx, (const Type *)data, (Type *)buf));
286   }
287   PetscFunctionReturn(PETSC_SUCCESS);
288 }
289 
290 template <device::cupm::DeviceType T>
291 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
292 inline PetscErrorCode SfInterface<T>::UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, const void *buf) noexcept
293 {
294   const PetscInt *iarray = opt ? opt->array : NULL;
295 
296   PetscFunctionBegin;
297   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
298   if (PetscDefined(USING_NVCC) && std::is_same<Op, kernels::Insert<Type>>::value && !opt && !idx) { /* It is a 'nvshmem buf to CUDA data' memory copy */
299     PetscCallCUPM(cupmMemcpyAsync((char *)data + start * link->unitbytes, buf, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream));
300   } else {
301     PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_UnpackAndOp<Type, Op, BS, EQ>, link->bs, count, start, iarray, idx, (Type *)data, (const Type *)buf));
302   }
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305 
306 template <device::cupm::DeviceType T>
307 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
308 inline PetscErrorCode SfInterface<T>::FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf) noexcept
309 {
310   const PetscInt *iarray = opt ? opt->array : NULL;
311 
312   PetscFunctionBegin;
313   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
314   PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_FetchAndOp<Type, Op, BS, EQ>, link->bs, count, start, iarray, idx, (Type *)data, (const Type *)buf));
315   PetscFunctionReturn(PETSC_SUCCESS);
316 }
317 
318 template <device::cupm::DeviceType T>
319 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
320 inline PetscErrorCode SfInterface<T>::ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst) noexcept
321 {
322   PetscInt nthreads = 256;
323   PetscInt nblocks  = (count + nthreads - 1) / nthreads;
324   PetscInt srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;
325 
326   PetscFunctionBegin;
327   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
328   nblocks = PetscMin(nblocks, link->maxResidentThreadsPerGPU / nthreads);
329 
330   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use 3D grid and block */
331   if (srcOpt) {
332     srcx     = srcOpt->dx[0];
333     srcy     = srcOpt->dy[0];
334     srcX     = srcOpt->X[0];
335     srcY     = srcOpt->Y[0];
336     srcStart = srcOpt->start[0];
337     srcIdx   = NULL;
338   } else if (!srcIdx) {
339     srcx = srcX = count;
340     srcy = srcY = 1;
341   }
342 
343   if (dstOpt) {
344     dstx     = dstOpt->dx[0];
345     dsty     = dstOpt->dy[0];
346     dstX     = dstOpt->X[0];
347     dstY     = dstOpt->Y[0];
348     dstStart = dstOpt->start[0];
349     dstIdx   = NULL;
350   } else if (!dstIdx) {
351     dstx = dstX = count;
352     dsty = dstY = 1;
353   }
354 
355   PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_ScatterAndOp<Type, Op, BS, EQ>, link->bs, count, srcx, srcy, srcX, srcY, srcStart, srcIdx, (const Type *)src, dstx, dsty, dstX, dstY, dstStart, dstIdx, (Type *)dst));
356   PetscFunctionReturn(PETSC_SUCCESS);
357 }
358 
359 template <device::cupm::DeviceType T>
360 /* Specialization for Insert since we may use cupmMemcpyAsync */
361 template <typename Type, PetscInt BS, PetscInt EQ>
362 inline PetscErrorCode SfInterface<T>::ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst) noexcept
363 {
364   PetscFunctionBegin;
365   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
366   /*src and dst are contiguous */
367   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
368     PetscCallCUPM(cupmMemcpyAsync((Type *)dst + dstStart * link->bs, (const Type *)src + srcStart * link->bs, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream));
369   } else {
370     PetscCall(ScatterAndOp<Type, kernels::Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
371   }
372   PetscFunctionReturn(PETSC_SUCCESS);
373 }
374 
375 template <device::cupm::DeviceType T>
376 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
377 inline PetscErrorCode SfInterface<T>::FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata, void *leafupdate) noexcept
378 {
379   const PetscInt *rarray = rootopt ? rootopt->array : NULL;
380   const PetscInt *larray = leafopt ? leafopt->array : NULL;
381 
382   PetscFunctionBegin;
383   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
384   PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_FetchAndOpLocal<Type, Op, BS, EQ>, link->bs, count, rootstart, rarray, rootidx, (Type *)rootdata, leafstart, larray, leafidx, (const Type *)leafdata, (Type *)leafupdate));
385   PetscFunctionReturn(PETSC_SUCCESS);
386 }
387 
388 /*====================================================================================*/
389 /*  Init various types and instantiate pack/unpack function pointers                  */
390 /*====================================================================================*/
391 template <device::cupm::DeviceType T>
392 template <typename Type, PetscInt BS, PetscInt EQ>
393 inline void SfInterface<T>::PackInit_RealType(PetscSFLink link) noexcept
394 {
395   /* Pack/unpack for remote communication */
396   link->d_Pack            = Pack<Type, BS, EQ>;
397   link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
398   link->d_UnpackAndAdd    = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>;
399   link->d_UnpackAndMult   = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>;
400   link->d_UnpackAndMin    = UnpackAndOp<Type, kernels::Min<Type>, BS, EQ>;
401   link->d_UnpackAndMax    = UnpackAndOp<Type, kernels::Max<Type>, BS, EQ>;
402   link->d_FetchAndAdd     = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>;
403 
404   /* Scatter for local communication */
405   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
406   link->d_ScatterAndAdd    = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>;
407   link->d_ScatterAndMult   = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>;
408   link->d_ScatterAndMin    = ScatterAndOp<Type, kernels::Min<Type>, BS, EQ>;
409   link->d_ScatterAndMax    = ScatterAndOp<Type, kernels::Max<Type>, BS, EQ>;
410   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>;
411 
412   /* Atomic versions when there are data-race possibilities */
413   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
414   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
415   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
416   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
417   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
418   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicAdd<Type>, BS, EQ>;
419 
420   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
421   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
422   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
423   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
424   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
425   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicAdd<Type>, BS, EQ>;
426 }
427 
428 /* Have this templated class to specialize for char integers */
429 template <device::cupm::DeviceType T>
430 template <typename Type, PetscInt BS, PetscInt EQ, PetscInt size /*sizeof(Type)*/>
431 struct SfInterface<T>::PackInit_IntegerType_Atomic {
432   static inline void Init(PetscSFLink link) noexcept
433   {
434     link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
435     link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
436     link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
437     link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
438     link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
439     link->da_UnpackAndLAND   = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
440     link->da_UnpackAndLOR    = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
441     link->da_UnpackAndLXOR   = UnpackAndOp<Type, AtomicLXOR<Type>, BS, EQ>;
442     link->da_UnpackAndBAND   = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
443     link->da_UnpackAndBOR    = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
444     link->da_UnpackAndBXOR   = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
445     link->da_FetchAndAdd     = FetchAndOp<Type, AtomicAdd<Type>, BS, EQ>;
446 
447     link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
448     link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
449     link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
450     link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
451     link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
452     link->da_ScatterAndLAND   = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
453     link->da_ScatterAndLOR    = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
454     link->da_ScatterAndLXOR   = ScatterAndOp<Type, AtomicLXOR<Type>, BS, EQ>;
455     link->da_ScatterAndBAND   = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
456     link->da_ScatterAndBOR    = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
457     link->da_ScatterAndBXOR   = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
458     link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicAdd<Type>, BS, EQ>;
459   }
460 };
461 
462 /* CUDA does not support atomics on chars. It is TBD in PETSc. */
463 template <device::cupm::DeviceType T>
464 template <typename Type, PetscInt BS, PetscInt EQ>
465 struct SfInterface<T>::PackInit_IntegerType_Atomic<Type, BS, EQ, 1> {
466   static inline void Init(PetscSFLink)
467   { /* Nothing to leave function pointers NULL */
468   }
469 };
470 
471 template <device::cupm::DeviceType T>
472 template <typename Type, PetscInt BS, PetscInt EQ>
473 inline void SfInterface<T>::PackInit_IntegerType(PetscSFLink link) noexcept
474 {
475   link->d_Pack            = Pack<Type, BS, EQ>;
476   link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
477   link->d_UnpackAndAdd    = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>;
478   link->d_UnpackAndMult   = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>;
479   link->d_UnpackAndMin    = UnpackAndOp<Type, kernels::Min<Type>, BS, EQ>;
480   link->d_UnpackAndMax    = UnpackAndOp<Type, kernels::Max<Type>, BS, EQ>;
481   link->d_UnpackAndLAND   = UnpackAndOp<Type, kernels::LAND<Type>, BS, EQ>;
482   link->d_UnpackAndLOR    = UnpackAndOp<Type, kernels::LOR<Type>, BS, EQ>;
483   link->d_UnpackAndLXOR   = UnpackAndOp<Type, kernels::LXOR<Type>, BS, EQ>;
484   link->d_UnpackAndBAND   = UnpackAndOp<Type, kernels::BAND<Type>, BS, EQ>;
485   link->d_UnpackAndBOR    = UnpackAndOp<Type, kernels::BOR<Type>, BS, EQ>;
486   link->d_UnpackAndBXOR   = UnpackAndOp<Type, kernels::BXOR<Type>, BS, EQ>;
487   link->d_FetchAndAdd     = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>;
488 
489   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
490   link->d_ScatterAndAdd    = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>;
491   link->d_ScatterAndMult   = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>;
492   link->d_ScatterAndMin    = ScatterAndOp<Type, kernels::Min<Type>, BS, EQ>;
493   link->d_ScatterAndMax    = ScatterAndOp<Type, kernels::Max<Type>, BS, EQ>;
494   link->d_ScatterAndLAND   = ScatterAndOp<Type, kernels::LAND<Type>, BS, EQ>;
495   link->d_ScatterAndLOR    = ScatterAndOp<Type, kernels::LOR<Type>, BS, EQ>;
496   link->d_ScatterAndLXOR   = ScatterAndOp<Type, kernels::LXOR<Type>, BS, EQ>;
497   link->d_ScatterAndBAND   = ScatterAndOp<Type, kernels::BAND<Type>, BS, EQ>;
498   link->d_ScatterAndBOR    = ScatterAndOp<Type, kernels::BOR<Type>, BS, EQ>;
499   link->d_ScatterAndBXOR   = ScatterAndOp<Type, kernels::BXOR<Type>, BS, EQ>;
500   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>;
501   PackInit_IntegerType_Atomic<Type, BS, EQ, sizeof(Type)>::Init(link);
502 }
503 
504 #if defined(PETSC_HAVE_COMPLEX)
505 template <device::cupm::DeviceType T>
506 template <typename Type, PetscInt BS, PetscInt EQ>
507 inline void SfInterface<T>::PackInit_ComplexType(PetscSFLink link) noexcept
508 {
509   link->d_Pack            = Pack<Type, BS, EQ>;
510   link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
511   link->d_UnpackAndAdd    = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>;
512   link->d_UnpackAndMult   = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>;
513   link->d_FetchAndAdd     = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>;
514 
515   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
516   link->d_ScatterAndAdd    = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>;
517   link->d_ScatterAndMult   = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>;
518   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>;
519 
520   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
521   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
522   link->da_UnpackAndMult   = NULL; /* Not implemented yet */
523   link->da_FetchAndAdd     = NULL; /* Return value of atomicAdd on complex is not atomic */
524 
525   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
526   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
527 }
528 #endif
529 
530 typedef signed char   SignedChar;
531 typedef unsigned char UnsignedChar;
532 typedef struct {
533   int a;
534   int b;
535 } PairInt;
536 typedef struct {
537   PetscInt a;
538   PetscInt b;
539 } PairPetscInt;
540 
541 template <device::cupm::DeviceType T>
542 template <typename Type>
543 inline void SfInterface<T>::PackInit_PairType(PetscSFLink link) noexcept
544 {
545   link->d_Pack            = Pack<Type, 1, 1>;
546   link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, 1, 1>;
547   link->d_UnpackAndMaxloc = UnpackAndOp<Type, kernels::Maxloc<Type>, 1, 1>;
548   link->d_UnpackAndMinloc = UnpackAndOp<Type, kernels::Minloc<Type>, 1, 1>;
549 
550   link->d_ScatterAndInsert = ScatterAndOp<Type, kernels::Insert<Type>, 1, 1>;
551   link->d_ScatterAndMaxloc = ScatterAndOp<Type, kernels::Maxloc<Type>, 1, 1>;
552   link->d_ScatterAndMinloc = ScatterAndOp<Type, kernels::Minloc<Type>, 1, 1>;
553   /* Atomics for pair types are not implemented yet */
554 }
555 
556 template <device::cupm::DeviceType T>
557 template <typename Type, PetscInt BS, PetscInt EQ>
558 inline void SfInterface<T>::PackInit_DumbType(PetscSFLink link) noexcept
559 {
560   link->d_Pack             = Pack<Type, BS, EQ>;
561   link->d_UnpackAndInsert  = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
562   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
563   /* Atomics for dumb types are not implemented yet */
564 }
565 
566 /* Some device-specific utilities */
567 template <device::cupm::DeviceType T>
568 inline PetscErrorCode SfInterface<T>::LinkSyncDevice(PetscSFLink) noexcept
569 {
570   PetscFunctionBegin;
571   PetscCallCUPM(cupmDeviceSynchronize());
572   PetscFunctionReturn(PETSC_SUCCESS);
573 }
574 
575 template <device::cupm::DeviceType T>
576 inline PetscErrorCode SfInterface<T>::LinkSyncStream(PetscSFLink link) noexcept
577 {
578   PetscFunctionBegin;
579   PetscCallCUPM(cupmStreamSynchronize(link->stream));
580   PetscFunctionReturn(PETSC_SUCCESS);
581 }
582 
583 template <device::cupm::DeviceType T>
584 inline PetscErrorCode SfInterface<T>::LinkMemcpy(PetscSFLink link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n) noexcept
585 {
586   PetscFunctionBegin;
587   cupmMemcpyKind_t kinds[2][2] = {
588     {cupmMemcpyHostToHost,   cupmMemcpyHostToDevice  },
589     {cupmMemcpyDeviceToHost, cupmMemcpyDeviceToDevice}
590   };
591 
592   if (n) {
593     if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { /* Separate HostToHost so that pure-cpu code won't call cupm runtime */
594       PetscCall(PetscMemcpy(dst, src, n));
595     } else {
596       int stype = PetscMemTypeDevice(srcmtype) ? 1 : 0;
597       int dtype = PetscMemTypeDevice(dstmtype) ? 1 : 0;
598       PetscCallCUPM(cupmMemcpyAsync(dst, src, n, kinds[stype][dtype], link->stream));
599     }
600   }
601   PetscFunctionReturn(PETSC_SUCCESS);
602 }
603 
604 template <device::cupm::DeviceType T>
605 inline PetscErrorCode SfInterface<T>::Malloc(PetscMemType mtype, size_t size, void **ptr) noexcept
606 {
607   PetscFunctionBegin;
608   if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
609   else if (PetscMemTypeDevice(mtype)) {
610     PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
611     PetscCallCUPM(cupmMalloc(ptr, size));
612   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
613   PetscFunctionReturn(PETSC_SUCCESS);
614 }
615 
616 template <device::cupm::DeviceType T>
617 inline PetscErrorCode SfInterface<T>::Free(PetscMemType mtype, void *ptr) noexcept
618 {
619   PetscFunctionBegin;
620   if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
621   else if (PetscMemTypeDevice(mtype)) PetscCallCUPM(cupmFree(ptr));
622   else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
623   PetscFunctionReturn(PETSC_SUCCESS);
624 }
625 
626 /* Destructor when the link uses MPI for communication on CUPM device */
627 template <device::cupm::DeviceType T>
628 inline PetscErrorCode SfInterface<T>::LinkDestroy_MPI(PetscSF, PetscSFLink link) noexcept
629 {
630   PetscFunctionBegin;
631   for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
632     PetscCallCUPM(cupmFree(link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
633     PetscCallCUPM(cupmFree(link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
634   }
635   PetscFunctionReturn(PETSC_SUCCESS);
636 }
637 
638 /*====================================================================================*/
639 /*                Main driver to init MPI datatype on device                          */
640 /*====================================================================================*/
641 
642 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
643 template <device::cupm::DeviceType T>
644 inline PetscErrorCode SfInterface<T>::LinkSetUp(PetscSF sf, PetscSFLink link, MPI_Datatype unit) noexcept
645 {
646   PetscInt  nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
647   PetscBool is2Int, is2PetscInt;
648 #if defined(PETSC_HAVE_COMPLEX)
649   PetscInt nPetscComplex = 0;
650 #endif
651 
652   PetscFunctionBegin;
653   if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS);
654   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
655   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
656   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
657   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
658   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
659   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
660 #if defined(PETSC_HAVE_COMPLEX)
661   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
662 #endif
663   PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
664   PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));
665 
666   if (is2Int) {
667     PackInit_PairType<PairInt>(link);
668   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
669     PackInit_PairType<PairPetscInt>(link);
670   } else if (nPetscReal) {
671 #if !defined(PETSC_HAVE_DEVICE)
672     if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
673     else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
674     else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
675     else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
676     else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
677     else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
678     else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
679     else if (nPetscReal % 1 == 0)
680 #endif
681       PackInit_RealType<PetscReal, 1, 0>(link);
682   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
683 #if !defined(PETSC_HAVE_DEVICE)
684     if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
685     else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
686     else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
687     else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
688     else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
689     else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
690     else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
691     else if (nPetscInt % 1 == 0)
692 #endif
693       PackInit_IntegerType<llint, 1, 0>(link);
694   } else if (nInt) {
695 #if !defined(PETSC_HAVE_DEVICE)
696     if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
697     else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
698     else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
699     else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
700     else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
701     else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
702     else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
703     else if (nInt % 1 == 0)
704 #endif
705       PackInit_IntegerType<int, 1, 0>(link);
706   } else if (nSignedChar) {
707 #if !defined(PETSC_HAVE_DEVICE)
708     if (nSignedChar == 8) PackInit_IntegerType<SignedChar, 8, 1>(link);
709     else if (nSignedChar % 8 == 0) PackInit_IntegerType<SignedChar, 8, 0>(link);
710     else if (nSignedChar == 4) PackInit_IntegerType<SignedChar, 4, 1>(link);
711     else if (nSignedChar % 4 == 0) PackInit_IntegerType<SignedChar, 4, 0>(link);
712     else if (nSignedChar == 2) PackInit_IntegerType<SignedChar, 2, 1>(link);
713     else if (nSignedChar % 2 == 0) PackInit_IntegerType<SignedChar, 2, 0>(link);
714     else if (nSignedChar == 1) PackInit_IntegerType<SignedChar, 1, 1>(link);
715     else if (nSignedChar % 1 == 0)
716 #endif
717       PackInit_IntegerType<SignedChar, 1, 0>(link);
718   } else if (nUnsignedChar) {
719 #if !defined(PETSC_HAVE_DEVICE)
720     if (nUnsignedChar == 8) PackInit_IntegerType<UnsignedChar, 8, 1>(link);
721     else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<UnsignedChar, 8, 0>(link);
722     else if (nUnsignedChar == 4) PackInit_IntegerType<UnsignedChar, 4, 1>(link);
723     else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<UnsignedChar, 4, 0>(link);
724     else if (nUnsignedChar == 2) PackInit_IntegerType<UnsignedChar, 2, 1>(link);
725     else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<UnsignedChar, 2, 0>(link);
726     else if (nUnsignedChar == 1) PackInit_IntegerType<UnsignedChar, 1, 1>(link);
727     else if (nUnsignedChar % 1 == 0)
728 #endif
729       PackInit_IntegerType<UnsignedChar, 1, 0>(link);
730 #if defined(PETSC_HAVE_COMPLEX)
731   } else if (nPetscComplex) {
732   #if !defined(PETSC_HAVE_DEVICE)
733     if (nPetscComplex == 8) PackInit_ComplexType<PetscComplex, 8, 1>(link);
734     else if (nPetscComplex % 8 == 0) PackInit_ComplexType<PetscComplex, 8, 0>(link);
735     else if (nPetscComplex == 4) PackInit_ComplexType<PetscComplex, 4, 1>(link);
736     else if (nPetscComplex % 4 == 0) PackInit_ComplexType<PetscComplex, 4, 0>(link);
737     else if (nPetscComplex == 2) PackInit_ComplexType<PetscComplex, 2, 1>(link);
738     else if (nPetscComplex % 2 == 0) PackInit_ComplexType<PetscComplex, 2, 0>(link);
739     else if (nPetscComplex == 1) PackInit_ComplexType<PetscComplex, 1, 1>(link);
740     else if (nPetscComplex % 1 == 0)
741   #endif
742       PackInit_ComplexType<PetscComplex, 1, 0>(link);
743 #endif
744   } else {
745     MPI_Aint lb, nbyte;
746     PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte));
747     PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb);
748     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
749 #if !defined(PETSC_HAVE_DEVICE)
750       if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
751       else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
752       else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
753       else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
754       else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
755       else if (nbyte % 1 == 0)
756 #endif
757         PackInit_DumbType<char, 1, 0>(link);
758     } else {
759       nInt = nbyte / sizeof(int);
760 #if !defined(PETSC_HAVE_DEVICE)
761       if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
762       else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
763       else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
764       else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
765       else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
766       else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
767       else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
768       else if (nInt % 1 == 0)
769 #endif
770         PackInit_DumbType<int, 1, 0>(link);
771     }
772   }
773 
774   if (!sf->maxResidentThreadsPerGPU) { /* Not initialized */
775     int              device;
776     cupmDeviceProp_t props;
777     PetscCallCUPM(cupmGetDevice(&device));
778     PetscCallCUPM(cupmGetDeviceProperties(&props, device));
779     sf->maxResidentThreadsPerGPU = props.maxThreadsPerMultiProcessor * props.multiProcessorCount;
780   }
781   link->maxResidentThreadsPerGPU = sf->maxResidentThreadsPerGPU;
782 
783   {
784     cupmStream_t      *stream;
785     PetscDeviceContext dctx;
786 
787     PetscCall(PetscDeviceContextGetCurrentContextAssertType_Internal(&dctx, PETSC_DEVICE_CUPM()));
788     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
789     link->stream = *stream;
790   }
791   link->Destroy      = LinkDestroy_MPI;
792   link->SyncDevice   = LinkSyncDevice;
793   link->SyncStream   = LinkSyncStream;
794   link->Memcpy       = LinkMemcpy;
795   link->deviceinited = PETSC_TRUE;
796   PetscFunctionReturn(PETSC_SUCCESS);
797 }
798 
799 } // namespace impl
800 
801 } // namespace cupm
802 
803 } // namespace sf
804 
805 } // namespace Petsc
806