xref: /petsc/src/vec/is/sf/impls/basic/kokkos/sfkok.kokkos.cxx (revision 517f4e3460cd8c7e73c19b8bf533f1efe47f30a7)
1 #include <../src/vec/is/sf/impls/basic/sfpack.h>
2 
3 #include <petsc_kokkos.hpp>
4 
5 using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
6 using DeviceMemorySpace    = typename DeviceExecutionSpace::memory_space;
7 using HostMemorySpace      = Kokkos::HostSpace;
8 
9 typedef Kokkos::View<char *, DeviceMemorySpace> deviceBuffer_t;
10 typedef Kokkos::View<char *, HostMemorySpace>   HostBuffer_t;
11 
12 typedef Kokkos::View<const char *, DeviceMemorySpace> deviceConstBuffer_t;
13 typedef Kokkos::View<const char *, HostMemorySpace>   HostConstBuffer_t;
14 
15 /*====================================================================================*/
16 /*                             Regular operations                           */
17 /*====================================================================================*/
18 template <typename Type>
19 struct Insert {
20   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
21   {
22     Type old = x;
23     x        = y;
24     return old;
25   }
26 };
27 template <typename Type>
28 struct Add {
29   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
30   {
31     Type old = x;
32     x += y;
33     return old;
34   }
35 };
36 template <typename Type>
37 struct Mult {
38   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
39   {
40     Type old = x;
41     x *= y;
42     return old;
43   }
44 };
45 template <typename Type>
46 struct Min {
47   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
48   {
49     Type old = x;
50     x        = PetscMin(x, y);
51     return old;
52   }
53 };
54 template <typename Type>
55 struct Max {
56   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
57   {
58     Type old = x;
59     x        = PetscMax(x, y);
60     return old;
61   }
62 };
63 template <typename Type>
64 struct LAND {
65   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
66   {
67     Type old = x;
68     x        = x && y;
69     return old;
70   }
71 };
72 template <typename Type>
73 struct LOR {
74   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
75   {
76     Type old = x;
77     x        = x || y;
78     return old;
79   }
80 };
81 template <typename Type>
82 struct LXOR {
83   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
84   {
85     Type old = x;
86     x        = !x != !y;
87     return old;
88   }
89 };
90 template <typename Type>
91 struct BAND {
92   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
93   {
94     Type old = x;
95     x        = x & y;
96     return old;
97   }
98 };
99 template <typename Type>
100 struct BOR {
101   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
102   {
103     Type old = x;
104     x        = x | y;
105     return old;
106   }
107 };
108 template <typename Type>
109 struct BXOR {
110   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
111   {
112     Type old = x;
113     x        = x ^ y;
114     return old;
115   }
116 };
117 template <typename PairType>
118 struct Minloc {
119   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
120   {
121     PairType old = x;
122     if (y.first < x.first) x = y;
123     else if (y.first == x.first) x.second = PetscMin(x.second, y.second);
124     return old;
125   }
126 };
127 template <typename PairType>
128 struct Maxloc {
129   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
130   {
131     PairType old = x;
132     if (y.first > x.first) x = y;
133     else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */
134     return old;
135   }
136 };
137 
138 /*====================================================================================*/
139 /*                             Atomic operations                            */
140 /*====================================================================================*/
141 template <typename Type>
142 struct AtomicInsert {
143   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_assign(&x, y); }
144 };
145 template <typename Type>
146 struct AtomicAdd {
147   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); }
148 };
149 template <typename Type>
150 struct AtomicBAND {
151   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); }
152 };
153 template <typename Type>
154 struct AtomicBOR {
155   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); }
156 };
157 template <typename Type>
158 struct AtomicBXOR {
159   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); }
160 };
161 template <typename Type>
162 struct AtomicLAND {
163   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
164   {
165     const Type zero = 0, one = ~0;
166     Kokkos::atomic_and(&x, y ? one : zero);
167   }
168 };
169 template <typename Type>
170 struct AtomicLOR {
171   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
172   {
173     const Type zero = 0, one = 1;
174     Kokkos::atomic_or(&x, y ? one : zero);
175   }
176 };
177 template <typename Type>
178 struct AtomicMult {
179   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); }
180 };
181 template <typename Type>
182 struct AtomicMin {
183   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); }
184 };
185 template <typename Type>
186 struct AtomicMax {
187   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); }
188 };
189 /* TODO: struct AtomicLXOR  */
190 template <typename Type>
191 struct AtomicFetchAdd {
192   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); }
193 };
194 
195 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
196 static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid)
197 {
198   PetscInt        i, j, k, m, n, r;
199   const PetscInt *offset, *start, *dx, *dy, *X, *Y;
200 
201   n      = opt[0];
202   offset = opt + 1;
203   start  = opt + n + 2;
204   dx     = opt + 2 * n + 2;
205   dy     = opt + 3 * n + 2;
206   X      = opt + 5 * n + 2;
207   Y      = opt + 6 * n + 2;
208   for (r = 0; r < n; r++) {
209     if (tid < offset[r + 1]) break;
210   }
211   m = (tid - offset[r]);
212   k = m / (dx[r] * dy[r]);
213   j = (m - k * dx[r] * dy[r]) / dx[r];
214   i = m - k * dx[r] * dy[r] - j * dx[r];
215 
216   return (start[r] + k * X[r] * Y[r] + j * X[r] + i);
217 }
218 
219 /*====================================================================================*/
220 /*  Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link'         */
221 /*====================================================================================*/
222 
223 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
224    <Type> is PetscReal, which is the primitive type we operate on.
225    <bs>   is 16, which says <unit> contains 16 primitive types.
226    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
227    <EQ>   is 0, which is (bs == BS ? 1 : 0)
228 
229   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
230   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
231 */
232 template <typename Type, PetscInt BS, PetscInt EQ>
233 static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_)
234 {
235   const PetscInt       *iopt = opt ? opt->array : NULL;
236   const PetscInt        M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */
237   const Type           *data = static_cast<const Type *>(data_);
238   Type                 *buf  = static_cast<Type *>(buf_);
239   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
240 
241   PetscFunctionBegin;
242   Kokkos::parallel_for(
243     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
244       /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
245        iopt == NULL && idx == NULL ==> the indices are contiguous;
246      */
247       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
248       PetscInt s = tid * MBS;
249       for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i];
250     });
251   PetscFunctionReturn(PETSC_SUCCESS);
252 }
253 
254 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
255 static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_)
256 {
257   Op                    op;
258   const PetscInt       *iopt = opt ? opt->array : NULL;
259   const PetscInt        M = EQ ? 1 : link->bs / BS, MBS = M * BS;
260   Type                 *data = static_cast<Type *>(data_);
261   const Type           *buf  = static_cast<const Type *>(buf_);
262   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
263 
264   PetscFunctionBegin;
265   Kokkos::parallel_for(
266     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
267       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
268       PetscInt s = tid * MBS;
269       for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
270     });
271   PetscFunctionReturn(PETSC_SUCCESS);
272 }
273 
274 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
275 static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf)
276 {
277   Op                    op;
278   const PetscInt       *ropt = opt ? opt->array : NULL;
279   const PetscInt        M = EQ ? 1 : link->bs / BS, MBS = M * BS;
280   Type                 *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf);
281   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
282 
283   PetscFunctionBegin;
284   Kokkos::parallel_for(
285     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
286       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
287       PetscInt l = tid * MBS;
288       for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
289     });
290   PetscFunctionReturn(PETSC_SUCCESS);
291 }
292 
293 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
294 static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
295 {
296   PetscInt              srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;
297   const PetscInt        M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
298   const Type           *src  = static_cast<const Type *>(src_);
299   Type                 *dst  = static_cast<Type *>(dst_);
300   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
301 
302   PetscFunctionBegin;
303   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
304   if (srcOpt) {
305     srcx     = srcOpt->dx[0];
306     srcy     = srcOpt->dy[0];
307     srcX     = srcOpt->X[0];
308     srcY     = srcOpt->Y[0];
309     srcStart = srcOpt->start[0];
310     srcIdx   = NULL;
311   } else if (!srcIdx) {
312     srcx = srcX = count;
313     srcy = srcY = 1;
314   }
315 
316   if (dstOpt) {
317     dstx     = dstOpt->dx[0];
318     dsty     = dstOpt->dy[0];
319     dstX     = dstOpt->X[0];
320     dstY     = dstOpt->Y[0];
321     dstStart = dstOpt->start[0];
322     dstIdx   = NULL;
323   } else if (!dstIdx) {
324     dstx = dstX = count;
325     dsty = dstY = 1;
326   }
327 
328   Kokkos::parallel_for(
329     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
330       PetscInt i, j, k, s, t;
331       Op       op;
332       if (!srcIdx) { /* src is in 3D */
333         k = tid / (srcx * srcy);
334         j = (tid - k * srcx * srcy) / srcx;
335         i = tid - k * srcx * srcy - j * srcx;
336         s = srcStart + k * srcX * srcY + j * srcX + i;
337       } else { /* src is contiguous */
338         s = srcIdx[tid];
339       }
340 
341       if (!dstIdx) { /* 3D */
342         k = tid / (dstx * dsty);
343         j = (tid - k * dstx * dsty) / dstx;
344         i = tid - k * dstx * dsty - j * dstx;
345         t = dstStart + k * dstX * dstY + j * dstX + i;
346       } else { /* contiguous */
347         t = dstIdx[tid];
348       }
349 
350       s *= MBS;
351       t *= MBS;
352       for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
353     });
354   PetscFunctionReturn(PETSC_SUCCESS);
355 }
356 
357 /* Specialization for Insert since we may use memcpy */
358 template <typename Type, PetscInt BS, PetscInt EQ>
359 static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
360 {
361   const Type           *src  = static_cast<const Type *>(src_);
362   Type                 *dst  = static_cast<Type *>(dst_);
363   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
364 
365   PetscFunctionBegin;
366   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
367   /*src and dst are contiguous */
368   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
369     size_t              sz = count * link->unitbytes;
370     deviceBuffer_t      dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz);
371     deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz);
372     Kokkos::deep_copy(exec, dbuf, sbuf);
373   } else {
374     PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
375   }
376   PetscFunctionReturn(PETSC_SUCCESS);
377 }
378 
379 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
380 static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_)
381 {
382   Op                    op;
383   const PetscInt        M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
384   const PetscInt       *ropt     = rootopt ? rootopt->array : NULL;
385   const PetscInt       *lopt     = leafopt ? leafopt->array : NULL;
386   Type                 *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_);
387   const Type           *leafdata = static_cast<const Type *>(leafdata_);
388   DeviceExecutionSpace &exec     = PetscGetKokkosExecutionSpace();
389 
390   PetscFunctionBegin;
391   Kokkos::parallel_for(
392     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
393       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
394       PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
395       for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
396     });
397   PetscFunctionReturn(PETSC_SUCCESS);
398 }
399 
400 /*====================================================================================*/
401 /*  Init various types and instantiate pack/unpack function pointers                  */
402 /*====================================================================================*/
403 template <typename Type, PetscInt BS, PetscInt EQ>
404 static void PackInit_RealType(PetscSFLink link)
405 {
406   /* Pack/unpack for remote communication */
407   link->d_Pack            = Pack<Type, BS, EQ>;
408   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
409   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
410   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
411   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
412   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
413   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
414   /* Scatter for local communication */
415   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
416   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
417   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
418   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
419   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
420   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
421   /* Atomic versions when there are data-race possibilities */
422   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
423   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
424   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
425   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
426   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
427   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
428 
429   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
430   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
431   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
432   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
433   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
434   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
435 }
436 
437 template <typename Type, PetscInt BS, PetscInt EQ>
438 static void PackInit_IntegerType(PetscSFLink link)
439 {
440   link->d_Pack            = Pack<Type, BS, EQ>;
441   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
442   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
443   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
444   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
445   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
446   link->d_UnpackAndLAND   = UnpackAndOp<Type, LAND<Type>, BS, EQ>;
447   link->d_UnpackAndLOR    = UnpackAndOp<Type, LOR<Type>, BS, EQ>;
448   link->d_UnpackAndLXOR   = UnpackAndOp<Type, LXOR<Type>, BS, EQ>;
449   link->d_UnpackAndBAND   = UnpackAndOp<Type, BAND<Type>, BS, EQ>;
450   link->d_UnpackAndBOR    = UnpackAndOp<Type, BOR<Type>, BS, EQ>;
451   link->d_UnpackAndBXOR   = UnpackAndOp<Type, BXOR<Type>, BS, EQ>;
452   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
453 
454   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
455   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
456   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
457   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
458   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
459   link->d_ScatterAndLAND   = ScatterAndOp<Type, LAND<Type>, BS, EQ>;
460   link->d_ScatterAndLOR    = ScatterAndOp<Type, LOR<Type>, BS, EQ>;
461   link->d_ScatterAndLXOR   = ScatterAndOp<Type, LXOR<Type>, BS, EQ>;
462   link->d_ScatterAndBAND   = ScatterAndOp<Type, BAND<Type>, BS, EQ>;
463   link->d_ScatterAndBOR    = ScatterAndOp<Type, BOR<Type>, BS, EQ>;
464   link->d_ScatterAndBXOR   = ScatterAndOp<Type, BXOR<Type>, BS, EQ>;
465   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
466 
467   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
468   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
469   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
470   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
471   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
472   link->da_UnpackAndLAND   = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
473   link->da_UnpackAndLOR    = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
474   link->da_UnpackAndBAND   = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
475   link->da_UnpackAndBOR    = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
476   link->da_UnpackAndBXOR   = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
477   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
478 
479   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
480   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
481   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
482   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
483   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
484   link->da_ScatterAndLAND   = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
485   link->da_ScatterAndLOR    = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
486   link->da_ScatterAndBAND   = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
487   link->da_ScatterAndBOR    = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
488   link->da_ScatterAndBXOR   = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
489   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
490 }
491 
492 #if defined(PETSC_HAVE_COMPLEX)
493 template <typename Type, PetscInt BS, PetscInt EQ>
494 static void PackInit_ComplexType(PetscSFLink link)
495 {
496   link->d_Pack            = Pack<Type, BS, EQ>;
497   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
498   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
499   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
500   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
501 
502   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
503   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
504   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
505   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
506 
507   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
508   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
509   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
510   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
511 
512   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
513   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
514   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
515   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
516 }
517 #endif
518 
519 template <typename Type>
520 static void PackInit_PairType(PetscSFLink link)
521 {
522   link->d_Pack            = Pack<Type, 1, 1>;
523   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>;
524   link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>;
525   link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>;
526 
527   link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>;
528   link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>;
529   link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>;
530   /* Atomics for pair types are not implemented yet */
531 }
532 
533 template <typename Type, PetscInt BS, PetscInt EQ>
534 static void PackInit_DumbType(PetscSFLink link)
535 {
536   link->d_Pack             = Pack<Type, BS, EQ>;
537   link->d_UnpackAndInsert  = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
538   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
539   /* Atomics for dumb types are not implemented yet */
540 }
541 
542 /*
543   Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
544   that one is not able to repeatedly create and destroy the object. SF's original design was each
545   SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
546   destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
547   memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
548   does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
549   object in Kokkos.
550 */
551 /*
552 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
553 {
554   PetscFunctionBegin;
555   PetscFunctionReturn(PETSC_SUCCESS);
556 }
557 */
558 
559 /* Some device-specific utilities */
560 static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link)
561 {
562   PetscFunctionBegin;
563   Kokkos::fence();
564   PetscFunctionReturn(PETSC_SUCCESS);
565 }
566 
567 static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link)
568 {
569   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
570   PetscFunctionBegin;
571   exec.fence();
572   PetscFunctionReturn(PETSC_SUCCESS);
573 }
574 
575 static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n)
576 {
577   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
578 
579   PetscFunctionBegin;
580   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
581   if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) {
582     PetscCall(PetscMemcpy(dst, src, n));
583   } else {
584     if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) { // H2D
585       deviceBuffer_t    dbuf(static_cast<char *>(dst), n);
586       HostConstBuffer_t sbuf(static_cast<const char *>(src), n);
587       PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
588       PetscCall(PetscLogCpuToGpu(n));
589     } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) { // D2H
590       HostBuffer_t        dbuf(static_cast<char *>(dst), n);
591       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
592       PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
593       PetscCallCXX(exec.fence()); // make sure dbuf is ready for use immediately on host
594       PetscCall(PetscLogGpuToCpu(n));
595     } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) {
596       deviceBuffer_t      dbuf(static_cast<char *>(dst), n);
597       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
598       PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
599     }
600   }
601   PetscFunctionReturn(PETSC_SUCCESS);
602 }
603 
604 PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr)
605 {
606   PetscFunctionBegin;
607   if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
608   else if (PetscMemTypeDevice(mtype)) {
609     if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck());
610     PetscCallCXX(*ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size));
611   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
612   PetscFunctionReturn(PETSC_SUCCESS);
613 }
614 
615 PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr)
616 {
617   PetscFunctionBegin;
618   if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
619   else if (PetscMemTypeDevice(mtype)) {
620     PetscCallCXX(Kokkos::kokkos_free<DeviceMemorySpace>(ptr));
621   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
622   PetscFunctionReturn(PETSC_SUCCESS);
623 }
624 
625 /* Destructor when the link uses MPI for communication */
626 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link)
627 {
628   PetscFunctionBegin;
629   for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
630     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
631     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
632   }
633   PetscFunctionReturn(PETSC_SUCCESS);
634 }
635 
636 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
637 PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit)
638 {
639   PetscInt  nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
640   PetscBool is2Int, is2PetscInt;
641 #if defined(PETSC_HAVE_COMPLEX)
642   PetscInt nPetscComplex = 0;
643 #endif
644 
645   PetscFunctionBegin;
646   if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS);
647   PetscCall(PetscKokkosInitializeCheck());
648   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
649   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
650   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
651   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
652   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
653   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
654 #if defined(PETSC_HAVE_COMPLEX)
655   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
656 #endif
657   PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
658   PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));
659 
660   if (is2Int) {
661     PackInit_PairType<Kokkos::pair<int, int>>(link);
662   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
663     PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link);
664   } else if (nPetscReal) {
665 #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */
666     if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
667     else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
668     else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
669     else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
670     else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
671     else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
672     else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
673     else if (nPetscReal % 1 == 0)
674 #endif
675       PackInit_RealType<PetscReal, 1, 0>(link);
676   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
677 #if !defined(PETSC_HAVE_DEVICE)
678     if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
679     else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
680     else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
681     else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
682     else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
683     else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
684     else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
685     else if (nPetscInt % 1 == 0)
686 #endif
687       PackInit_IntegerType<llint, 1, 0>(link);
688   } else if (nInt) {
689 #if !defined(PETSC_HAVE_DEVICE)
690     if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
691     else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
692     else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
693     else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
694     else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
695     else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
696     else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
697     else if (nInt % 1 == 0)
698 #endif
699       PackInit_IntegerType<int, 1, 0>(link);
700   } else if (nSignedChar) {
701 #if !defined(PETSC_HAVE_DEVICE)
702     if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link);
703     else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link);
704     else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link);
705     else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link);
706     else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link);
707     else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link);
708     else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link);
709     else if (nSignedChar % 1 == 0)
710 #endif
711       PackInit_IntegerType<char, 1, 0>(link);
712   } else if (nUnsignedChar) {
713 #if !defined(PETSC_HAVE_DEVICE)
714     if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link);
715     else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link);
716     else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link);
717     else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link);
718     else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link);
719     else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link);
720     else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link);
721     else if (nUnsignedChar % 1 == 0)
722 #endif
723       PackInit_IntegerType<unsigned char, 1, 0>(link);
724 #if defined(PETSC_HAVE_COMPLEX)
725   } else if (nPetscComplex) {
726   #if !defined(PETSC_HAVE_DEVICE)
727     if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link);
728     else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link);
729     else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link);
730     else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link);
731     else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link);
732     else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link);
733     else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link);
734     else if (nPetscComplex % 1 == 0)
735   #endif
736       PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link);
737 #endif
738   } else {
739     MPI_Aint lb, nbyte;
740     PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte));
741     PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb);
742     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
743 #if !defined(PETSC_HAVE_DEVICE)
744       if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
745       else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
746       else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
747       else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
748       else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
749       else if (nbyte % 1 == 0)
750 #endif
751         PackInit_DumbType<char, 1, 0>(link);
752     } else {
753       nInt = nbyte / sizeof(int);
754 #if !defined(PETSC_HAVE_DEVICE)
755       if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
756       else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
757       else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
758       else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
759       else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
760       else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
761       else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
762       else if (nInt % 1 == 0)
763 #endif
764         PackInit_DumbType<int, 1, 0>(link);
765     }
766   }
767 
768   link->SyncDevice   = PetscSFLinkSyncDevice_Kokkos;
769   link->SyncStream   = PetscSFLinkSyncStream_Kokkos;
770   link->Memcpy       = PetscSFLinkMemcpy_Kokkos;
771   link->Destroy      = PetscSFLinkDestroy_Kokkos;
772   link->deviceinited = PETSC_TRUE;
773   PetscFunctionReturn(PETSC_SUCCESS);
774 }
775