xref: /petsc/src/vec/is/sf/impls/basic/kokkos/sfkok.kokkos.cxx (revision 58d68138c660dfb4e9f5b03334792cd4f2ffd7cc)
1 #include <../src/vec/is/sf/impls/basic/sfpack.h>
2 
3 #include <Kokkos_Core.hpp>
4 
5 using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
6 using DeviceMemorySpace    = typename DeviceExecutionSpace::memory_space;
7 using HostMemorySpace      = Kokkos::HostSpace;
8 
9 typedef Kokkos::View<char *, DeviceMemorySpace> deviceBuffer_t;
10 typedef Kokkos::View<char *, HostMemorySpace>   HostBuffer_t;
11 
12 typedef Kokkos::View<const char *, DeviceMemorySpace> deviceConstBuffer_t;
13 typedef Kokkos::View<const char *, HostMemorySpace>   HostConstBuffer_t;
14 
15 /*====================================================================================*/
16 /*                             Regular operations                           */
17 /*====================================================================================*/
18 template <typename Type>
19 struct Insert {
20   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
21     Type old = x;
22     x        = y;
23     return old;
24   }
25 };
26 template <typename Type>
27 struct Add {
28   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
29     Type old = x;
30     x += y;
31     return old;
32   }
33 };
34 template <typename Type>
35 struct Mult {
36   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
37     Type old = x;
38     x *= y;
39     return old;
40   }
41 };
42 template <typename Type>
43 struct Min {
44   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
45     Type old = x;
46     x        = PetscMin(x, y);
47     return old;
48   }
49 };
50 template <typename Type>
51 struct Max {
52   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
53     Type old = x;
54     x        = PetscMax(x, y);
55     return old;
56   }
57 };
58 template <typename Type>
59 struct LAND {
60   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
61     Type old = x;
62     x        = x && y;
63     return old;
64   }
65 };
66 template <typename Type>
67 struct LOR {
68   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
69     Type old = x;
70     x        = x || y;
71     return old;
72   }
73 };
74 template <typename Type>
75 struct LXOR {
76   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
77     Type old = x;
78     x        = !x != !y;
79     return old;
80   }
81 };
82 template <typename Type>
83 struct BAND {
84   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
85     Type old = x;
86     x        = x & y;
87     return old;
88   }
89 };
90 template <typename Type>
91 struct BOR {
92   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
93     Type old = x;
94     x        = x | y;
95     return old;
96   }
97 };
98 template <typename Type>
99 struct BXOR {
100   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
101     Type old = x;
102     x        = x ^ y;
103     return old;
104   }
105 };
106 template <typename PairType>
107 struct Minloc {
108   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const {
109     PairType old = x;
110     if (y.first < x.first) x = y;
111     else if (y.first == x.first) x.second = PetscMin(x.second, y.second);
112     return old;
113   }
114 };
115 template <typename PairType>
116 struct Maxloc {
117   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const {
118     PairType old = x;
119     if (y.first > x.first) x = y;
120     else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */
121     return old;
122   }
123 };
124 
125 /*====================================================================================*/
126 /*                             Atomic operations                            */
127 /*====================================================================================*/
128 template <typename Type>
129 struct AtomicInsert {
130   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_assign(&x, y); }
131 };
132 template <typename Type>
133 struct AtomicAdd {
134   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); }
135 };
136 template <typename Type>
137 struct AtomicBAND {
138   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); }
139 };
140 template <typename Type>
141 struct AtomicBOR {
142   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); }
143 };
144 template <typename Type>
145 struct AtomicBXOR {
146   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); }
147 };
148 template <typename Type>
149 struct AtomicLAND {
150   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const {
151     const Type zero = 0, one = ~0;
152     Kokkos::atomic_and(&x, y ? one : zero);
153   }
154 };
155 template <typename Type>
156 struct AtomicLOR {
157   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const {
158     const Type zero = 0, one = 1;
159     Kokkos::atomic_or(&x, y ? one : zero);
160   }
161 };
162 template <typename Type>
163 struct AtomicMult {
164   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); }
165 };
166 template <typename Type>
167 struct AtomicMin {
168   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); }
169 };
170 template <typename Type>
171 struct AtomicMax {
172   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); }
173 };
174 /* TODO: struct AtomicLXOR  */
175 template <typename Type>
176 struct AtomicFetchAdd {
177   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); }
178 };
179 
180 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
181 static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid) {
182   PetscInt        i, j, k, m, n, r;
183   const PetscInt *offset, *start, *dx, *dy, *X, *Y;
184 
185   n      = opt[0];
186   offset = opt + 1;
187   start  = opt + n + 2;
188   dx     = opt + 2 * n + 2;
189   dy     = opt + 3 * n + 2;
190   X      = opt + 5 * n + 2;
191   Y      = opt + 6 * n + 2;
192   for (r = 0; r < n; r++) {
193     if (tid < offset[r + 1]) break;
194   }
195   m = (tid - offset[r]);
196   k = m / (dx[r] * dy[r]);
197   j = (m - k * dx[r] * dy[r]) / dx[r];
198   i = m - k * dx[r] * dy[r] - j * dx[r];
199 
200   return (start[r] + k * X[r] * Y[r] + j * X[r] + i);
201 }
202 
203 /*====================================================================================*/
204 /*  Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link'         */
205 /*====================================================================================*/
206 
207 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
208    <Type> is PetscReal, which is the primitive type we operate on.
209    <bs>   is 16, which says <unit> contains 16 primitive types.
210    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
211    <EQ>   is 0, which is (bs == BS ? 1 : 0)
212 
213   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
214   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
215 */
216 template <typename Type, PetscInt BS, PetscInt EQ>
217 static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_) {
218   const PetscInt      *iopt = opt ? opt->array : NULL;
219   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */
220   const Type          *data = static_cast<const Type *>(data_);
221   Type                *buf  = static_cast<Type *>(buf_);
222   DeviceExecutionSpace exec;
223 
224   PetscFunctionBegin;
225   Kokkos::parallel_for(
226     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
227       /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
228        iopt == NULL && idx == NULL ==> the indices are contiguous;
229      */
230       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
231       PetscInt s = tid * MBS;
232       for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i];
233     });
234   PetscFunctionReturn(0);
235 }
236 
237 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
238 static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_) {
239   Op                   op;
240   const PetscInt      *iopt = opt ? opt->array : NULL;
241   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS;
242   Type                *data = static_cast<Type *>(data_);
243   const Type          *buf  = static_cast<const Type *>(buf_);
244   DeviceExecutionSpace exec;
245 
246   PetscFunctionBegin;
247   Kokkos::parallel_for(
248     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
249       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
250       PetscInt s = tid * MBS;
251       for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
252     });
253   PetscFunctionReturn(0);
254 }
255 
256 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
257 static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf) {
258   Op                   op;
259   const PetscInt      *ropt = opt ? opt->array : NULL;
260   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS;
261   Type                *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf);
262   DeviceExecutionSpace exec;
263 
264   PetscFunctionBegin;
265   Kokkos::parallel_for(
266     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
267       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
268       PetscInt l = tid * MBS;
269       for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
270     });
271   PetscFunctionReturn(0);
272 }
273 
274 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
275 static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) {
276   PetscInt             srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;
277   const PetscInt       M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
278   const Type          *src = static_cast<const Type *>(src_);
279   Type                *dst = static_cast<Type *>(dst_);
280   DeviceExecutionSpace exec;
281 
282   PetscFunctionBegin;
283   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
284   if (srcOpt) {
285     srcx     = srcOpt->dx[0];
286     srcy     = srcOpt->dy[0];
287     srcX     = srcOpt->X[0];
288     srcY     = srcOpt->Y[0];
289     srcStart = srcOpt->start[0];
290     srcIdx   = NULL;
291   } else if (!srcIdx) {
292     srcx = srcX = count;
293     srcy = srcY = 1;
294   }
295 
296   if (dstOpt) {
297     dstx     = dstOpt->dx[0];
298     dsty     = dstOpt->dy[0];
299     dstX     = dstOpt->X[0];
300     dstY     = dstOpt->Y[0];
301     dstStart = dstOpt->start[0];
302     dstIdx   = NULL;
303   } else if (!dstIdx) {
304     dstx = dstX = count;
305     dsty = dstY = 1;
306   }
307 
308   Kokkos::parallel_for(
309     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
310       PetscInt i, j, k, s, t;
311       Op       op;
312       if (!srcIdx) { /* src is in 3D */
313         k = tid / (srcx * srcy);
314         j = (tid - k * srcx * srcy) / srcx;
315         i = tid - k * srcx * srcy - j * srcx;
316         s = srcStart + k * srcX * srcY + j * srcX + i;
317       } else { /* src is contiguous */
318         s = srcIdx[tid];
319       }
320 
321       if (!dstIdx) { /* 3D */
322         k = tid / (dstx * dsty);
323         j = (tid - k * dstx * dsty) / dstx;
324         i = tid - k * dstx * dsty - j * dstx;
325         t = dstStart + k * dstX * dstY + j * dstX + i;
326       } else { /* contiguous */
327         t = dstIdx[tid];
328       }
329 
330       s *= MBS;
331       t *= MBS;
332       for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
333     });
334   PetscFunctionReturn(0);
335 }
336 
337 /* Specialization for Insert since we may use memcpy */
338 template <typename Type, PetscInt BS, PetscInt EQ>
339 static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) {
340   const Type          *src = static_cast<const Type *>(src_);
341   Type                *dst = static_cast<Type *>(dst_);
342   DeviceExecutionSpace exec;
343 
344   PetscFunctionBegin;
345   if (!count) PetscFunctionReturn(0);
346   /*src and dst are contiguous */
347   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
348     size_t              sz = count * link->unitbytes;
349     deviceBuffer_t      dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz);
350     deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz);
351     Kokkos::deep_copy(exec, dbuf, sbuf);
352   } else {
353     PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
354   }
355   PetscFunctionReturn(0);
356 }
357 
358 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
359 static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_) {
360   Op                   op;
361   const PetscInt       M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
362   const PetscInt      *ropt     = rootopt ? rootopt->array : NULL;
363   const PetscInt      *lopt     = leafopt ? leafopt->array : NULL;
364   Type                *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_);
365   const Type          *leafdata = static_cast<const Type *>(leafdata_);
366   DeviceExecutionSpace exec;
367 
368   PetscFunctionBegin;
369   Kokkos::parallel_for(
370     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
371       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
372       PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
373       for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
374     });
375   PetscFunctionReturn(0);
376 }
377 
378 /*====================================================================================*/
379 /*  Init various types and instantiate pack/unpack function pointers                  */
380 /*====================================================================================*/
381 template <typename Type, PetscInt BS, PetscInt EQ>
382 static void PackInit_RealType(PetscSFLink link) {
383   /* Pack/unpack for remote communication */
384   link->d_Pack             = Pack<Type, BS, EQ>;
385   link->d_UnpackAndInsert  = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
386   link->d_UnpackAndAdd     = UnpackAndOp<Type, Add<Type>, BS, EQ>;
387   link->d_UnpackAndMult    = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
388   link->d_UnpackAndMin     = UnpackAndOp<Type, Min<Type>, BS, EQ>;
389   link->d_UnpackAndMax     = UnpackAndOp<Type, Max<Type>, BS, EQ>;
390   link->d_FetchAndAdd      = FetchAndOp<Type, Add<Type>, BS, EQ>;
391   /* Scatter for local communication */
392   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
393   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
394   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
395   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
396   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
397   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
398   /* Atomic versions when there are data-race possibilities */
399   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
400   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
401   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
402   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
403   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
404   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
405 
406   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
407   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
408   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
409   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
410   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
411   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
412 }
413 
414 template <typename Type, PetscInt BS, PetscInt EQ>
415 static void PackInit_IntegerType(PetscSFLink link) {
416   link->d_Pack            = Pack<Type, BS, EQ>;
417   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
418   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
419   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
420   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
421   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
422   link->d_UnpackAndLAND   = UnpackAndOp<Type, LAND<Type>, BS, EQ>;
423   link->d_UnpackAndLOR    = UnpackAndOp<Type, LOR<Type>, BS, EQ>;
424   link->d_UnpackAndLXOR   = UnpackAndOp<Type, LXOR<Type>, BS, EQ>;
425   link->d_UnpackAndBAND   = UnpackAndOp<Type, BAND<Type>, BS, EQ>;
426   link->d_UnpackAndBOR    = UnpackAndOp<Type, BOR<Type>, BS, EQ>;
427   link->d_UnpackAndBXOR   = UnpackAndOp<Type, BXOR<Type>, BS, EQ>;
428   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
429 
430   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
431   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
432   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
433   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
434   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
435   link->d_ScatterAndLAND   = ScatterAndOp<Type, LAND<Type>, BS, EQ>;
436   link->d_ScatterAndLOR    = ScatterAndOp<Type, LOR<Type>, BS, EQ>;
437   link->d_ScatterAndLXOR   = ScatterAndOp<Type, LXOR<Type>, BS, EQ>;
438   link->d_ScatterAndBAND   = ScatterAndOp<Type, BAND<Type>, BS, EQ>;
439   link->d_ScatterAndBOR    = ScatterAndOp<Type, BOR<Type>, BS, EQ>;
440   link->d_ScatterAndBXOR   = ScatterAndOp<Type, BXOR<Type>, BS, EQ>;
441   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
442 
443   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
444   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
445   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
446   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
447   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
448   link->da_UnpackAndLAND   = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
449   link->da_UnpackAndLOR    = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
450   link->da_UnpackAndBAND   = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
451   link->da_UnpackAndBOR    = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
452   link->da_UnpackAndBXOR   = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
453   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
454 
455   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
456   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
457   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
458   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
459   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
460   link->da_ScatterAndLAND   = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
461   link->da_ScatterAndLOR    = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
462   link->da_ScatterAndBAND   = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
463   link->da_ScatterAndBOR    = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
464   link->da_ScatterAndBXOR   = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
465   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
466 }
467 
468 #if defined(PETSC_HAVE_COMPLEX)
469 template <typename Type, PetscInt BS, PetscInt EQ>
470 static void PackInit_ComplexType(PetscSFLink link) {
471   link->d_Pack            = Pack<Type, BS, EQ>;
472   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
473   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
474   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
475   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
476 
477   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
478   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
479   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
480   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
481 
482   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
483   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
484   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
485   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
486 
487   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
488   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
489   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
490   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
491 }
492 #endif
493 
494 template <typename Type>
495 static void PackInit_PairType(PetscSFLink link) {
496   link->d_Pack            = Pack<Type, 1, 1>;
497   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>;
498   link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>;
499   link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>;
500 
501   link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>;
502   link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>;
503   link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>;
504   /* Atomics for pair types are not implemented yet */
505 }
506 
507 template <typename Type, PetscInt BS, PetscInt EQ>
508 static void PackInit_DumbType(PetscSFLink link) {
509   link->d_Pack             = Pack<Type, BS, EQ>;
510   link->d_UnpackAndInsert  = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
511   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
512   /* Atomics for dumb types are not implemented yet */
513 }
514 
515 /*
516   Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
517   that one is not able to repeatedly create and destroy the object. SF's original design was each
518   SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
519   destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
520   memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
521   does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
522   object in Kokkos.
523 */
524 /*
525 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
526 {
527   PetscFunctionBegin;
528   PetscFunctionReturn(0);
529 }
530 */
531 
532 /* Some device-specific utilities */
533 static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link) {
534   PetscFunctionBegin;
535   Kokkos::fence();
536   PetscFunctionReturn(0);
537 }
538 
539 static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link) {
540   DeviceExecutionSpace exec;
541   PetscFunctionBegin;
542   exec.fence();
543   PetscFunctionReturn(0);
544 }
545 
546 static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n) {
547   DeviceExecutionSpace exec;
548 
549   PetscFunctionBegin;
550   if (!n) PetscFunctionReturn(0);
551   if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) {
552     PetscCall(PetscMemcpy(dst, src, n));
553   } else {
554     if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) {
555       deviceBuffer_t    dbuf(static_cast<char *>(dst), n);
556       HostConstBuffer_t sbuf(static_cast<const char *>(src), n);
557       Kokkos::deep_copy(exec, dbuf, sbuf);
558       PetscCall(PetscLogCpuToGpu(n));
559     } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) {
560       HostBuffer_t        dbuf(static_cast<char *>(dst), n);
561       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
562       Kokkos::deep_copy(exec, dbuf, sbuf);
563       PetscCall(PetscLogGpuToCpu(n));
564     } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) {
565       deviceBuffer_t      dbuf(static_cast<char *>(dst), n);
566       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
567       Kokkos::deep_copy(exec, dbuf, sbuf);
568     }
569   }
570   PetscFunctionReturn(0);
571 }
572 
573 PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr) {
574   PetscFunctionBegin;
575   if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
576   else if (PetscMemTypeDevice(mtype)) {
577     if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck());
578     *ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size);
579   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
580   PetscFunctionReturn(0);
581 }
582 
583 PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr) {
584   PetscFunctionBegin;
585   if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
586   else if (PetscMemTypeDevice(mtype)) {
587     Kokkos::kokkos_free<DeviceMemorySpace>(ptr);
588   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
589   PetscFunctionReturn(0);
590 }
591 
592 /* Destructor when the link uses MPI for communication */
593 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link) {
594   PetscFunctionBegin;
595   for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
596     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
597     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
598   }
599   PetscFunctionReturn(0);
600 }
601 
602 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
603 PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit) {
604   PetscInt  nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
605   PetscBool is2Int, is2PetscInt;
606 #if defined(PETSC_HAVE_COMPLEX)
607   PetscInt nPetscComplex = 0;
608 #endif
609 
610   PetscFunctionBegin;
611   if (link->deviceinited) PetscFunctionReturn(0);
612   PetscCall(PetscKokkosInitializeCheck());
613   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
614   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
615   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
616   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
617   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
618   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
619 #if defined(PETSC_HAVE_COMPLEX)
620   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
621 #endif
622   PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
623   PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));
624 
625   if (is2Int) {
626     PackInit_PairType<Kokkos::pair<int, int>>(link);
627   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
628     PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link);
629   } else if (nPetscReal) {
630 #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */
631     if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
632     else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
633     else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
634     else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
635     else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
636     else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
637     else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
638     else if (nPetscReal % 1 == 0)
639 #endif
640       PackInit_RealType<PetscReal, 1, 0>(link);
641   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
642 #if !defined(PETSC_HAVE_DEVICE)
643     if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
644     else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
645     else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
646     else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
647     else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
648     else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
649     else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
650     else if (nPetscInt % 1 == 0)
651 #endif
652       PackInit_IntegerType<llint, 1, 0>(link);
653   } else if (nInt) {
654 #if !defined(PETSC_HAVE_DEVICE)
655     if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
656     else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
657     else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
658     else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
659     else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
660     else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
661     else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
662     else if (nInt % 1 == 0)
663 #endif
664       PackInit_IntegerType<int, 1, 0>(link);
665   } else if (nSignedChar) {
666 #if !defined(PETSC_HAVE_DEVICE)
667     if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link);
668     else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link);
669     else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link);
670     else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link);
671     else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link);
672     else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link);
673     else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link);
674     else if (nSignedChar % 1 == 0)
675 #endif
676       PackInit_IntegerType<char, 1, 0>(link);
677   } else if (nUnsignedChar) {
678 #if !defined(PETSC_HAVE_DEVICE)
679     if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link);
680     else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link);
681     else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link);
682     else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link);
683     else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link);
684     else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link);
685     else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link);
686     else if (nUnsignedChar % 1 == 0)
687 #endif
688       PackInit_IntegerType<unsigned char, 1, 0>(link);
689 #if defined(PETSC_HAVE_COMPLEX)
690   } else if (nPetscComplex) {
691 #if !defined(PETSC_HAVE_DEVICE)
692     if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link);
693     else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link);
694     else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link);
695     else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link);
696     else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link);
697     else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link);
698     else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link);
699     else if (nPetscComplex % 1 == 0)
700 #endif
701       PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link);
702 #endif
703   } else {
704     MPI_Aint lb, nbyte;
705     PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte));
706     PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb);
707     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
708 #if !defined(PETSC_HAVE_DEVICE)
709       if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
710       else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
711       else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
712       else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
713       else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
714       else if (nbyte % 1 == 0)
715 #endif
716         PackInit_DumbType<char, 1, 0>(link);
717     } else {
718       nInt = nbyte / sizeof(int);
719 #if !defined(PETSC_HAVE_DEVICE)
720       if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
721       else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
722       else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
723       else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
724       else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
725       else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
726       else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
727       else if (nInt % 1 == 0)
728 #endif
729         PackInit_DumbType<int, 1, 0>(link);
730     }
731   }
732 
733   link->SyncDevice   = PetscSFLinkSyncDevice_Kokkos;
734   link->SyncStream   = PetscSFLinkSyncStream_Kokkos;
735   link->Memcpy       = PetscSFLinkMemcpy_Kokkos;
736   link->Destroy      = PetscSFLinkDestroy_Kokkos;
737   link->deviceinited = PETSC_TRUE;
738   PetscFunctionReturn(0);
739 }
740