xref: /petsc/src/vec/is/sf/impls/basic/kokkos/sfkok.kokkos.cxx (revision bcee047adeeb73090d7e36cc71e39fc287cdbb97)
1 #include <../src/vec/is/sf/impls/basic/sfpack.h>
2 
3 #include <petsc_kokkos.hpp>
4 
5 using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
6 using DeviceMemorySpace    = typename DeviceExecutionSpace::memory_space;
7 using HostMemorySpace      = Kokkos::HostSpace;
8 
9 typedef Kokkos::View<char *, DeviceMemorySpace> deviceBuffer_t;
10 typedef Kokkos::View<char *, HostMemorySpace>   HostBuffer_t;
11 
12 typedef Kokkos::View<const char *, DeviceMemorySpace> deviceConstBuffer_t;
13 typedef Kokkos::View<const char *, HostMemorySpace>   HostConstBuffer_t;
14 
15 /*====================================================================================*/
16 /*                             Regular operations                           */
17 /*====================================================================================*/
18 template <typename Type>
19 struct Insert {
20   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
21   {
22     Type old = x;
23     x        = y;
24     return old;
25   }
26 };
27 template <typename Type>
28 struct Add {
29   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
30   {
31     Type old = x;
32     x += y;
33     return old;
34   }
35 };
36 template <typename Type>
37 struct Mult {
38   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
39   {
40     Type old = x;
41     x *= y;
42     return old;
43   }
44 };
45 template <typename Type>
46 struct Min {
47   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
48   {
49     Type old = x;
50     x        = PetscMin(x, y);
51     return old;
52   }
53 };
54 template <typename Type>
55 struct Max {
56   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
57   {
58     Type old = x;
59     x        = PetscMax(x, y);
60     return old;
61   }
62 };
63 template <typename Type>
64 struct LAND {
65   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
66   {
67     Type old = x;
68     x        = x && y;
69     return old;
70   }
71 };
72 template <typename Type>
73 struct LOR {
74   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
75   {
76     Type old = x;
77     x        = x || y;
78     return old;
79   }
80 };
81 template <typename Type>
82 struct LXOR {
83   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
84   {
85     Type old = x;
86     x        = !x != !y;
87     return old;
88   }
89 };
90 template <typename Type>
91 struct BAND {
92   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
93   {
94     Type old = x;
95     x        = x & y;
96     return old;
97   }
98 };
99 template <typename Type>
100 struct BOR {
101   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
102   {
103     Type old = x;
104     x        = x | y;
105     return old;
106   }
107 };
108 template <typename Type>
109 struct BXOR {
110   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
111   {
112     Type old = x;
113     x        = x ^ y;
114     return old;
115   }
116 };
117 template <typename PairType>
118 struct Minloc {
119   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
120   {
121     PairType old = x;
122     if (y.first < x.first) x = y;
123     else if (y.first == x.first) x.second = PetscMin(x.second, y.second);
124     return old;
125   }
126 };
127 template <typename PairType>
128 struct Maxloc {
129   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
130   {
131     PairType old = x;
132     if (y.first > x.first) x = y;
133     else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */
134     return old;
135   }
136 };
137 
138 /*====================================================================================*/
139 /*                             Atomic operations                            */
140 /*====================================================================================*/
141 template <typename Type>
142 struct AtomicInsert {
143   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_assign(&x, y); }
144 };
145 template <typename Type>
146 struct AtomicAdd {
147   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); }
148 };
149 template <typename Type>
150 struct AtomicBAND {
151   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); }
152 };
153 template <typename Type>
154 struct AtomicBOR {
155   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); }
156 };
157 template <typename Type>
158 struct AtomicBXOR {
159   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); }
160 };
161 template <typename Type>
162 struct AtomicLAND {
163   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
164   {
165     const Type zero = 0, one = ~0;
166     Kokkos::atomic_and(&x, y ? one : zero);
167   }
168 };
169 template <typename Type>
170 struct AtomicLOR {
171   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
172   {
173     const Type zero = 0, one = 1;
174     Kokkos::atomic_or(&x, y ? one : zero);
175   }
176 };
177 template <typename Type>
178 struct AtomicMult {
179   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); }
180 };
181 template <typename Type>
182 struct AtomicMin {
183   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); }
184 };
185 template <typename Type>
186 struct AtomicMax {
187   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); }
188 };
189 /* TODO: struct AtomicLXOR  */
190 template <typename Type>
191 struct AtomicFetchAdd {
192   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); }
193 };
194 
195 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
196 static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid)
197 {
198   PetscInt        i, j, k, m, n, r;
199   const PetscInt *offset, *start, *dx, *dy, *X, *Y;
200 
201   n      = opt[0];
202   offset = opt + 1;
203   start  = opt + n + 2;
204   dx     = opt + 2 * n + 2;
205   dy     = opt + 3 * n + 2;
206   X      = opt + 5 * n + 2;
207   Y      = opt + 6 * n + 2;
208   for (r = 0; r < n; r++) {
209     if (tid < offset[r + 1]) break;
210   }
211   m = (tid - offset[r]);
212   k = m / (dx[r] * dy[r]);
213   j = (m - k * dx[r] * dy[r]) / dx[r];
214   i = m - k * dx[r] * dy[r] - j * dx[r];
215 
216   return (start[r] + k * X[r] * Y[r] + j * X[r] + i);
217 }
218 
219 /*====================================================================================*/
220 /*  Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link'         */
221 /*====================================================================================*/
222 
223 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
224    <Type> is PetscReal, which is the primitive type we operate on.
225    <bs>   is 16, which says <unit> contains 16 primitive types.
226    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
227    <EQ>   is 0, which is (bs == BS ? 1 : 0)
228 
229   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
230   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
231 */
232 template <typename Type, PetscInt BS, PetscInt EQ>
233 static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_)
234 {
235   const PetscInt       *iopt = opt ? opt->array : NULL;
236   const PetscInt        M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */
237   const Type           *data = static_cast<const Type *>(data_);
238   Type                 *buf  = static_cast<Type *>(buf_);
239   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
240 
241   PetscFunctionBegin;
242   Kokkos::parallel_for(
243     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
244       /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
245        iopt == NULL && idx == NULL ==> the indices are contiguous;
246      */
247       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
248       PetscInt s = tid * MBS;
249       for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i];
250     });
251   PetscFunctionReturn(PETSC_SUCCESS);
252 }
253 
254 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
255 static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_)
256 {
257   Op                    op;
258   const PetscInt       *iopt = opt ? opt->array : NULL;
259   const PetscInt        M = EQ ? 1 : link->bs / BS, MBS = M * BS;
260   Type                 *data = static_cast<Type *>(data_);
261   const Type           *buf  = static_cast<const Type *>(buf_);
262   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
263 
264   PetscFunctionBegin;
265   Kokkos::parallel_for(
266     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
267       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
268       PetscInt s = tid * MBS;
269       for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
270     });
271   PetscFunctionReturn(PETSC_SUCCESS);
272 }
273 
274 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
275 static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf)
276 {
277   Op                    op;
278   const PetscInt       *ropt = opt ? opt->array : NULL;
279   const PetscInt        M = EQ ? 1 : link->bs / BS, MBS = M * BS;
280   Type                 *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf);
281   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
282 
283   PetscFunctionBegin;
284   Kokkos::parallel_for(
285     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
286       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
287       PetscInt l = tid * MBS;
288       for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
289     });
290   PetscFunctionReturn(PETSC_SUCCESS);
291 }
292 
293 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
294 static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
295 {
296   PetscInt              srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;
297   const PetscInt        M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
298   const Type           *src  = static_cast<const Type *>(src_);
299   Type                 *dst  = static_cast<Type *>(dst_);
300   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
301 
302   PetscFunctionBegin;
303   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
304   if (srcOpt) {
305     srcx     = srcOpt->dx[0];
306     srcy     = srcOpt->dy[0];
307     srcX     = srcOpt->X[0];
308     srcY     = srcOpt->Y[0];
309     srcStart = srcOpt->start[0];
310     srcIdx   = NULL;
311   } else if (!srcIdx) {
312     srcx = srcX = count;
313     srcy = srcY = 1;
314   }
315 
316   if (dstOpt) {
317     dstx     = dstOpt->dx[0];
318     dsty     = dstOpt->dy[0];
319     dstX     = dstOpt->X[0];
320     dstY     = dstOpt->Y[0];
321     dstStart = dstOpt->start[0];
322     dstIdx   = NULL;
323   } else if (!dstIdx) {
324     dstx = dstX = count;
325     dsty = dstY = 1;
326   }
327 
328   Kokkos::parallel_for(
329     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
330       PetscInt i, j, k, s, t;
331       Op       op;
332       if (!srcIdx) { /* src is in 3D */
333         k = tid / (srcx * srcy);
334         j = (tid - k * srcx * srcy) / srcx;
335         i = tid - k * srcx * srcy - j * srcx;
336         s = srcStart + k * srcX * srcY + j * srcX + i;
337       } else { /* src is contiguous */
338         s = srcIdx[tid];
339       }
340 
341       if (!dstIdx) { /* 3D */
342         k = tid / (dstx * dsty);
343         j = (tid - k * dstx * dsty) / dstx;
344         i = tid - k * dstx * dsty - j * dstx;
345         t = dstStart + k * dstX * dstY + j * dstX + i;
346       } else { /* contiguous */
347         t = dstIdx[tid];
348       }
349 
350       s *= MBS;
351       t *= MBS;
352       for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
353     });
354   PetscFunctionReturn(PETSC_SUCCESS);
355 }
356 
357 /* Specialization for Insert since we may use memcpy */
358 template <typename Type, PetscInt BS, PetscInt EQ>
359 static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
360 {
361   const Type           *src  = static_cast<const Type *>(src_);
362   Type                 *dst  = static_cast<Type *>(dst_);
363   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
364 
365   PetscFunctionBegin;
366   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
367   /*src and dst are contiguous */
368   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
369     size_t              sz = count * link->unitbytes;
370     deviceBuffer_t      dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz);
371     deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz);
372     Kokkos::deep_copy(exec, dbuf, sbuf);
373   } else {
374     PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
375   }
376   PetscFunctionReturn(PETSC_SUCCESS);
377 }
378 
379 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
380 static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_)
381 {
382   Op                    op;
383   const PetscInt        M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
384   const PetscInt       *ropt     = rootopt ? rootopt->array : NULL;
385   const PetscInt       *lopt     = leafopt ? leafopt->array : NULL;
386   Type                 *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_);
387   const Type           *leafdata = static_cast<const Type *>(leafdata_);
388   DeviceExecutionSpace &exec     = PetscGetKokkosExecutionSpace();
389 
390   PetscFunctionBegin;
391   Kokkos::parallel_for(
392     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
393       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
394       PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
395       for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
396     });
397   PetscFunctionReturn(PETSC_SUCCESS);
398 }
399 
400 /*====================================================================================*/
401 /*  Init various types and instantiate pack/unpack function pointers                  */
402 /*====================================================================================*/
403 template <typename Type, PetscInt BS, PetscInt EQ>
404 static void PackInit_RealType(PetscSFLink link)
405 {
406   /* Pack/unpack for remote communication */
407   link->d_Pack            = Pack<Type, BS, EQ>;
408   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
409   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
410   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
411   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
412   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
413   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
414   /* Scatter for local communication */
415   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
416   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
417   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
418   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
419   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
420   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
421   /* Atomic versions when there are data-race possibilities */
422   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
423   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
424   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
425   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
426   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
427   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
428 
429   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
430   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
431   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
432   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
433   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
434   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
435 }
436 
437 template <typename Type, PetscInt BS, PetscInt EQ>
438 static void PackInit_IntegerType(PetscSFLink link)
439 {
440   link->d_Pack            = Pack<Type, BS, EQ>;
441   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
442   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
443   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
444   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
445   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
446   link->d_UnpackAndLAND   = UnpackAndOp<Type, LAND<Type>, BS, EQ>;
447   link->d_UnpackAndLOR    = UnpackAndOp<Type, LOR<Type>, BS, EQ>;
448   link->d_UnpackAndLXOR   = UnpackAndOp<Type, LXOR<Type>, BS, EQ>;
449   link->d_UnpackAndBAND   = UnpackAndOp<Type, BAND<Type>, BS, EQ>;
450   link->d_UnpackAndBOR    = UnpackAndOp<Type, BOR<Type>, BS, EQ>;
451   link->d_UnpackAndBXOR   = UnpackAndOp<Type, BXOR<Type>, BS, EQ>;
452   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
453 
454   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
455   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
456   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
457   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
458   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
459   link->d_ScatterAndLAND   = ScatterAndOp<Type, LAND<Type>, BS, EQ>;
460   link->d_ScatterAndLOR    = ScatterAndOp<Type, LOR<Type>, BS, EQ>;
461   link->d_ScatterAndLXOR   = ScatterAndOp<Type, LXOR<Type>, BS, EQ>;
462   link->d_ScatterAndBAND   = ScatterAndOp<Type, BAND<Type>, BS, EQ>;
463   link->d_ScatterAndBOR    = ScatterAndOp<Type, BOR<Type>, BS, EQ>;
464   link->d_ScatterAndBXOR   = ScatterAndOp<Type, BXOR<Type>, BS, EQ>;
465   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
466 
467   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
468   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
469   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
470   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
471   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
472   link->da_UnpackAndLAND   = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
473   link->da_UnpackAndLOR    = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
474   link->da_UnpackAndBAND   = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
475   link->da_UnpackAndBOR    = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
476   link->da_UnpackAndBXOR   = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
477   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
478 
479   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
480   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
481   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
482   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
483   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
484   link->da_ScatterAndLAND   = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
485   link->da_ScatterAndLOR    = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
486   link->da_ScatterAndBAND   = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
487   link->da_ScatterAndBOR    = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
488   link->da_ScatterAndBXOR   = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
489   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
490 }
491 
492 #if defined(PETSC_HAVE_COMPLEX)
493 template <typename Type, PetscInt BS, PetscInt EQ>
494 static void PackInit_ComplexType(PetscSFLink link)
495 {
496   link->d_Pack            = Pack<Type, BS, EQ>;
497   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
498   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
499   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
500   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
501 
502   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
503   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
504   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
505   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
506 
507   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
508   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
509   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
510   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
511 
512   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
513   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
514   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
515   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
516 }
517 #endif
518 
519 template <typename Type>
520 static void PackInit_PairType(PetscSFLink link)
521 {
522   link->d_Pack            = Pack<Type, 1, 1>;
523   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>;
524   link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>;
525   link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>;
526 
527   link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>;
528   link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>;
529   link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>;
530   /* Atomics for pair types are not implemented yet */
531 }
532 
533 template <typename Type, PetscInt BS, PetscInt EQ>
534 static void PackInit_DumbType(PetscSFLink link)
535 {
536   link->d_Pack             = Pack<Type, BS, EQ>;
537   link->d_UnpackAndInsert  = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
538   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
539   /* Atomics for dumb types are not implemented yet */
540 }
541 
542 /*
543   Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
544   that one is not able to repeatedly create and destroy the object. SF's original design was each
545   SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
546   destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
547   memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
548   does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
549   object in Kokkos.
550 */
551 /*
552 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
553 {
554   PetscFunctionBegin;
555   PetscFunctionReturn(PETSC_SUCCESS);
556 }
557 */
558 
559 /* Some device-specific utilities */
560 static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link)
561 {
562   PetscFunctionBegin;
563   Kokkos::fence();
564   PetscFunctionReturn(PETSC_SUCCESS);
565 }
566 
567 static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link)
568 {
569   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
570   PetscFunctionBegin;
571   exec.fence();
572   PetscFunctionReturn(PETSC_SUCCESS);
573 }
574 
575 static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n)
576 {
577   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
578 
579   PetscFunctionBegin;
580   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
581   if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) {
582     PetscCall(PetscMemcpy(dst, src, n));
583   } else {
584     if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) {
585       deviceBuffer_t    dbuf(static_cast<char *>(dst), n);
586       HostConstBuffer_t sbuf(static_cast<const char *>(src), n);
587       Kokkos::deep_copy(exec, dbuf, sbuf);
588       PetscCall(PetscLogCpuToGpu(n));
589     } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) {
590       HostBuffer_t        dbuf(static_cast<char *>(dst), n);
591       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
592       Kokkos::deep_copy(exec, dbuf, sbuf);
593       PetscCall(PetscLogGpuToCpu(n));
594     } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) {
595       deviceBuffer_t      dbuf(static_cast<char *>(dst), n);
596       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
597       Kokkos::deep_copy(exec, dbuf, sbuf);
598     }
599   }
600   PetscFunctionReturn(PETSC_SUCCESS);
601 }
602 
603 PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr)
604 {
605   PetscFunctionBegin;
606   if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
607   else if (PetscMemTypeDevice(mtype)) {
608     if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck());
609     *ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size);
610   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
611   PetscFunctionReturn(PETSC_SUCCESS);
612 }
613 
614 PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr)
615 {
616   PetscFunctionBegin;
617   if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
618   else if (PetscMemTypeDevice(mtype)) {
619     Kokkos::kokkos_free<DeviceMemorySpace>(ptr);
620   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
621   PetscFunctionReturn(PETSC_SUCCESS);
622 }
623 
624 /* Destructor when the link uses MPI for communication */
625 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link)
626 {
627   PetscFunctionBegin;
628   for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
629     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
630     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
631   }
632   PetscFunctionReturn(PETSC_SUCCESS);
633 }
634 
635 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
636 PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit)
637 {
638   PetscInt  nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
639   PetscBool is2Int, is2PetscInt;
640 #if defined(PETSC_HAVE_COMPLEX)
641   PetscInt nPetscComplex = 0;
642 #endif
643 
644   PetscFunctionBegin;
645   if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS);
646   PetscCall(PetscKokkosInitializeCheck());
647   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
648   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
649   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
650   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
651   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
652   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
653 #if defined(PETSC_HAVE_COMPLEX)
654   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
655 #endif
656   PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
657   PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));
658 
659   if (is2Int) {
660     PackInit_PairType<Kokkos::pair<int, int>>(link);
661   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
662     PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link);
663   } else if (nPetscReal) {
664 #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */
665     if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
666     else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
667     else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
668     else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
669     else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
670     else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
671     else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
672     else if (nPetscReal % 1 == 0)
673 #endif
674       PackInit_RealType<PetscReal, 1, 0>(link);
675   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
676 #if !defined(PETSC_HAVE_DEVICE)
677     if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
678     else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
679     else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
680     else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
681     else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
682     else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
683     else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
684     else if (nPetscInt % 1 == 0)
685 #endif
686       PackInit_IntegerType<llint, 1, 0>(link);
687   } else if (nInt) {
688 #if !defined(PETSC_HAVE_DEVICE)
689     if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
690     else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
691     else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
692     else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
693     else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
694     else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
695     else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
696     else if (nInt % 1 == 0)
697 #endif
698       PackInit_IntegerType<int, 1, 0>(link);
699   } else if (nSignedChar) {
700 #if !defined(PETSC_HAVE_DEVICE)
701     if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link);
702     else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link);
703     else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link);
704     else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link);
705     else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link);
706     else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link);
707     else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link);
708     else if (nSignedChar % 1 == 0)
709 #endif
710       PackInit_IntegerType<char, 1, 0>(link);
711   } else if (nUnsignedChar) {
712 #if !defined(PETSC_HAVE_DEVICE)
713     if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link);
714     else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link);
715     else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link);
716     else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link);
717     else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link);
718     else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link);
719     else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link);
720     else if (nUnsignedChar % 1 == 0)
721 #endif
722       PackInit_IntegerType<unsigned char, 1, 0>(link);
723 #if defined(PETSC_HAVE_COMPLEX)
724   } else if (nPetscComplex) {
725   #if !defined(PETSC_HAVE_DEVICE)
726     if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link);
727     else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link);
728     else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link);
729     else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link);
730     else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link);
731     else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link);
732     else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link);
733     else if (nPetscComplex % 1 == 0)
734   #endif
735       PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link);
736 #endif
737   } else {
738     MPI_Aint lb, nbyte;
739     PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte));
740     PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb);
741     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
742 #if !defined(PETSC_HAVE_DEVICE)
743       if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
744       else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
745       else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
746       else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
747       else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
748       else if (nbyte % 1 == 0)
749 #endif
750         PackInit_DumbType<char, 1, 0>(link);
751     } else {
752       nInt = nbyte / sizeof(int);
753 #if !defined(PETSC_HAVE_DEVICE)
754       if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
755       else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
756       else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
757       else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
758       else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
759       else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
760       else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
761       else if (nInt % 1 == 0)
762 #endif
763         PackInit_DumbType<int, 1, 0>(link);
764     }
765   }
766 
767   link->SyncDevice   = PetscSFLinkSyncDevice_Kokkos;
768   link->SyncStream   = PetscSFLinkSyncStream_Kokkos;
769   link->Memcpy       = PetscSFLinkMemcpy_Kokkos;
770   link->Destroy      = PetscSFLinkDestroy_Kokkos;
771   link->deviceinited = PETSC_TRUE;
772   PetscFunctionReturn(PETSC_SUCCESS);
773 }
774