1 #include <../src/vec/is/sf/impls/basic/sfpack.h> 2 3 #include <Kokkos_Core.hpp> 4 5 using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace; 6 using DeviceMemorySpace = typename DeviceExecutionSpace::memory_space; 7 using HostMemorySpace = Kokkos::HostSpace; 8 9 typedef Kokkos::View<char *, DeviceMemorySpace> deviceBuffer_t; 10 typedef Kokkos::View<char *, HostMemorySpace> HostBuffer_t; 11 12 typedef Kokkos::View<const char *, DeviceMemorySpace> deviceConstBuffer_t; 13 typedef Kokkos::View<const char *, HostMemorySpace> HostConstBuffer_t; 14 15 /*====================================================================================*/ 16 /* Regular operations */ 17 /*====================================================================================*/ 18 template <typename Type> 19 struct Insert { 20 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 21 Type old = x; 22 x = y; 23 return old; 24 } 25 }; 26 template <typename Type> 27 struct Add { 28 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 29 Type old = x; 30 x += y; 31 return old; 32 } 33 }; 34 template <typename Type> 35 struct Mult { 36 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 37 Type old = x; 38 x *= y; 39 return old; 40 } 41 }; 42 template <typename Type> 43 struct Min { 44 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 45 Type old = x; 46 x = PetscMin(x, y); 47 return old; 48 } 49 }; 50 template <typename Type> 51 struct Max { 52 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 53 Type old = x; 54 x = PetscMax(x, y); 55 return old; 56 } 57 }; 58 template <typename Type> 59 struct LAND { 60 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 61 Type old = x; 62 x = x && y; 63 return old; 64 } 65 }; 66 template <typename Type> 67 struct LOR { 68 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 69 Type old = x; 70 x = x || y; 71 return old; 72 } 73 }; 74 template <typename Type> 75 struct LXOR { 76 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 77 Type old = x; 78 x = !x != !y; 79 return old; 80 } 81 }; 82 template <typename Type> 83 struct BAND { 84 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 85 Type old = x; 86 x = x & y; 87 return old; 88 } 89 }; 90 template <typename Type> 91 struct BOR { 92 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 93 Type old = x; 94 x = x | y; 95 return old; 96 } 97 }; 98 template <typename Type> 99 struct BXOR { 100 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 101 Type old = x; 102 x = x ^ y; 103 return old; 104 } 105 }; 106 template <typename PairType> 107 struct Minloc { 108 KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const { 109 PairType old = x; 110 if (y.first < x.first) x = y; 111 else if (y.first == x.first) x.second = PetscMin(x.second, y.second); 112 return old; 113 } 114 }; 115 template <typename PairType> 116 struct Maxloc { 117 KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const { 118 PairType old = x; 119 if (y.first > x.first) x = y; 120 else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */ 121 return old; 122 } 123 }; 124 125 /*====================================================================================*/ 126 /* Atomic operations */ 127 /*====================================================================================*/ 128 template <typename Type> 129 struct AtomicInsert { 130 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_assign(&x, y); } 131 }; 132 template <typename Type> 133 struct AtomicAdd { 134 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); } 135 }; 136 template <typename Type> 137 struct AtomicBAND { 138 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); } 139 }; 140 template <typename Type> 141 struct AtomicBOR { 142 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); } 143 }; 144 template <typename Type> 145 struct AtomicBXOR { 146 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); } 147 }; 148 template <typename Type> 149 struct AtomicLAND { 150 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { 151 const Type zero = 0, one = ~0; 152 Kokkos::atomic_and(&x, y ? one : zero); 153 } 154 }; 155 template <typename Type> 156 struct AtomicLOR { 157 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { 158 const Type zero = 0, one = 1; 159 Kokkos::atomic_or(&x, y ? one : zero); 160 } 161 }; 162 template <typename Type> 163 struct AtomicMult { 164 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); } 165 }; 166 template <typename Type> 167 struct AtomicMin { 168 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); } 169 }; 170 template <typename Type> 171 struct AtomicMax { 172 KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); } 173 }; 174 /* TODO: struct AtomicLXOR */ 175 template <typename Type> 176 struct AtomicFetchAdd { 177 KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); } 178 }; 179 180 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */ 181 static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid) { 182 PetscInt i, j, k, m, n, r; 183 const PetscInt *offset, *start, *dx, *dy, *X, *Y; 184 185 n = opt[0]; 186 offset = opt + 1; 187 start = opt + n + 2; 188 dx = opt + 2 * n + 2; 189 dy = opt + 3 * n + 2; 190 X = opt + 5 * n + 2; 191 Y = opt + 6 * n + 2; 192 for (r = 0; r < n; r++) { 193 if (tid < offset[r + 1]) break; 194 } 195 m = (tid - offset[r]); 196 k = m / (dx[r] * dy[r]); 197 j = (m - k * dx[r] * dy[r]) / dx[r]; 198 i = m - k * dx[r] * dy[r] - j * dx[r]; 199 200 return (start[r] + k * X[r] * Y[r] + j * X[r] + i); 201 } 202 203 /*====================================================================================*/ 204 /* Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link' */ 205 /*====================================================================================*/ 206 207 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then 208 <Type> is PetscReal, which is the primitive type we operate on. 209 <bs> is 16, which says <unit> contains 16 primitive types. 210 <BS> is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>. 211 <EQ> is 0, which is (bs == BS ? 1 : 0) 212 213 If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant. 214 For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled. 215 */ 216 template <typename Type, PetscInt BS, PetscInt EQ> 217 static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_) { 218 const PetscInt *iopt = opt ? opt->array : NULL; 219 const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */ 220 const Type *data = static_cast<const Type *>(data_); 221 Type *buf = static_cast<Type *>(buf_); 222 DeviceExecutionSpace exec; 223 224 PetscFunctionBegin; 225 Kokkos::parallel_for( 226 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 227 /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous; 228 iopt == NULL && idx == NULL ==> the indices are contiguous; 229 */ 230 PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS; 231 PetscInt s = tid * MBS; 232 for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i]; 233 }); 234 PetscFunctionReturn(0); 235 } 236 237 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 238 static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_) { 239 Op op; 240 const PetscInt *iopt = opt ? opt->array : NULL; 241 const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; 242 Type *data = static_cast<Type *>(data_); 243 const Type *buf = static_cast<const Type *>(buf_); 244 DeviceExecutionSpace exec; 245 246 PetscFunctionBegin; 247 Kokkos::parallel_for( 248 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 249 PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS; 250 PetscInt s = tid * MBS; 251 for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]); 252 }); 253 PetscFunctionReturn(0); 254 } 255 256 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 257 static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf) { 258 Op op; 259 const PetscInt *ropt = opt ? opt->array : NULL; 260 const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; 261 Type *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf); 262 DeviceExecutionSpace exec; 263 264 PetscFunctionBegin; 265 Kokkos::parallel_for( 266 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 267 PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS; 268 PetscInt l = tid * MBS; 269 for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]); 270 }); 271 PetscFunctionReturn(0); 272 } 273 274 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 275 static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) { 276 PetscInt srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0; 277 const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS; 278 const Type *src = static_cast<const Type *>(src_); 279 Type *dst = static_cast<Type *>(dst_); 280 DeviceExecutionSpace exec; 281 282 PetscFunctionBegin; 283 /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */ 284 if (srcOpt) { 285 srcx = srcOpt->dx[0]; 286 srcy = srcOpt->dy[0]; 287 srcX = srcOpt->X[0]; 288 srcY = srcOpt->Y[0]; 289 srcStart = srcOpt->start[0]; 290 srcIdx = NULL; 291 } else if (!srcIdx) { 292 srcx = srcX = count; 293 srcy = srcY = 1; 294 } 295 296 if (dstOpt) { 297 dstx = dstOpt->dx[0]; 298 dsty = dstOpt->dy[0]; 299 dstX = dstOpt->X[0]; 300 dstY = dstOpt->Y[0]; 301 dstStart = dstOpt->start[0]; 302 dstIdx = NULL; 303 } else if (!dstIdx) { 304 dstx = dstX = count; 305 dsty = dstY = 1; 306 } 307 308 Kokkos::parallel_for( 309 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 310 PetscInt i, j, k, s, t; 311 Op op; 312 if (!srcIdx) { /* src is in 3D */ 313 k = tid / (srcx * srcy); 314 j = (tid - k * srcx * srcy) / srcx; 315 i = tid - k * srcx * srcy - j * srcx; 316 s = srcStart + k * srcX * srcY + j * srcX + i; 317 } else { /* src is contiguous */ 318 s = srcIdx[tid]; 319 } 320 321 if (!dstIdx) { /* 3D */ 322 k = tid / (dstx * dsty); 323 j = (tid - k * dstx * dsty) / dstx; 324 i = tid - k * dstx * dsty - j * dstx; 325 t = dstStart + k * dstX * dstY + j * dstX + i; 326 } else { /* contiguous */ 327 t = dstIdx[tid]; 328 } 329 330 s *= MBS; 331 t *= MBS; 332 for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]); 333 }); 334 PetscFunctionReturn(0); 335 } 336 337 /* Specialization for Insert since we may use memcpy */ 338 template <typename Type, PetscInt BS, PetscInt EQ> 339 static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) { 340 const Type *src = static_cast<const Type *>(src_); 341 Type *dst = static_cast<Type *>(dst_); 342 DeviceExecutionSpace exec; 343 344 PetscFunctionBegin; 345 if (!count) PetscFunctionReturn(0); 346 /*src and dst are contiguous */ 347 if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) { 348 size_t sz = count * link->unitbytes; 349 deviceBuffer_t dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz); 350 deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz); 351 Kokkos::deep_copy(exec, dbuf, sbuf); 352 } else { 353 PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst)); 354 } 355 PetscFunctionReturn(0); 356 } 357 358 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 359 static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_) { 360 Op op; 361 const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS; 362 const PetscInt *ropt = rootopt ? rootopt->array : NULL; 363 const PetscInt *lopt = leafopt ? leafopt->array : NULL; 364 Type *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_); 365 const Type *leafdata = static_cast<const Type *>(leafdata_); 366 DeviceExecutionSpace exec; 367 368 PetscFunctionBegin; 369 Kokkos::parallel_for( 370 Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 371 PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS; 372 PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS; 373 for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]); 374 }); 375 PetscFunctionReturn(0); 376 } 377 378 /*====================================================================================*/ 379 /* Init various types and instantiate pack/unpack function pointers */ 380 /*====================================================================================*/ 381 template <typename Type, PetscInt BS, PetscInt EQ> 382 static void PackInit_RealType(PetscSFLink link) { 383 /* Pack/unpack for remote communication */ 384 link->d_Pack = Pack<Type, BS, EQ>; 385 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 386 link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>; 387 link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>; 388 link->d_UnpackAndMin = UnpackAndOp<Type, Min<Type>, BS, EQ>; 389 link->d_UnpackAndMax = UnpackAndOp<Type, Max<Type>, BS, EQ>; 390 link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>; 391 /* Scatter for local communication */ 392 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */ 393 link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>; 394 link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>; 395 link->d_ScatterAndMin = ScatterAndOp<Type, Min<Type>, BS, EQ>; 396 link->d_ScatterAndMax = ScatterAndOp<Type, Max<Type>, BS, EQ>; 397 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>; 398 /* Atomic versions when there are data-race possibilities */ 399 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 400 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 401 link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>; 402 link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>; 403 link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>; 404 link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>; 405 406 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 407 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 408 link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>; 409 link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>; 410 link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>; 411 link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>; 412 } 413 414 template <typename Type, PetscInt BS, PetscInt EQ> 415 static void PackInit_IntegerType(PetscSFLink link) { 416 link->d_Pack = Pack<Type, BS, EQ>; 417 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 418 link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>; 419 link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>; 420 link->d_UnpackAndMin = UnpackAndOp<Type, Min<Type>, BS, EQ>; 421 link->d_UnpackAndMax = UnpackAndOp<Type, Max<Type>, BS, EQ>; 422 link->d_UnpackAndLAND = UnpackAndOp<Type, LAND<Type>, BS, EQ>; 423 link->d_UnpackAndLOR = UnpackAndOp<Type, LOR<Type>, BS, EQ>; 424 link->d_UnpackAndLXOR = UnpackAndOp<Type, LXOR<Type>, BS, EQ>; 425 link->d_UnpackAndBAND = UnpackAndOp<Type, BAND<Type>, BS, EQ>; 426 link->d_UnpackAndBOR = UnpackAndOp<Type, BOR<Type>, BS, EQ>; 427 link->d_UnpackAndBXOR = UnpackAndOp<Type, BXOR<Type>, BS, EQ>; 428 link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>; 429 430 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 431 link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>; 432 link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>; 433 link->d_ScatterAndMin = ScatterAndOp<Type, Min<Type>, BS, EQ>; 434 link->d_ScatterAndMax = ScatterAndOp<Type, Max<Type>, BS, EQ>; 435 link->d_ScatterAndLAND = ScatterAndOp<Type, LAND<Type>, BS, EQ>; 436 link->d_ScatterAndLOR = ScatterAndOp<Type, LOR<Type>, BS, EQ>; 437 link->d_ScatterAndLXOR = ScatterAndOp<Type, LXOR<Type>, BS, EQ>; 438 link->d_ScatterAndBAND = ScatterAndOp<Type, BAND<Type>, BS, EQ>; 439 link->d_ScatterAndBOR = ScatterAndOp<Type, BOR<Type>, BS, EQ>; 440 link->d_ScatterAndBXOR = ScatterAndOp<Type, BXOR<Type>, BS, EQ>; 441 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>; 442 443 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 444 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 445 link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>; 446 link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>; 447 link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>; 448 link->da_UnpackAndLAND = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>; 449 link->da_UnpackAndLOR = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>; 450 link->da_UnpackAndBAND = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>; 451 link->da_UnpackAndBOR = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>; 452 link->da_UnpackAndBXOR = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>; 453 link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>; 454 455 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 456 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 457 link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>; 458 link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>; 459 link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>; 460 link->da_ScatterAndLAND = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>; 461 link->da_ScatterAndLOR = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>; 462 link->da_ScatterAndBAND = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>; 463 link->da_ScatterAndBOR = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>; 464 link->da_ScatterAndBXOR = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>; 465 link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>; 466 } 467 468 #if defined(PETSC_HAVE_COMPLEX) 469 template <typename Type, PetscInt BS, PetscInt EQ> 470 static void PackInit_ComplexType(PetscSFLink link) { 471 link->d_Pack = Pack<Type, BS, EQ>; 472 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 473 link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>; 474 link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>; 475 link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>; 476 477 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 478 link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>; 479 link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>; 480 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>; 481 482 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 483 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 484 link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>; 485 link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>; 486 487 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 488 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 489 link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>; 490 link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>; 491 } 492 #endif 493 494 template <typename Type> 495 static void PackInit_PairType(PetscSFLink link) { 496 link->d_Pack = Pack<Type, 1, 1>; 497 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>; 498 link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>; 499 link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>; 500 501 link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>; 502 link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>; 503 link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>; 504 /* Atomics for pair types are not implemented yet */ 505 } 506 507 template <typename Type, PetscInt BS, PetscInt EQ> 508 static void PackInit_DumbType(PetscSFLink link) { 509 link->d_Pack = Pack<Type, BS, EQ>; 510 link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 511 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 512 /* Atomics for dumb types are not implemented yet */ 513 } 514 515 /* 516 Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug 517 that one is not able to repeatedly create and destroy the object. SF's original design was each 518 SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from 519 destroying multiple SFLinks with NULL stream and the default execution space object. To avoid 520 memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos 521 does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton 522 object in Kokkos. 523 */ 524 /* 525 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link) 526 { 527 PetscFunctionBegin; 528 PetscFunctionReturn(0); 529 } 530 */ 531 532 /* Some device-specific utilities */ 533 static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link) { 534 PetscFunctionBegin; 535 Kokkos::fence(); 536 PetscFunctionReturn(0); 537 } 538 539 static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link) { 540 DeviceExecutionSpace exec; 541 PetscFunctionBegin; 542 exec.fence(); 543 PetscFunctionReturn(0); 544 } 545 546 static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n) { 547 DeviceExecutionSpace exec; 548 549 PetscFunctionBegin; 550 if (!n) PetscFunctionReturn(0); 551 if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { 552 PetscCall(PetscMemcpy(dst, src, n)); 553 } else { 554 if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) { 555 deviceBuffer_t dbuf(static_cast<char *>(dst), n); 556 HostConstBuffer_t sbuf(static_cast<const char *>(src), n); 557 Kokkos::deep_copy(exec, dbuf, sbuf); 558 PetscCall(PetscLogCpuToGpu(n)); 559 } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) { 560 HostBuffer_t dbuf(static_cast<char *>(dst), n); 561 deviceConstBuffer_t sbuf(static_cast<const char *>(src), n); 562 Kokkos::deep_copy(exec, dbuf, sbuf); 563 PetscCall(PetscLogGpuToCpu(n)); 564 } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) { 565 deviceBuffer_t dbuf(static_cast<char *>(dst), n); 566 deviceConstBuffer_t sbuf(static_cast<const char *>(src), n); 567 Kokkos::deep_copy(exec, dbuf, sbuf); 568 } 569 } 570 PetscFunctionReturn(0); 571 } 572 573 PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr) { 574 PetscFunctionBegin; 575 if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr)); 576 else if (PetscMemTypeDevice(mtype)) { 577 if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck()); 578 *ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size); 579 } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype); 580 PetscFunctionReturn(0); 581 } 582 583 PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr) { 584 PetscFunctionBegin; 585 if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr)); 586 else if (PetscMemTypeDevice(mtype)) { 587 Kokkos::kokkos_free<DeviceMemorySpace>(ptr); 588 } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype); 589 PetscFunctionReturn(0); 590 } 591 592 /* Destructor when the link uses MPI for communication */ 593 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link) { 594 PetscFunctionBegin; 595 for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) { 596 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE])); 597 PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE])); 598 } 599 PetscFunctionReturn(0); 600 } 601 602 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */ 603 PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit) { 604 PetscInt nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0; 605 PetscBool is2Int, is2PetscInt; 606 #if defined(PETSC_HAVE_COMPLEX) 607 PetscInt nPetscComplex = 0; 608 #endif 609 610 PetscFunctionBegin; 611 if (link->deviceinited) PetscFunctionReturn(0); 612 PetscCall(PetscKokkosInitializeCheck()); 613 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar)); 614 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar)); 615 /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */ 616 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt)); 617 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt)); 618 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal)); 619 #if defined(PETSC_HAVE_COMPLEX) 620 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex)); 621 #endif 622 PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int)); 623 PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt)); 624 625 if (is2Int) { 626 PackInit_PairType<Kokkos::pair<int, int>>(link); 627 } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */ 628 PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link); 629 } else if (nPetscReal) { 630 #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */ 631 if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link); 632 else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link); 633 else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link); 634 else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link); 635 else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link); 636 else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link); 637 else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link); 638 else if (nPetscReal % 1 == 0) 639 #endif 640 PackInit_RealType<PetscReal, 1, 0>(link); 641 } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) { 642 #if !defined(PETSC_HAVE_DEVICE) 643 if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link); 644 else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link); 645 else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link); 646 else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link); 647 else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link); 648 else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link); 649 else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link); 650 else if (nPetscInt % 1 == 0) 651 #endif 652 PackInit_IntegerType<llint, 1, 0>(link); 653 } else if (nInt) { 654 #if !defined(PETSC_HAVE_DEVICE) 655 if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link); 656 else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link); 657 else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link); 658 else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link); 659 else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link); 660 else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link); 661 else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link); 662 else if (nInt % 1 == 0) 663 #endif 664 PackInit_IntegerType<int, 1, 0>(link); 665 } else if (nSignedChar) { 666 #if !defined(PETSC_HAVE_DEVICE) 667 if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link); 668 else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link); 669 else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link); 670 else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link); 671 else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link); 672 else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link); 673 else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link); 674 else if (nSignedChar % 1 == 0) 675 #endif 676 PackInit_IntegerType<char, 1, 0>(link); 677 } else if (nUnsignedChar) { 678 #if !defined(PETSC_HAVE_DEVICE) 679 if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link); 680 else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link); 681 else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link); 682 else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link); 683 else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link); 684 else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link); 685 else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link); 686 else if (nUnsignedChar % 1 == 0) 687 #endif 688 PackInit_IntegerType<unsigned char, 1, 0>(link); 689 #if defined(PETSC_HAVE_COMPLEX) 690 } else if (nPetscComplex) { 691 #if !defined(PETSC_HAVE_DEVICE) 692 if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link); 693 else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link); 694 else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link); 695 else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link); 696 else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link); 697 else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link); 698 else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link); 699 else if (nPetscComplex % 1 == 0) 700 #endif 701 PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link); 702 #endif 703 } else { 704 MPI_Aint lb, nbyte; 705 PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte)); 706 PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb); 707 if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */ 708 #if !defined(PETSC_HAVE_DEVICE) 709 if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link); 710 else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link); 711 else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link); 712 else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link); 713 else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link); 714 else if (nbyte % 1 == 0) 715 #endif 716 PackInit_DumbType<char, 1, 0>(link); 717 } else { 718 nInt = nbyte / sizeof(int); 719 #if !defined(PETSC_HAVE_DEVICE) 720 if (nInt == 8) PackInit_DumbType<int, 8, 1>(link); 721 else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link); 722 else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link); 723 else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link); 724 else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link); 725 else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link); 726 else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link); 727 else if (nInt % 1 == 0) 728 #endif 729 PackInit_DumbType<int, 1, 0>(link); 730 } 731 } 732 733 link->SyncDevice = PetscSFLinkSyncDevice_Kokkos; 734 link->SyncStream = PetscSFLinkSyncStream_Kokkos; 735 link->Memcpy = PetscSFLinkMemcpy_Kokkos; 736 link->Destroy = PetscSFLinkDestroy_Kokkos; 737 link->deviceinited = PETSC_TRUE; 738 PetscFunctionReturn(0); 739 } 740