1 #include <../src/vec/is/sf/impls/basic/sfpack.h> 2 3 #include <petsc_kokkos.hpp> 4 5 using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace; 6 using DeviceMemorySpace = typename DeviceExecutionSpace::memory_space; 7 using HostMemorySpace = Kokkos::HostSpace; 8 9 typedef Kokkos::View<char *, DeviceMemorySpace> deviceBuffer_t; 10 typedef Kokkos::View<char *, HostMemorySpace> HostBuffer_t; 11 12 typedef Kokkos::View<const char *, DeviceMemorySpace> deviceConstBuffer_t; 13 typedef Kokkos::View<const char *, HostMemorySpace> HostConstBuffer_t; 14 15 /*====================================================================================*/ 16 /* Regular operations */ 17 /*====================================================================================*/ 18 template <typename Type> 19 struct Insert { 20 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const 21 { 22 Type old = x; 23 x = y; 24 return old; 25 } 26 }; 27 template <typename Type> 28 struct Add { 29 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const 30 { 31 Type old = x; 32 x += y; 33 return old; 34 } 35 }; 36 template <typename Type> 37 struct Mult { 38 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const 39 { 40 Type old = x; 41 x *= y; 42 return old; 43 } 44 }; 45 template <typename Type> 46 struct Min { 47 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const 48 { 49 Type old = x; 50 x = PetscMin(x, y); 51 return old; 52 } 53 }; 54 template <typename Type> 55 struct Max { 56 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const 57 { 58 Type old = x; 59 x = PetscMax(x, y); 60 return old; 61 } 62 }; 63 template <typename Type> 64 struct LAND { 65 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const 66 { 67 Type old = x; 68 x = x && y; 69 return old; 70 } 71 }; 72 template <typename Type> 73 struct LOR { 74 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const 75 { 76 Type old = x; 77 x = x || y; 78 return old; 79 } 80 }; 81 template <typename Type> 82 struct LXOR { 83 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const 84 { 85 Type old = x; 86 x = !x != !y; 87 return old; 88 } 89 }; 90 template <typename Type> 91 struct BAND { 92 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const 93 { 94 Type old = x; 95 x = x & y; 96 return old; 97 } 98 }; 99 template <typename Type> 100 struct BOR { 101 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const 102 { 103 Type old = x; 104 x = x | y; 105 return old; 106 } 107 }; 108 template <typename Type> 109 struct BXOR { 110 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const 111 { 112 Type old = x; 113 x = x ^ y; 114 return old; 115 } 116 }; 117 template <typename PairType> 118 struct Minloc { 119 KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const 120 { 121 PairType old = x; 122 if (y.first < x.first) x = y; 123 else if (y.first == x.first) x.second = PetscMin(x.second, y.second); 124 return old; 125 } 126 }; 127 template <typename PairType> 128 struct Maxloc { 129 KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const 130 { 131 PairType old = x; 132 if (y.first > x.first) x = y; 133 else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */ 134 return old; 135 } 136 }; 137 138 /*====================================================================================*/ 139 /* Atomic operations */ 140 /*====================================================================================*/ 141 template <typename Type> 142 struct AtomicInsert { 143 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_assign(&x, y); } 144 }; 145 template <typename Type> 146 struct AtomicAdd { 147 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); } 148 }; 149 template <typename Type> 150 struct AtomicBAND { 151 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); } 152 }; 153 template <typename Type> 154 struct AtomicBOR { 155 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); } 156 }; 157 template <typename Type> 158 struct AtomicBXOR { 159 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); } 160 }; 161 template <typename Type> 162 struct AtomicLAND { 163 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const 164 { 165 const Type zero = 0, one = ~0; 166 Kokkos::atomic_and(&x, y ? one : zero); 167 } 168 }; 169 template <typename Type> 170 struct AtomicLOR { 171 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const 172 { 173 const Type zero = 0, one = 1; 174 Kokkos::atomic_or(&x, y ? one : zero); 175 } 176 }; 177 template <typename Type> 178 struct AtomicMult { 179 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); } 180 }; 181 template <typename Type> 182 struct AtomicMin { 183 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); } 184 }; 185 template <typename Type> 186 struct AtomicMax { 187 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); } 188 }; 189 /* TODO: struct AtomicLXOR */ 190 template <typename Type> 191 struct AtomicFetchAdd { 192 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); } 193 }; 194 195 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */ 196 static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid) 197 { 198 PetscInt i, j, k, m, n, r; 199 const PetscInt *offset, *start, *dx, *dy, *X, *Y; 200 201 n = opt[0]; 202 offset = opt + 1; 203 start = opt + n + 2; 204 dx = opt + 2 * n + 2; 205 dy = opt + 3 * n + 2; 206 X = opt + 5 * n + 2; 207 Y = opt + 6 * n + 2; 208 for (r = 0; r < n; r++) { 209 if (tid < offset[r + 1]) break; 210 } 211 m = (tid - offset[r]); 212 k = m / (dx[r] * dy[r]); 213 j = (m - k * dx[r] * dy[r]) / dx[r]; 214 i = m - k * dx[r] * dy[r] - j * dx[r]; 215 216 return start[r] + k * X[r] * Y[r] + j * X[r] + i; 217 } 218 219 /*====================================================================================*/ 220 /* Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link' */ 221 /*====================================================================================*/ 222 223 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then 224 <Type> is PetscReal, which is the primitive type we operate on. 225 <bs> is 16, which says <unit> contains 16 primitive types. 226 <BS> is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>. 227 <EQ> is 0, which is (bs == BS ? 1 : 0) 228 229 If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant. 230 For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled. 231 */ 232 template <typename Type, PetscInt BS, PetscInt EQ> 233 static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_) 234 { 235 const PetscInt *iopt = opt ? opt->array : NULL; 236 const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */ 237 const Type *data = static_cast<const Type *>(data_); 238 Type *buf = static_cast<Type *>(buf_); 239 DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 240 241 PetscFunctionBegin; 242 Kokkos::parallel_for( 243 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 244 /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous; 245 iopt == NULL && idx == NULL ==> the indices are contiguous; 246 */ 247 PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS; 248 PetscInt s = tid * MBS; 249 for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i]; 250 }); 251 PetscFunctionReturn(PETSC_SUCCESS); 252 } 253 254 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 255 static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_) 256 { 257 Op op; 258 const PetscInt *iopt = opt ? opt->array : NULL; 259 const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; 260 Type *data = static_cast<Type *>(data_); 261 const Type *buf = static_cast<const Type *>(buf_); 262 DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 263 264 PetscFunctionBegin; 265 Kokkos::parallel_for( 266 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 267 PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS; 268 PetscInt s = tid * MBS; 269 for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]); 270 }); 271 PetscFunctionReturn(PETSC_SUCCESS); 272 } 273 274 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 275 static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf) 276 { 277 Op op; 278 const PetscInt *ropt = opt ? opt->array : NULL; 279 const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; 280 Type *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf); 281 DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 282 283 PetscFunctionBegin; 284 Kokkos::parallel_for( 285 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 286 PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS; 287 PetscInt l = tid * MBS; 288 for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]); 289 }); 290 PetscFunctionReturn(PETSC_SUCCESS); 291 } 292 293 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 294 static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) 295 { 296 PetscInt srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0; 297 const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS; 298 const Type *src = static_cast<const Type *>(src_); 299 Type *dst = static_cast<Type *>(dst_); 300 DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 301 302 PetscFunctionBegin; 303 /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */ 304 if (srcOpt) { 305 srcx = srcOpt->dx[0]; 306 srcy = srcOpt->dy[0]; 307 srcX = srcOpt->X[0]; 308 srcY = srcOpt->Y[0]; 309 srcStart = srcOpt->start[0]; 310 srcIdx = NULL; 311 } else if (!srcIdx) { 312 srcx = srcX = count; 313 srcy = srcY = 1; 314 } 315 316 if (dstOpt) { 317 dstx = dstOpt->dx[0]; 318 dsty = dstOpt->dy[0]; 319 dstX = dstOpt->X[0]; 320 dstY = dstOpt->Y[0]; 321 dstStart = dstOpt->start[0]; 322 dstIdx = NULL; 323 } else if (!dstIdx) { 324 dstx = dstX = count; 325 dsty = dstY = 1; 326 } 327 328 Kokkos::parallel_for( 329 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 330 PetscInt i, j, k, s, t; 331 Op op; 332 if (!srcIdx) { /* src is in 3D */ 333 k = tid / (srcx * srcy); 334 j = (tid - k * srcx * srcy) / srcx; 335 i = tid - k * srcx * srcy - j * srcx; 336 s = srcStart + k * srcX * srcY + j * srcX + i; 337 } else { /* src is contiguous */ 338 s = srcIdx[tid]; 339 } 340 341 if (!dstIdx) { /* 3D */ 342 k = tid / (dstx * dsty); 343 j = (tid - k * dstx * dsty) / dstx; 344 i = tid - k * dstx * dsty - j * dstx; 345 t = dstStart + k * dstX * dstY + j * dstX + i; 346 } else { /* contiguous */ 347 t = dstIdx[tid]; 348 } 349 350 s *= MBS; 351 t *= MBS; 352 for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]); 353 }); 354 PetscFunctionReturn(PETSC_SUCCESS); 355 } 356 357 /* Specialization for Insert since we may use memcpy */ 358 template <typename Type, PetscInt BS, PetscInt EQ> 359 static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) 360 { 361 const Type *src = static_cast<const Type *>(src_); 362 Type *dst = static_cast<Type *>(dst_); 363 DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 364 365 PetscFunctionBegin; 366 if (!count) PetscFunctionReturn(PETSC_SUCCESS); 367 /*src and dst are contiguous */ 368 if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) { 369 size_t sz = count * link->unitbytes; 370 deviceBuffer_t dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz); 371 deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz); 372 Kokkos::deep_copy(exec, dbuf, sbuf); 373 } else { 374 PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst)); 375 } 376 PetscFunctionReturn(PETSC_SUCCESS); 377 } 378 379 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 380 static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_) 381 { 382 Op op; 383 const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS; 384 const PetscInt *ropt = rootopt ? rootopt->array : NULL; 385 const PetscInt *lopt = leafopt ? leafopt->array : NULL; 386 Type *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_); 387 const Type *leafdata = static_cast<const Type *>(leafdata_); 388 DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 389 390 PetscFunctionBegin; 391 Kokkos::parallel_for( 392 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 393 PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS; 394 PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS; 395 for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]); 396 }); 397 PetscFunctionReturn(PETSC_SUCCESS); 398 } 399 400 /*====================================================================================*/ 401 /* Init various types and instantiate pack/unpack function pointers */ 402 /*====================================================================================*/ 403 template <typename Type, PetscInt BS, PetscInt EQ> 404 static void PackInit_RealType(PetscSFLink link) 405 { 406 /* Pack/unpack for remote communication */ 407 link->d_Pack = Pack<Type, BS, EQ>; 408 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 409 link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>; 410 link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>; 411 link->d_UnpackAndMin = UnpackAndOp<Type, Min<Type>, BS, EQ>; 412 link->d_UnpackAndMax = UnpackAndOp<Type, Max<Type>, BS, EQ>; 413 link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>; 414 /* Scatter for local communication */ 415 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */ 416 link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>; 417 link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>; 418 link->d_ScatterAndMin = ScatterAndOp<Type, Min<Type>, BS, EQ>; 419 link->d_ScatterAndMax = ScatterAndOp<Type, Max<Type>, BS, EQ>; 420 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>; 421 /* Atomic versions when there are data-race possibilities */ 422 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 423 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 424 link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>; 425 link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>; 426 link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>; 427 link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>; 428 429 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 430 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 431 link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>; 432 link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>; 433 link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>; 434 link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>; 435 } 436 437 template <typename Type, PetscInt BS, PetscInt EQ> 438 static void PackInit_IntegerType(PetscSFLink link) 439 { 440 link->d_Pack = Pack<Type, BS, EQ>; 441 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 442 link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>; 443 link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>; 444 link->d_UnpackAndMin = UnpackAndOp<Type, Min<Type>, BS, EQ>; 445 link->d_UnpackAndMax = UnpackAndOp<Type, Max<Type>, BS, EQ>; 446 link->d_UnpackAndLAND = UnpackAndOp<Type, LAND<Type>, BS, EQ>; 447 link->d_UnpackAndLOR = UnpackAndOp<Type, LOR<Type>, BS, EQ>; 448 link->d_UnpackAndLXOR = UnpackAndOp<Type, LXOR<Type>, BS, EQ>; 449 link->d_UnpackAndBAND = UnpackAndOp<Type, BAND<Type>, BS, EQ>; 450 link->d_UnpackAndBOR = UnpackAndOp<Type, BOR<Type>, BS, EQ>; 451 link->d_UnpackAndBXOR = UnpackAndOp<Type, BXOR<Type>, BS, EQ>; 452 link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>; 453 454 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 455 link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>; 456 link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>; 457 link->d_ScatterAndMin = ScatterAndOp<Type, Min<Type>, BS, EQ>; 458 link->d_ScatterAndMax = ScatterAndOp<Type, Max<Type>, BS, EQ>; 459 link->d_ScatterAndLAND = ScatterAndOp<Type, LAND<Type>, BS, EQ>; 460 link->d_ScatterAndLOR = ScatterAndOp<Type, LOR<Type>, BS, EQ>; 461 link->d_ScatterAndLXOR = ScatterAndOp<Type, LXOR<Type>, BS, EQ>; 462 link->d_ScatterAndBAND = ScatterAndOp<Type, BAND<Type>, BS, EQ>; 463 link->d_ScatterAndBOR = ScatterAndOp<Type, BOR<Type>, BS, EQ>; 464 link->d_ScatterAndBXOR = ScatterAndOp<Type, BXOR<Type>, BS, EQ>; 465 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>; 466 467 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 468 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 469 link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>; 470 link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>; 471 link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>; 472 link->da_UnpackAndLAND = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>; 473 link->da_UnpackAndLOR = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>; 474 link->da_UnpackAndBAND = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>; 475 link->da_UnpackAndBOR = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>; 476 link->da_UnpackAndBXOR = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>; 477 link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>; 478 479 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 480 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 481 link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>; 482 link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>; 483 link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>; 484 link->da_ScatterAndLAND = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>; 485 link->da_ScatterAndLOR = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>; 486 link->da_ScatterAndBAND = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>; 487 link->da_ScatterAndBOR = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>; 488 link->da_ScatterAndBXOR = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>; 489 link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>; 490 } 491 492 #if defined(PETSC_HAVE_COMPLEX) 493 template <typename Type, PetscInt BS, PetscInt EQ> 494 static void PackInit_ComplexType(PetscSFLink link) 495 { 496 link->d_Pack = Pack<Type, BS, EQ>; 497 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 498 link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>; 499 link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>; 500 link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>; 501 502 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 503 link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>; 504 link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>; 505 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>; 506 507 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 508 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 509 link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>; 510 link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>; 511 512 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 513 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 514 link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>; 515 link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>; 516 } 517 #endif 518 519 template <typename Type> 520 static void PackInit_PairType(PetscSFLink link) 521 { 522 link->d_Pack = Pack<Type, 1, 1>; 523 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>; 524 link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>; 525 link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>; 526 527 link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>; 528 link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>; 529 link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>; 530 /* Atomics for pair types are not implemented yet */ 531 } 532 533 template <typename Type, PetscInt BS, PetscInt EQ> 534 static void PackInit_DumbType(PetscSFLink link) 535 { 536 link->d_Pack = Pack<Type, BS, EQ>; 537 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 538 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 539 /* Atomics for dumb types are not implemented yet */ 540 } 541 542 /* 543 Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug 544 that one is not able to repeatedly create and destroy the object. SF's original design was each 545 SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from 546 destroying multiple SFLinks with NULL stream and the default execution space object. To avoid 547 memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos 548 does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton 549 object in Kokkos. 550 */ 551 /* 552 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link) 553 { 554 PetscFunctionBegin; 555 PetscFunctionReturn(PETSC_SUCCESS); 556 } 557 */ 558 559 /* Some device-specific utilities */ 560 static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link) 561 { 562 PetscFunctionBegin; 563 Kokkos::fence(); 564 PetscFunctionReturn(PETSC_SUCCESS); 565 } 566 567 static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link) 568 { 569 DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 570 571 PetscFunctionBegin; 572 exec.fence(); 573 PetscFunctionReturn(PETSC_SUCCESS); 574 } 575 576 static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n) 577 { 578 DeviceExecutionSpace &exec = PetscGetKokkosExecutionSpace(); 579 580 PetscFunctionBegin; 581 if (!n) PetscFunctionReturn(PETSC_SUCCESS); 582 if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { 583 PetscCall(PetscMemcpy(dst, src, n)); 584 } else { 585 if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) { // H2D 586 deviceBuffer_t dbuf(static_cast<char *>(dst), n); 587 HostConstBuffer_t sbuf(static_cast<const char *>(src), n); 588 PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf)); 589 PetscCall(PetscLogCpuToGpu(n)); 590 } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) { // D2H 591 HostBuffer_t dbuf(static_cast<char *>(dst), n); 592 deviceConstBuffer_t sbuf(static_cast<const char *>(src), n); 593 PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf)); 594 PetscCallCXX(exec.fence()); // make sure dbuf is ready for use immediately on host 595 PetscCall(PetscLogGpuToCpu(n)); 596 } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) { 597 deviceBuffer_t dbuf(static_cast<char *>(dst), n); 598 deviceConstBuffer_t sbuf(static_cast<const char *>(src), n); 599 PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf)); 600 } 601 } 602 PetscFunctionReturn(PETSC_SUCCESS); 603 } 604 605 PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr) 606 { 607 PetscFunctionBegin; 608 if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr)); 609 else if (PetscMemTypeDevice(mtype)) { 610 if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck()); 611 PetscCallCXX(*ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size)); 612 } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype); 613 PetscFunctionReturn(PETSC_SUCCESS); 614 } 615 616 PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr) 617 { 618 PetscFunctionBegin; 619 if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr)); 620 else if (PetscMemTypeDevice(mtype)) { 621 PetscCallCXX(Kokkos::kokkos_free<DeviceMemorySpace>(ptr)); 622 } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype); 623 PetscFunctionReturn(PETSC_SUCCESS); 624 } 625 626 /* Destructor when the link uses MPI for communication */ 627 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link) 628 { 629 PetscFunctionBegin; 630 for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) { 631 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE])); 632 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE])); 633 } 634 PetscFunctionReturn(PETSC_SUCCESS); 635 } 636 637 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */ 638 PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit) 639 { 640 PetscInt nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0; 641 PetscBool is2Int, is2PetscInt; 642 #if defined(PETSC_HAVE_COMPLEX) 643 PetscInt nPetscComplex = 0; 644 #endif 645 646 PetscFunctionBegin; 647 if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS); 648 PetscCall(PetscKokkosInitializeCheck()); 649 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar)); 650 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar)); 651 /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */ 652 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt)); 653 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt)); 654 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal)); 655 #if defined(PETSC_HAVE_COMPLEX) 656 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex)); 657 #endif 658 PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int)); 659 PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt)); 660 661 if (is2Int) { 662 PackInit_PairType<Kokkos::pair<int, int>>(link); 663 } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */ 664 PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link); 665 } else if (nPetscReal) { 666 #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */ 667 if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link); 668 else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link); 669 else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link); 670 else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link); 671 else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link); 672 else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link); 673 else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link); 674 else if (nPetscReal % 1 == 0) 675 #endif 676 PackInit_RealType<PetscReal, 1, 0>(link); 677 } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) { 678 #if !defined(PETSC_HAVE_DEVICE) 679 if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link); 680 else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link); 681 else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link); 682 else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link); 683 else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link); 684 else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link); 685 else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link); 686 else if (nPetscInt % 1 == 0) 687 #endif 688 PackInit_IntegerType<llint, 1, 0>(link); 689 } else if (nInt) { 690 #if !defined(PETSC_HAVE_DEVICE) 691 if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link); 692 else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link); 693 else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link); 694 else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link); 695 else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link); 696 else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link); 697 else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link); 698 else if (nInt % 1 == 0) 699 #endif 700 PackInit_IntegerType<int, 1, 0>(link); 701 } else if (nSignedChar) { 702 #if !defined(PETSC_HAVE_DEVICE) 703 if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link); 704 else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link); 705 else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link); 706 else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link); 707 else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link); 708 else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link); 709 else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link); 710 else if (nSignedChar % 1 == 0) 711 #endif 712 PackInit_IntegerType<char, 1, 0>(link); 713 } else if (nUnsignedChar) { 714 #if !defined(PETSC_HAVE_DEVICE) 715 if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link); 716 else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link); 717 else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link); 718 else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link); 719 else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link); 720 else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link); 721 else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link); 722 else if (nUnsignedChar % 1 == 0) 723 #endif 724 PackInit_IntegerType<unsigned char, 1, 0>(link); 725 #if defined(PETSC_HAVE_COMPLEX) 726 } else if (nPetscComplex) { 727 #if !defined(PETSC_HAVE_DEVICE) 728 if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link); 729 else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link); 730 else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link); 731 else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link); 732 else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link); 733 else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link); 734 else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link); 735 else if (nPetscComplex % 1 == 0) 736 #endif 737 PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link); 738 #endif 739 } else { 740 MPI_Aint lb, nbyte; 741 PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte)); 742 PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb); 743 if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */ 744 #if !defined(PETSC_HAVE_DEVICE) 745 if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link); 746 else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link); 747 else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link); 748 else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link); 749 else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link); 750 else if (nbyte % 1 == 0) 751 #endif 752 PackInit_DumbType<char, 1, 0>(link); 753 } else { 754 nInt = nbyte / sizeof(int); 755 #if !defined(PETSC_HAVE_DEVICE) 756 if (nInt == 8) PackInit_DumbType<int, 8, 1>(link); 757 else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link); 758 else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link); 759 else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link); 760 else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link); 761 else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link); 762 else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link); 763 else if (nInt % 1 == 0) 764 #endif 765 PackInit_DumbType<int, 1, 0>(link); 766 } 767 } 768 769 link->SyncDevice = PetscSFLinkSyncDevice_Kokkos; 770 link->SyncStream = PetscSFLinkSyncStream_Kokkos; 771 link->Memcpy = PetscSFLinkMemcpy_Kokkos; 772 link->Destroy = PetscSFLinkDestroy_Kokkos; 773 link->deviceinited = PETSC_TRUE; 774 PetscFunctionReturn(PETSC_SUCCESS); 775 } 776