xref: /petsc/src/vec/is/sf/impls/basic/kokkos/sfkok.kokkos.cxx (revision f13dfd9ea68e0ddeee984e65c377a1819eab8a8a)
1 #include <../src/vec/is/sf/impls/basic/sfpack.h>
2 
3 #include <petsc_kokkos.hpp>
4 
5 using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
6 using DeviceMemorySpace    = typename DeviceExecutionSpace::memory_space;
7 using HostMemorySpace      = Kokkos::HostSpace;
8 
9 typedef Kokkos::View<char *, DeviceMemorySpace> deviceBuffer_t;
10 typedef Kokkos::View<char *, HostMemorySpace>   HostBuffer_t;
11 
12 typedef Kokkos::View<const char *, DeviceMemorySpace> deviceConstBuffer_t;
13 typedef Kokkos::View<const char *, HostMemorySpace>   HostConstBuffer_t;
14 
15 /*====================================================================================*/
16 /*                             Regular operations                           */
17 /*====================================================================================*/
18 template <typename Type>
19 struct Insert {
20   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
21   {
22     Type old = x;
23     x        = y;
24     return old;
25   }
26 };
27 template <typename Type>
28 struct Add {
29   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
30   {
31     Type old = x;
32     x += y;
33     return old;
34   }
35 };
36 template <typename Type>
37 struct Mult {
38   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
39   {
40     Type old = x;
41     x *= y;
42     return old;
43   }
44 };
45 template <typename Type>
46 struct Min {
47   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
48   {
49     Type old = x;
50     x        = PetscMin(x, y);
51     return old;
52   }
53 };
54 template <typename Type>
55 struct Max {
56   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
57   {
58     Type old = x;
59     x        = PetscMax(x, y);
60     return old;
61   }
62 };
63 template <typename Type>
64 struct LAND {
65   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
66   {
67     Type old = x;
68     x        = x && y;
69     return old;
70   }
71 };
72 template <typename Type>
73 struct LOR {
74   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
75   {
76     Type old = x;
77     x        = x || y;
78     return old;
79   }
80 };
81 template <typename Type>
82 struct LXOR {
83   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
84   {
85     Type old = x;
86     x        = !x != !y;
87     return old;
88   }
89 };
90 template <typename Type>
91 struct BAND {
92   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
93   {
94     Type old = x;
95     x        = x & y;
96     return old;
97   }
98 };
99 template <typename Type>
100 struct BOR {
101   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
102   {
103     Type old = x;
104     x        = x | y;
105     return old;
106   }
107 };
108 template <typename Type>
109 struct BXOR {
110   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
111   {
112     Type old = x;
113     x        = x ^ y;
114     return old;
115   }
116 };
117 template <typename PairType>
118 struct Minloc {
119   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
120   {
121     PairType old = x;
122     if (y.first < x.first) x = y;
123     else if (y.first == x.first) x.second = PetscMin(x.second, y.second);
124     return old;
125   }
126 };
127 template <typename PairType>
128 struct Maxloc {
129   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
130   {
131     PairType old = x;
132     if (y.first > x.first) x = y;
133     else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */
134     return old;
135   }
136 };
137 
138 /*====================================================================================*/
139 /*                             Atomic operations                            */
140 /*====================================================================================*/
141 template <typename Type>
142 struct AtomicInsert {
143   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_assign(&x, y); }
144 };
145 template <typename Type>
146 struct AtomicAdd {
147   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); }
148 };
149 template <typename Type>
150 struct AtomicBAND {
151   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); }
152 };
153 template <typename Type>
154 struct AtomicBOR {
155   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); }
156 };
157 template <typename Type>
158 struct AtomicBXOR {
159   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); }
160 };
161 template <typename Type>
162 struct AtomicLAND {
163   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
164   {
165     const Type zero = 0, one = ~0;
166     Kokkos::atomic_and(&x, y ? one : zero);
167   }
168 };
169 template <typename Type>
170 struct AtomicLOR {
171   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
172   {
173     const Type zero = 0, one = 1;
174     Kokkos::atomic_or(&x, y ? one : zero);
175   }
176 };
177 template <typename Type>
178 struct AtomicMult {
179   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); }
180 };
181 template <typename Type>
182 struct AtomicMin {
183   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); }
184 };
185 template <typename Type>
186 struct AtomicMax {
187   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); }
188 };
189 /* TODO: struct AtomicLXOR  */
190 template <typename Type>
191 struct AtomicFetchAdd {
192   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); }
193 };
194 
195 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
196 static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid)
197 {
198   PetscInt        i, j, k, m, n, r;
199   const PetscInt *offset, *start, *dx, *dy, *X, *Y;
200 
201   n      = opt[0];
202   offset = opt + 1;
203   start  = opt + n + 2;
204   dx     = opt + 2 * n + 2;
205   dy     = opt + 3 * n + 2;
206   X      = opt + 5 * n + 2;
207   Y      = opt + 6 * n + 2;
208   for (r = 0; r < n; r++) {
209     if (tid < offset[r + 1]) break;
210   }
211   m = (tid - offset[r]);
212   k = m / (dx[r] * dy[r]);
213   j = (m - k * dx[r] * dy[r]) / dx[r];
214   i = m - k * dx[r] * dy[r] - j * dx[r];
215 
216   return start[r] + k * X[r] * Y[r] + j * X[r] + i;
217 }
218 
219 /*====================================================================================*/
220 /*  Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link'         */
221 /*====================================================================================*/
222 
223 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
224    <Type> is PetscReal, which is the primitive type we operate on.
225    <bs>   is 16, which says <unit> contains 16 primitive types.
226    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
227    <EQ>   is 0, which is (bs == BS ? 1 : 0)
228 
229   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
230   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
231 */
232 template <typename Type, PetscInt BS, PetscInt EQ>
233 static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_)
234 {
235   const PetscInt       *iopt = opt ? opt->array : NULL;
236   const PetscInt        M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */
237   const Type           *data = static_cast<const Type *>(data_);
238   Type                 *buf  = static_cast<Type *>(buf_);
239   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
240 
241   PetscFunctionBegin;
242   Kokkos::parallel_for(
243     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
244       /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
245        iopt == NULL && idx == NULL ==> the indices are contiguous;
246      */
247       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
248       PetscInt s = tid * MBS;
249       for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i];
250     });
251   PetscFunctionReturn(PETSC_SUCCESS);
252 }
253 
254 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
255 static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_)
256 {
257   Op                    op;
258   const PetscInt       *iopt = opt ? opt->array : NULL;
259   const PetscInt        M = EQ ? 1 : link->bs / BS, MBS = M * BS;
260   Type                 *data = static_cast<Type *>(data_);
261   const Type           *buf  = static_cast<const Type *>(buf_);
262   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
263 
264   PetscFunctionBegin;
265   Kokkos::parallel_for(
266     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
267       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
268       PetscInt s = tid * MBS;
269       for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
270     });
271   PetscFunctionReturn(PETSC_SUCCESS);
272 }
273 
274 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
275 static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf)
276 {
277   Op                    op;
278   const PetscInt       *ropt = opt ? opt->array : NULL;
279   const PetscInt        M = EQ ? 1 : link->bs / BS, MBS = M * BS;
280   Type                 *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf);
281   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
282 
283   PetscFunctionBegin;
284   Kokkos::parallel_for(
285     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
286       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
287       PetscInt l = tid * MBS;
288       for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
289     });
290   PetscFunctionReturn(PETSC_SUCCESS);
291 }
292 
293 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
294 static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
295 {
296   PetscInt              srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;
297   const PetscInt        M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
298   const Type           *src  = static_cast<const Type *>(src_);
299   Type                 *dst  = static_cast<Type *>(dst_);
300   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
301 
302   PetscFunctionBegin;
303   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
304   if (srcOpt) {
305     srcx     = srcOpt->dx[0];
306     srcy     = srcOpt->dy[0];
307     srcX     = srcOpt->X[0];
308     srcY     = srcOpt->Y[0];
309     srcStart = srcOpt->start[0];
310     srcIdx   = NULL;
311   } else if (!srcIdx) {
312     srcx = srcX = count;
313     srcy = srcY = 1;
314   }
315 
316   if (dstOpt) {
317     dstx     = dstOpt->dx[0];
318     dsty     = dstOpt->dy[0];
319     dstX     = dstOpt->X[0];
320     dstY     = dstOpt->Y[0];
321     dstStart = dstOpt->start[0];
322     dstIdx   = NULL;
323   } else if (!dstIdx) {
324     dstx = dstX = count;
325     dsty = dstY = 1;
326   }
327 
328   Kokkos::parallel_for(
329     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
330       PetscInt i, j, k, s, t;
331       Op       op;
332       if (!srcIdx) { /* src is in 3D */
333         k = tid / (srcx * srcy);
334         j = (tid - k * srcx * srcy) / srcx;
335         i = tid - k * srcx * srcy - j * srcx;
336         s = srcStart + k * srcX * srcY + j * srcX + i;
337       } else { /* src is contiguous */
338         s = srcIdx[tid];
339       }
340 
341       if (!dstIdx) { /* 3D */
342         k = tid / (dstx * dsty);
343         j = (tid - k * dstx * dsty) / dstx;
344         i = tid - k * dstx * dsty - j * dstx;
345         t = dstStart + k * dstX * dstY + j * dstX + i;
346       } else { /* contiguous */
347         t = dstIdx[tid];
348       }
349 
350       s *= MBS;
351       t *= MBS;
352       for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
353     });
354   PetscFunctionReturn(PETSC_SUCCESS);
355 }
356 
357 /* Specialization for Insert since we may use memcpy */
358 template <typename Type, PetscInt BS, PetscInt EQ>
359 static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
360 {
361   const Type           *src  = static_cast<const Type *>(src_);
362   Type                 *dst  = static_cast<Type *>(dst_);
363   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
364 
365   PetscFunctionBegin;
366   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
367   /*src and dst are contiguous */
368   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
369     size_t              sz = count * link->unitbytes;
370     deviceBuffer_t      dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz);
371     deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz);
372     Kokkos::deep_copy(exec, dbuf, sbuf);
373   } else {
374     PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
375   }
376   PetscFunctionReturn(PETSC_SUCCESS);
377 }
378 
379 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
380 static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_)
381 {
382   Op                    op;
383   const PetscInt        M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
384   const PetscInt       *ropt     = rootopt ? rootopt->array : NULL;
385   const PetscInt       *lopt     = leafopt ? leafopt->array : NULL;
386   Type                 *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_);
387   const Type           *leafdata = static_cast<const Type *>(leafdata_);
388   DeviceExecutionSpace &exec     = PetscGetKokkosExecutionSpace();
389 
390   PetscFunctionBegin;
391   Kokkos::parallel_for(
392     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
393       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
394       PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
395       for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
396     });
397   PetscFunctionReturn(PETSC_SUCCESS);
398 }
399 
400 /*====================================================================================*/
401 /*  Init various types and instantiate pack/unpack function pointers                  */
402 /*====================================================================================*/
403 template <typename Type, PetscInt BS, PetscInt EQ>
404 static void PackInit_RealType(PetscSFLink link)
405 {
406   /* Pack/unpack for remote communication */
407   link->d_Pack            = Pack<Type, BS, EQ>;
408   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
409   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
410   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
411   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
412   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
413   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
414   /* Scatter for local communication */
415   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
416   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
417   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
418   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
419   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
420   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
421   /* Atomic versions when there are data-race possibilities */
422   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
423   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
424   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
425   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
426   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
427   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
428 
429   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
430   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
431   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
432   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
433   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
434   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
435 }
436 
437 template <typename Type, PetscInt BS, PetscInt EQ>
438 static void PackInit_IntegerType(PetscSFLink link)
439 {
440   link->d_Pack            = Pack<Type, BS, EQ>;
441   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
442   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
443   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
444   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
445   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
446   link->d_UnpackAndLAND   = UnpackAndOp<Type, LAND<Type>, BS, EQ>;
447   link->d_UnpackAndLOR    = UnpackAndOp<Type, LOR<Type>, BS, EQ>;
448   link->d_UnpackAndLXOR   = UnpackAndOp<Type, LXOR<Type>, BS, EQ>;
449   link->d_UnpackAndBAND   = UnpackAndOp<Type, BAND<Type>, BS, EQ>;
450   link->d_UnpackAndBOR    = UnpackAndOp<Type, BOR<Type>, BS, EQ>;
451   link->d_UnpackAndBXOR   = UnpackAndOp<Type, BXOR<Type>, BS, EQ>;
452   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
453 
454   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
455   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
456   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
457   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
458   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
459   link->d_ScatterAndLAND   = ScatterAndOp<Type, LAND<Type>, BS, EQ>;
460   link->d_ScatterAndLOR    = ScatterAndOp<Type, LOR<Type>, BS, EQ>;
461   link->d_ScatterAndLXOR   = ScatterAndOp<Type, LXOR<Type>, BS, EQ>;
462   link->d_ScatterAndBAND   = ScatterAndOp<Type, BAND<Type>, BS, EQ>;
463   link->d_ScatterAndBOR    = ScatterAndOp<Type, BOR<Type>, BS, EQ>;
464   link->d_ScatterAndBXOR   = ScatterAndOp<Type, BXOR<Type>, BS, EQ>;
465   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
466 
467   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
468   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
469   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
470   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
471   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
472   link->da_UnpackAndLAND   = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
473   link->da_UnpackAndLOR    = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
474   link->da_UnpackAndBAND   = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
475   link->da_UnpackAndBOR    = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
476   link->da_UnpackAndBXOR   = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
477   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
478 
479   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
480   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
481   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
482   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
483   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
484   link->da_ScatterAndLAND   = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
485   link->da_ScatterAndLOR    = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
486   link->da_ScatterAndBAND   = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
487   link->da_ScatterAndBOR    = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
488   link->da_ScatterAndBXOR   = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
489   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
490 }
491 
492 #if defined(PETSC_HAVE_COMPLEX)
493 template <typename Type, PetscInt BS, PetscInt EQ>
494 static void PackInit_ComplexType(PetscSFLink link)
495 {
496   link->d_Pack            = Pack<Type, BS, EQ>;
497   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
498   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
499   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
500   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
501 
502   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
503   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
504   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
505   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
506 
507   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
508   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
509   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
510   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
511 
512   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
513   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
514   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
515   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
516 }
517 #endif
518 
519 template <typename Type>
520 static void PackInit_PairType(PetscSFLink link)
521 {
522   link->d_Pack            = Pack<Type, 1, 1>;
523   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>;
524   link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>;
525   link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>;
526 
527   link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>;
528   link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>;
529   link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>;
530   /* Atomics for pair types are not implemented yet */
531 }
532 
533 template <typename Type, PetscInt BS, PetscInt EQ>
534 static void PackInit_DumbType(PetscSFLink link)
535 {
536   link->d_Pack             = Pack<Type, BS, EQ>;
537   link->d_UnpackAndInsert  = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
538   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
539   /* Atomics for dumb types are not implemented yet */
540 }
541 
542 /*
543   Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
544   that one is not able to repeatedly create and destroy the object. SF's original design was each
545   SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
546   destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
547   memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
548   does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
549   object in Kokkos.
550 */
551 /*
552 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
553 {
554   PetscFunctionBegin;
555   PetscFunctionReturn(PETSC_SUCCESS);
556 }
557 */
558 
559 /* Some device-specific utilities */
560 static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link)
561 {
562   PetscFunctionBegin;
563   Kokkos::fence();
564   PetscFunctionReturn(PETSC_SUCCESS);
565 }
566 
567 static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link)
568 {
569   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
570 
571   PetscFunctionBegin;
572   exec.fence();
573   PetscFunctionReturn(PETSC_SUCCESS);
574 }
575 
576 static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n)
577 {
578   DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace();
579 
580   PetscFunctionBegin;
581   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
582   if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) {
583     PetscCall(PetscMemcpy(dst, src, n));
584   } else {
585     if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) { // H2D
586       deviceBuffer_t    dbuf(static_cast<char *>(dst), n);
587       HostConstBuffer_t sbuf(static_cast<const char *>(src), n);
588       PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
589       PetscCall(PetscLogCpuToGpu(n));
590     } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) { // D2H
591       HostBuffer_t        dbuf(static_cast<char *>(dst), n);
592       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
593       PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
594       PetscCallCXX(exec.fence()); // make sure dbuf is ready for use immediately on host
595       PetscCall(PetscLogGpuToCpu(n));
596     } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) {
597       deviceBuffer_t      dbuf(static_cast<char *>(dst), n);
598       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
599       PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
600     }
601   }
602   PetscFunctionReturn(PETSC_SUCCESS);
603 }
604 
605 PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr)
606 {
607   PetscFunctionBegin;
608   if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
609   else if (PetscMemTypeDevice(mtype)) {
610     if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck());
611     PetscCallCXX(*ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size));
612   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
613   PetscFunctionReturn(PETSC_SUCCESS);
614 }
615 
616 PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr)
617 {
618   PetscFunctionBegin;
619   if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
620   else if (PetscMemTypeDevice(mtype)) {
621     PetscCallCXX(Kokkos::kokkos_free<DeviceMemorySpace>(ptr));
622   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
623   PetscFunctionReturn(PETSC_SUCCESS);
624 }
625 
626 /* Destructor when the link uses MPI for communication */
627 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link)
628 {
629   PetscFunctionBegin;
630   for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
631     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
632     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
633   }
634   PetscFunctionReturn(PETSC_SUCCESS);
635 }
636 
637 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
638 PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit)
639 {
640   PetscInt  nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
641   PetscBool is2Int, is2PetscInt;
642 #if defined(PETSC_HAVE_COMPLEX)
643   PetscInt nPetscComplex = 0;
644 #endif
645 
646   PetscFunctionBegin;
647   if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS);
648   PetscCall(PetscKokkosInitializeCheck());
649   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
650   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
651   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
652   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
653   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
654   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
655 #if defined(PETSC_HAVE_COMPLEX)
656   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
657 #endif
658   PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
659   PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));
660 
661   if (is2Int) {
662     PackInit_PairType<Kokkos::pair<int, int>>(link);
663   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
664     PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link);
665   } else if (nPetscReal) {
666 #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */
667     if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
668     else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
669     else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
670     else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
671     else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
672     else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
673     else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
674     else if (nPetscReal % 1 == 0)
675 #endif
676       PackInit_RealType<PetscReal, 1, 0>(link);
677   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
678 #if !defined(PETSC_HAVE_DEVICE)
679     if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
680     else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
681     else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
682     else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
683     else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
684     else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
685     else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
686     else if (nPetscInt % 1 == 0)
687 #endif
688       PackInit_IntegerType<llint, 1, 0>(link);
689   } else if (nInt) {
690 #if !defined(PETSC_HAVE_DEVICE)
691     if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
692     else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
693     else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
694     else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
695     else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
696     else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
697     else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
698     else if (nInt % 1 == 0)
699 #endif
700       PackInit_IntegerType<int, 1, 0>(link);
701   } else if (nSignedChar) {
702 #if !defined(PETSC_HAVE_DEVICE)
703     if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link);
704     else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link);
705     else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link);
706     else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link);
707     else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link);
708     else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link);
709     else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link);
710     else if (nSignedChar % 1 == 0)
711 #endif
712       PackInit_IntegerType<char, 1, 0>(link);
713   } else if (nUnsignedChar) {
714 #if !defined(PETSC_HAVE_DEVICE)
715     if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link);
716     else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link);
717     else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link);
718     else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link);
719     else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link);
720     else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link);
721     else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link);
722     else if (nUnsignedChar % 1 == 0)
723 #endif
724       PackInit_IntegerType<unsigned char, 1, 0>(link);
725 #if defined(PETSC_HAVE_COMPLEX)
726   } else if (nPetscComplex) {
727   #if !defined(PETSC_HAVE_DEVICE)
728     if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link);
729     else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link);
730     else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link);
731     else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link);
732     else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link);
733     else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link);
734     else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link);
735     else if (nPetscComplex % 1 == 0)
736   #endif
737       PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link);
738 #endif
739   } else {
740     MPI_Aint lb, nbyte;
741     PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte));
742     PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb);
743     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
744 #if !defined(PETSC_HAVE_DEVICE)
745       if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
746       else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
747       else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
748       else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
749       else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
750       else if (nbyte % 1 == 0)
751 #endif
752         PackInit_DumbType<char, 1, 0>(link);
753     } else {
754       nInt = nbyte / sizeof(int);
755 #if !defined(PETSC_HAVE_DEVICE)
756       if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
757       else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
758       else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
759       else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
760       else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
761       else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
762       else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
763       else if (nInt % 1 == 0)
764 #endif
765         PackInit_DumbType<int, 1, 0>(link);
766     }
767   }
768 
769   link->SyncDevice   = PetscSFLinkSyncDevice_Kokkos;
770   link->SyncStream   = PetscSFLinkSyncStream_Kokkos;
771   link->Memcpy       = PetscSFLinkMemcpy_Kokkos;
772   link->Destroy      = PetscSFLinkDestroy_Kokkos;
773   link->deviceinited = PETSC_TRUE;
774   PetscFunctionReturn(PETSC_SUCCESS);
775 }
776