1 #include <../src/vec/is/sf/impls/basic/sfpack.h>
2
3 #include <petsc_kokkos.hpp>
4 #include <petsc/private/kokkosimpl.hpp>
5
6 using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
7
8 typedef Kokkos::View<char *, DefaultMemorySpace> deviceBuffer_t;
9 typedef Kokkos::View<char *, HostMirrorMemorySpace> HostBuffer_t;
10
11 typedef Kokkos::View<const char *, DefaultMemorySpace> deviceConstBuffer_t;
12 typedef Kokkos::View<const char *, HostMirrorMemorySpace> HostConstBuffer_t;
13
14 /*====================================================================================*/
15 /* Regular operations */
16 /*====================================================================================*/
17 template <typename Type>
18 struct Insert {
operator ()Insert19 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
20 {
21 Type old = x;
22 x = y;
23 return old;
24 }
25 };
26 template <typename Type>
27 struct Add {
operator ()Add28 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
29 {
30 Type old = x;
31 x += y;
32 return old;
33 }
34 };
35 template <typename Type>
36 struct Mult {
operator ()Mult37 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
38 {
39 Type old = x;
40 x *= y;
41 return old;
42 }
43 };
44 template <typename Type>
45 struct Min {
operator ()Min46 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
47 {
48 Type old = x;
49 x = PetscMin(x, y);
50 return old;
51 }
52 };
53 template <typename Type>
54 struct Max {
operator ()Max55 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
56 {
57 Type old = x;
58 x = PetscMax(x, y);
59 return old;
60 }
61 };
62 template <typename Type>
63 struct LAND {
operator ()LAND64 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
65 {
66 Type old = x;
67 x = x && y;
68 return old;
69 }
70 };
71 template <typename Type>
72 struct LOR {
operator ()LOR73 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
74 {
75 Type old = x;
76 x = x || y;
77 return old;
78 }
79 };
80 template <typename Type>
81 struct LXOR {
operator ()LXOR82 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
83 {
84 Type old = x;
85 x = !x != !y;
86 return old;
87 }
88 };
89 template <typename Type>
90 struct BAND {
operator ()BAND91 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
92 {
93 Type old = x;
94 x = x & y;
95 return old;
96 }
97 };
98 template <typename Type>
99 struct BOR {
operator ()BOR100 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
101 {
102 Type old = x;
103 x = x | y;
104 return old;
105 }
106 };
107 template <typename Type>
108 struct BXOR {
operator ()BXOR109 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
110 {
111 Type old = x;
112 x = x ^ y;
113 return old;
114 }
115 };
116 template <typename PairType>
117 struct Minloc {
operator ()Minloc118 KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
119 {
120 PairType old = x;
121 if (y.first < x.first) x = y;
122 else if (y.first == x.first) x.second = PetscMin(x.second, y.second);
123 return old;
124 }
125 };
126 template <typename PairType>
127 struct Maxloc {
operator ()Maxloc128 KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
129 {
130 PairType old = x;
131 if (y.first > x.first) x = y;
132 else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */
133 return old;
134 }
135 };
136
137 /*====================================================================================*/
138 /* Atomic operations */
139 /*====================================================================================*/
140 template <typename Type>
141 struct AtomicInsert {
operator ()AtomicInsert142 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_store(&x, y); }
143 };
144 template <typename Type>
145 struct AtomicAdd {
operator ()AtomicAdd146 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); }
147 };
148 template <typename Type>
149 struct AtomicBAND {
operator ()AtomicBAND150 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); }
151 };
152 template <typename Type>
153 struct AtomicBOR {
operator ()AtomicBOR154 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); }
155 };
156 template <typename Type>
157 struct AtomicBXOR {
operator ()AtomicBXOR158 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); }
159 };
160 template <typename Type>
161 struct AtomicLAND {
operator ()AtomicLAND162 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
163 {
164 const Type zero = 0, one = ~0;
165 Kokkos::atomic_and(&x, y ? one : zero);
166 }
167 };
168 template <typename Type>
169 struct AtomicLOR {
operator ()AtomicLOR170 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
171 {
172 const Type zero = 0, one = 1;
173 Kokkos::atomic_or(&x, y ? one : zero);
174 }
175 };
176 template <typename Type>
177 struct AtomicMult {
operator ()AtomicMult178 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); }
179 };
180 template <typename Type>
181 struct AtomicMin {
operator ()AtomicMin182 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); }
183 };
184 template <typename Type>
185 struct AtomicMax {
operator ()AtomicMax186 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); }
187 };
188 /* TODO: struct AtomicLXOR */
189 template <typename Type>
190 struct AtomicFetchAdd {
operator ()AtomicFetchAdd191 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); }
192 };
193
194 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
MapTidToIndex(const PetscInt * opt,PetscInt tid)195 static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid)
196 {
197 PetscInt i, j, k, m, n, r;
198 const PetscInt *offset, *start, *dx, *dy, *X, *Y;
199
200 n = opt[0];
201 offset = opt + 1;
202 start = opt + n + 2;
203 dx = opt + 2 * n + 2;
204 dy = opt + 3 * n + 2;
205 X = opt + 5 * n + 2;
206 Y = opt + 6 * n + 2;
207 for (r = 0; r < n; r++) {
208 if (tid < offset[r + 1]) break;
209 }
210 m = (tid - offset[r]);
211 k = m / (dx[r] * dy[r]);
212 j = (m - k * dx[r] * dy[r]) / dx[r];
213 i = m - k * dx[r] * dy[r] - j * dx[r];
214
215 return start[r] + k * X[r] * Y[r] + j * X[r] + i;
216 }
217
218 /*====================================================================================*/
219 /* Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link' */
220 /*====================================================================================*/
221
222 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
223 <Type> is PetscReal, which is the primitive type we operate on.
224 <bs> is 16, which says <unit> contains 16 primitive types.
225 <BS> is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
226 <EQ> is 0, which is (bs == BS ? 1 : 0)
227
228 If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
229 For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
230 */
231 template <typename Type, PetscInt BS, PetscInt EQ>
Pack(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt * idx,const void * data_,void * buf_)232 static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_)
233 {
234 const PetscInt *iopt = opt ? opt->array : NULL;
235 const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */
236 const Type *data = static_cast<const Type *>(data_);
237 Type *buf = static_cast<Type *>(buf_);
238 DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
239
240 PetscFunctionBegin;
241 Kokkos::parallel_for(
242 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
243 /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
244 iopt == NULL && idx == NULL ==> the indices are contiguous;
245 */
246 PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
247 PetscInt s = tid * MBS;
248 for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i];
249 });
250 PetscFunctionReturn(PETSC_SUCCESS);
251 }
252
253 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
UnpackAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt * idx,void * data_,const void * buf_)254 static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_)
255 {
256 Op op;
257 const PetscInt *iopt = opt ? opt->array : NULL;
258 const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS;
259 Type *data = static_cast<Type *>(data_);
260 const Type *buf = static_cast<const Type *>(buf_);
261 DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
262
263 PetscFunctionBegin;
264 Kokkos::parallel_for(
265 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
266 PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
267 PetscInt s = tid * MBS;
268 for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
269 });
270 PetscFunctionReturn(PETSC_SUCCESS);
271 }
272
273 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
FetchAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt * idx,void * data,void * buf)274 static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf)
275 {
276 Op op;
277 const PetscInt *ropt = opt ? opt->array : NULL;
278 const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS;
279 Type *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf);
280 DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
281
282 PetscFunctionBegin;
283 Kokkos::parallel_for(
284 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
285 PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
286 PetscInt l = tid * MBS;
287 for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
288 });
289 PetscFunctionReturn(PETSC_SUCCESS);
290 }
291
292 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
ScatterAndOp(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt * srcIdx,const void * src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt * dstIdx,void * dst_)293 static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
294 {
295 PetscInt srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;
296 const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
297 const Type *src = static_cast<const Type *>(src_);
298 Type *dst = static_cast<Type *>(dst_);
299 DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
300
301 PetscFunctionBegin;
302 /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
303 if (srcOpt) {
304 srcx = srcOpt->dx[0];
305 srcy = srcOpt->dy[0];
306 srcX = srcOpt->X[0];
307 srcY = srcOpt->Y[0];
308 srcStart = srcOpt->start[0];
309 srcIdx = NULL;
310 } else if (!srcIdx) {
311 srcx = srcX = count;
312 srcy = srcY = 1;
313 }
314
315 if (dstOpt) {
316 dstx = dstOpt->dx[0];
317 dsty = dstOpt->dy[0];
318 dstX = dstOpt->X[0];
319 dstY = dstOpt->Y[0];
320 dstStart = dstOpt->start[0];
321 dstIdx = NULL;
322 } else if (!dstIdx) {
323 dstx = dstX = count;
324 dsty = dstY = 1;
325 }
326
327 Kokkos::parallel_for(
328 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
329 PetscInt i, j, k, s, t;
330 Op op;
331 if (!srcIdx) { /* src is in 3D */
332 k = tid / (srcx * srcy);
333 j = (tid - k * srcx * srcy) / srcx;
334 i = tid - k * srcx * srcy - j * srcx;
335 s = srcStart + k * srcX * srcY + j * srcX + i;
336 } else { /* src is contiguous */
337 s = srcIdx[tid];
338 }
339
340 if (!dstIdx) { /* 3D */
341 k = tid / (dstx * dsty);
342 j = (tid - k * dstx * dsty) / dstx;
343 i = tid - k * dstx * dsty - j * dstx;
344 t = dstStart + k * dstX * dstY + j * dstX + i;
345 } else { /* contiguous */
346 t = dstIdx[tid];
347 }
348
349 s *= MBS;
350 t *= MBS;
351 for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
352 });
353 PetscFunctionReturn(PETSC_SUCCESS);
354 }
355
356 /* Specialization for Insert since we may use memcpy */
357 template <typename Type, PetscInt BS, PetscInt EQ>
ScatterAndInsert(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt * srcIdx,const void * src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt * dstIdx,void * dst_)358 static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
359 {
360 const Type *src = static_cast<const Type *>(src_);
361 Type *dst = static_cast<Type *>(dst_);
362 DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
363
364 PetscFunctionBegin;
365 if (!count) PetscFunctionReturn(PETSC_SUCCESS);
366 /*src and dst are contiguous */
367 if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
368 size_t sz = count * link->unitbytes;
369 deviceBuffer_t dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz);
370 deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz);
371 Kokkos::deep_copy(exec, dbuf, sbuf);
372 } else {
373 PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
374 }
375 PetscFunctionReturn(PETSC_SUCCESS);
376 }
377
378 template <typename Type, class Op, PetscInt BS, PetscInt EQ>
FetchAndOpLocal(PetscSFLink link,PetscInt count,PetscInt rootstart,PetscSFPackOpt rootopt,const PetscInt * rootidx,void * rootdata_,PetscInt leafstart,PetscSFPackOpt leafopt,const PetscInt * leafidx,const void * leafdata_,void * leafupdate_)379 static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_)
380 {
381 Op op;
382 const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
383 const PetscInt *ropt = rootopt ? rootopt->array : NULL;
384 const PetscInt *lopt = leafopt ? leafopt->array : NULL;
385 Type *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_);
386 const Type *leafdata = static_cast<const Type *>(leafdata_);
387 DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
388
389 PetscFunctionBegin;
390 Kokkos::parallel_for(
391 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
392 PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
393 PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
394 for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
395 });
396 PetscFunctionReturn(PETSC_SUCCESS);
397 }
398
399 /*====================================================================================*/
400 /* Init various types and instantiate pack/unpack function pointers */
401 /*====================================================================================*/
402 template <typename Type, PetscInt BS, PetscInt EQ>
PackInit_RealType(PetscSFLink link)403 static void PackInit_RealType(PetscSFLink link)
404 {
405 /* Pack/unpack for remote communication */
406 link->d_Pack = Pack<Type, BS, EQ>;
407 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
408 link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>;
409 link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
410 link->d_UnpackAndMin = UnpackAndOp<Type, Min<Type>, BS, EQ>;
411 link->d_UnpackAndMax = UnpackAndOp<Type, Max<Type>, BS, EQ>;
412 link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>;
413 /* Scatter for local communication */
414 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
415 link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>;
416 link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
417 link->d_ScatterAndMin = ScatterAndOp<Type, Min<Type>, BS, EQ>;
418 link->d_ScatterAndMax = ScatterAndOp<Type, Max<Type>, BS, EQ>;
419 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
420 /* Atomic versions when there are data-race possibilities */
421 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
422 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
423 link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
424 link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
425 link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
426 link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
427
428 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
429 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
430 link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
431 link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
432 link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
433 link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
434 }
435
436 template <typename Type, PetscInt BS, PetscInt EQ>
PackInit_IntegerType(PetscSFLink link)437 static void PackInit_IntegerType(PetscSFLink link)
438 {
439 link->d_Pack = Pack<Type, BS, EQ>;
440 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
441 link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>;
442 link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
443 link->d_UnpackAndMin = UnpackAndOp<Type, Min<Type>, BS, EQ>;
444 link->d_UnpackAndMax = UnpackAndOp<Type, Max<Type>, BS, EQ>;
445 link->d_UnpackAndLAND = UnpackAndOp<Type, LAND<Type>, BS, EQ>;
446 link->d_UnpackAndLOR = UnpackAndOp<Type, LOR<Type>, BS, EQ>;
447 link->d_UnpackAndLXOR = UnpackAndOp<Type, LXOR<Type>, BS, EQ>;
448 link->d_UnpackAndBAND = UnpackAndOp<Type, BAND<Type>, BS, EQ>;
449 link->d_UnpackAndBOR = UnpackAndOp<Type, BOR<Type>, BS, EQ>;
450 link->d_UnpackAndBXOR = UnpackAndOp<Type, BXOR<Type>, BS, EQ>;
451 link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>;
452
453 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
454 link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>;
455 link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
456 link->d_ScatterAndMin = ScatterAndOp<Type, Min<Type>, BS, EQ>;
457 link->d_ScatterAndMax = ScatterAndOp<Type, Max<Type>, BS, EQ>;
458 link->d_ScatterAndLAND = ScatterAndOp<Type, LAND<Type>, BS, EQ>;
459 link->d_ScatterAndLOR = ScatterAndOp<Type, LOR<Type>, BS, EQ>;
460 link->d_ScatterAndLXOR = ScatterAndOp<Type, LXOR<Type>, BS, EQ>;
461 link->d_ScatterAndBAND = ScatterAndOp<Type, BAND<Type>, BS, EQ>;
462 link->d_ScatterAndBOR = ScatterAndOp<Type, BOR<Type>, BS, EQ>;
463 link->d_ScatterAndBXOR = ScatterAndOp<Type, BXOR<Type>, BS, EQ>;
464 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
465
466 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
467 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
468 link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
469 link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
470 link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
471 link->da_UnpackAndLAND = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
472 link->da_UnpackAndLOR = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
473 link->da_UnpackAndBAND = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
474 link->da_UnpackAndBOR = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
475 link->da_UnpackAndBXOR = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
476 link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
477
478 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
479 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
480 link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
481 link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
482 link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
483 link->da_ScatterAndLAND = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
484 link->da_ScatterAndLOR = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
485 link->da_ScatterAndBAND = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
486 link->da_ScatterAndBOR = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
487 link->da_ScatterAndBXOR = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
488 link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
489 }
490
491 #if defined(PETSC_HAVE_COMPLEX)
492 template <typename Type, PetscInt BS, PetscInt EQ>
PackInit_ComplexType(PetscSFLink link)493 static void PackInit_ComplexType(PetscSFLink link)
494 {
495 link->d_Pack = Pack<Type, BS, EQ>;
496 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
497 link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>;
498 link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
499 link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>;
500
501 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
502 link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>;
503 link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
504 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
505
506 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
507 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
508 link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
509 link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
510
511 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
512 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
513 link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
514 link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
515 }
516 #endif
517
518 template <typename Type>
PackInit_PairType(PetscSFLink link)519 static void PackInit_PairType(PetscSFLink link)
520 {
521 link->d_Pack = Pack<Type, 1, 1>;
522 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>;
523 link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>;
524 link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>;
525
526 link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>;
527 link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>;
528 link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>;
529 /* Atomics for pair types are not implemented yet */
530 }
531
532 template <typename Type, PetscInt BS, PetscInt EQ>
PackInit_DumbType(PetscSFLink link)533 static void PackInit_DumbType(PetscSFLink link)
534 {
535 link->d_Pack = Pack<Type, BS, EQ>;
536 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
537 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
538 /* Atomics for dumb types are not implemented yet */
539 }
540
541 /*
542 Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
543 that one is not able to repeatedly create and destroy the object. SF's original design was each
544 SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
545 destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
546 memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
547 does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
548 object in Kokkos.
549 */
550 /*
551 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
552 {
553 PetscFunctionBegin;
554 PetscFunctionReturn(PETSC_SUCCESS);
555 }
556 */
557
558 /* Some device-specific utilities */
PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link)559 static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link)
560 {
561 PetscFunctionBegin;
562 Kokkos::fence();
563 PetscFunctionReturn(PETSC_SUCCESS);
564 }
565
PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link)566 static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link)
567 {
568 DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
569
570 PetscFunctionBegin;
571 exec.fence();
572 PetscFunctionReturn(PETSC_SUCCESS);
573 }
574
PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link,PetscMemType dstmtype,void * dst,PetscMemType srcmtype,const void * src,size_t n)575 static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n)
576 {
577 DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
578
579 PetscFunctionBegin;
580 if (!n) PetscFunctionReturn(PETSC_SUCCESS);
581 if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { // H2H
582 PetscCallCXX(exec.fence()); // make sure async kernels on src are finished, in case of unified memory as on AMD MI300A.
583 PetscCall(PetscMemcpy(dst, src, n));
584 } else {
585 if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) { // H2D
586 deviceBuffer_t dbuf(static_cast<char *>(dst), n);
587 HostConstBuffer_t sbuf(static_cast<const char *>(src), n);
588 PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
589 PetscCall(PetscLogCpuToGpu(n));
590 } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) { // D2H
591 HostBuffer_t dbuf(static_cast<char *>(dst), n);
592 deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
593 PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
594 PetscCallCXX(exec.fence()); // make sure dbuf is ready for use immediately on host
595 PetscCall(PetscLogGpuToCpu(n));
596 } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) { // D2D
597 deviceBuffer_t dbuf(static_cast<char *>(dst), n);
598 deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
599 PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
600 }
601 }
602 PetscFunctionReturn(PETSC_SUCCESS);
603 }
604
PetscSFMalloc_Kokkos(PetscMemType mtype,size_t size,void ** ptr)605 PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr)
606 {
607 PetscFunctionBegin;
608 if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
609 else if (PetscMemTypeDevice(mtype)) {
610 if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck());
611 PetscCallCXX(*ptr = Kokkos::kokkos_malloc<DefaultMemorySpace>(size));
612 } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
613 PetscFunctionReturn(PETSC_SUCCESS);
614 }
615
PetscSFFree_Kokkos(PetscMemType mtype,void * ptr)616 PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr)
617 {
618 PetscFunctionBegin;
619 if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
620 else if (PetscMemTypeDevice(mtype)) {
621 PetscCallCXX(Kokkos::kokkos_free<DefaultMemorySpace>(ptr));
622 } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
623 PetscFunctionReturn(PETSC_SUCCESS);
624 }
625
626 /* Destructor when the link uses MPI for communication */
PetscSFLinkDestroy_Kokkos(PetscSF sf,PetscSFLink link)627 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link)
628 {
629 PetscFunctionBegin;
630 for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
631 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
632 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
633 }
634 PetscFunctionReturn(PETSC_SUCCESS);
635 }
636
637 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf,PetscSFLink link,MPI_Datatype unit)638 PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit)
639 {
640 PetscInt nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
641 PetscBool is2Int, is2PetscInt;
642 #if defined(PETSC_HAVE_COMPLEX)
643 PetscInt nPetscComplex = 0;
644 #endif
645
646 PetscFunctionBegin;
647 if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS);
648 PetscCall(PetscKokkosInitializeCheck());
649 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
650 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
651 /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
652 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
653 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
654 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
655 #if defined(PETSC_HAVE_COMPLEX)
656 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
657 #endif
658 PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
659 PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));
660
661 if (is2Int) {
662 PackInit_PairType<Kokkos::pair<int, int>>(link);
663 } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
664 PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link);
665 } else if (nPetscReal) {
666 #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */
667 if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
668 else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
669 else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
670 else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
671 else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
672 else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
673 else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
674 else if (nPetscReal % 1 == 0)
675 #endif
676 PackInit_RealType<PetscReal, 1, 0>(link);
677 } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
678 #if !defined(PETSC_HAVE_DEVICE)
679 if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
680 else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
681 else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
682 else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
683 else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
684 else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
685 else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
686 else if (nPetscInt % 1 == 0)
687 #endif
688 PackInit_IntegerType<llint, 1, 0>(link);
689 } else if (nInt) {
690 #if !defined(PETSC_HAVE_DEVICE)
691 if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
692 else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
693 else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
694 else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
695 else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
696 else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
697 else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
698 else if (nInt % 1 == 0)
699 #endif
700 PackInit_IntegerType<int, 1, 0>(link);
701 } else if (nSignedChar) {
702 #if !defined(PETSC_HAVE_DEVICE)
703 if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link);
704 else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link);
705 else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link);
706 else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link);
707 else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link);
708 else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link);
709 else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link);
710 else if (nSignedChar % 1 == 0)
711 #endif
712 PackInit_IntegerType<char, 1, 0>(link);
713 } else if (nUnsignedChar) {
714 #if !defined(PETSC_HAVE_DEVICE)
715 if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link);
716 else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link);
717 else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link);
718 else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link);
719 else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link);
720 else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link);
721 else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link);
722 else if (nUnsignedChar % 1 == 0)
723 #endif
724 PackInit_IntegerType<unsigned char, 1, 0>(link);
725 #if defined(PETSC_HAVE_COMPLEX)
726 } else if (nPetscComplex) {
727 #if !defined(PETSC_HAVE_DEVICE)
728 if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link);
729 else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link);
730 else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link);
731 else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link);
732 else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link);
733 else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link);
734 else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link);
735 else if (nPetscComplex % 1 == 0)
736 #endif
737 PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link);
738 #endif
739 } else {
740 MPI_Aint nbyte;
741
742 PetscCall(PetscSFGetDatatypeSize_Internal(PETSC_COMM_SELF, unit, &nbyte));
743 if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
744 #if !defined(PETSC_HAVE_DEVICE)
745 if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
746 else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
747 else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
748 else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
749 else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
750 else if (nbyte % 1 == 0)
751 #endif
752 PackInit_DumbType<char, 1, 0>(link);
753 } else {
754 PetscCall(PetscIntCast(nbyte / sizeof(int), &nInt));
755 #if !defined(PETSC_HAVE_DEVICE)
756 if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
757 else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
758 else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
759 else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
760 else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
761 else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
762 else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
763 else if (nInt % 1 == 0)
764 #endif
765 PackInit_DumbType<int, 1, 0>(link);
766 }
767 }
768
769 link->SyncDevice = PetscSFLinkSyncDevice_Kokkos;
770 link->SyncStream = PetscSFLinkSyncStream_Kokkos;
771 link->Memcpy = PetscSFLinkMemcpy_Kokkos;
772 link->Destroy = PetscSFLinkDestroy_Kokkos;
773 link->deviceinited = PETSC_TRUE;
774 PetscFunctionReturn(PETSC_SUCCESS);
775 }
776