1 #pragma once 2 3 #include "sfcupm.hpp" 4 #include <../src/sys/objects/device/impls/cupm/kernels.hpp> 5 #include <petsc/private/cupmatomics.hpp> 6 7 namespace Petsc 8 { 9 10 namespace sf 11 { 12 13 namespace cupm 14 { 15 16 namespace kernels 17 { 18 19 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */ 20 PETSC_NODISCARD static PETSC_DEVICE_INLINE_DECL PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid) noexcept 21 { 22 PetscInt i, j, k, m, n, r; 23 const PetscInt *offset, *start, *dx, *dy, *X, *Y; 24 25 n = opt[0]; 26 offset = opt + 1; 27 start = opt + n + 2; 28 dx = opt + 2 * n + 2; 29 dy = opt + 3 * n + 2; 30 X = opt + 5 * n + 2; 31 Y = opt + 6 * n + 2; 32 for (r = 0; r < n; r++) { 33 if (tid < offset[r + 1]) break; 34 } 35 m = (tid - offset[r]); 36 k = m / (dx[r] * dy[r]); 37 j = (m - k * dx[r] * dy[r]) / dx[r]; 38 i = m - k * dx[r] * dy[r] - j * dx[r]; 39 40 return start[r] + k * X[r] * Y[r] + j * X[r] + i; 41 } 42 43 /*====================================================================================*/ 44 /* Templated CUPM kernels for pack/unpack. The Op can be regular or atomic */ 45 /*====================================================================================*/ 46 47 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then 48 <Type> is PetscReal, which is the primitive type we operate on. 49 <bs> is 16, which says <unit> contains 16 primitive types. 50 <BS> is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>. 51 <EQ> is 0, which is (bs == BS ? 1 : 0) 52 53 If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant. 54 For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled. 55 */ 56 template <class Type, PetscInt BS, PetscInt EQ> 57 PETSC_KERNEL_DECL static void d_Pack(PetscInt bs, PetscInt count, PetscInt start, const PetscInt *opt, const PetscInt *idx, const Type *data, Type *buf) 58 { 59 const PetscInt M = (EQ) ? 1 : bs / BS; /* If EQ, then M=1 enables compiler's const-propagation */ 60 const PetscInt MBS = M * BS; /* MBS=bs. We turn MBS into a compile-time const when EQ=1. */ 61 62 ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) { 63 PetscInt t = (opt ? MapTidToIndex(opt, tid) : (idx ? idx[tid] : start + tid)) * MBS; 64 PetscInt s = tid * MBS; 65 for (PetscInt i = 0; i < MBS; i++) buf[s + i] = data[t + i]; 66 }); 67 } 68 69 template <class Type, class Op, PetscInt BS, PetscInt EQ> 70 PETSC_KERNEL_DECL static void d_UnpackAndOp(PetscInt bs, PetscInt count, PetscInt start, const PetscInt *opt, const PetscInt *idx, Type *data, const Type *buf) 71 { 72 const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS; 73 Op op; 74 75 ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) { 76 PetscInt t = (opt ? MapTidToIndex(opt, tid) : (idx ? idx[tid] : start + tid)) * MBS; 77 PetscInt s = tid * MBS; 78 for (PetscInt i = 0; i < MBS; i++) op(data[t + i], buf[s + i]); 79 }); 80 } 81 82 template <class Type, class Op, PetscInt BS, PetscInt EQ> 83 PETSC_KERNEL_DECL static void d_FetchAndOp(PetscInt bs, PetscInt count, PetscInt rootstart, const PetscInt *rootopt, const PetscInt *rootidx, Type *rootdata, Type *leafbuf) 84 { 85 const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS; 86 Op op; 87 88 ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) { 89 PetscInt r = (rootopt ? MapTidToIndex(rootopt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS; 90 PetscInt l = tid * MBS; 91 for (PetscInt i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]); 92 }); 93 } 94 95 template <class Type, class Op, PetscInt BS, PetscInt EQ> 96 PETSC_KERNEL_DECL static void d_ScatterAndOp(PetscInt bs, PetscInt count, PetscInt srcx, PetscInt srcy, PetscInt srcX, PetscInt srcY, PetscInt srcStart, const PetscInt *srcIdx, const Type *src, PetscInt dstx, PetscInt dsty, PetscInt dstX, PetscInt dstY, PetscInt dstStart, const PetscInt *dstIdx, Type *dst) 97 { 98 const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS; 99 Op op; 100 101 ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) { 102 PetscInt s, t; 103 104 if (!srcIdx) { /* src is either contiguous or 3D */ 105 PetscInt k = tid / (srcx * srcy); 106 PetscInt j = (tid - k * srcx * srcy) / srcx; 107 PetscInt i = tid - k * srcx * srcy - j * srcx; 108 109 s = srcStart + k * srcX * srcY + j * srcX + i; 110 } else { 111 s = srcIdx[tid]; 112 } 113 114 if (!dstIdx) { /* dst is either contiguous or 3D */ 115 PetscInt k = tid / (dstx * dsty); 116 PetscInt j = (tid - k * dstx * dsty) / dstx; 117 PetscInt i = tid - k * dstx * dsty - j * dstx; 118 119 t = dstStart + k * dstX * dstY + j * dstX + i; 120 } else { 121 t = dstIdx[tid]; 122 } 123 124 s *= MBS; 125 t *= MBS; 126 for (PetscInt i = 0; i < MBS; i++) op(dst[t + i], src[s + i]); 127 }); 128 } 129 130 template <class Type, class Op, PetscInt BS, PetscInt EQ> 131 PETSC_KERNEL_DECL static void d_FetchAndOpLocal(PetscInt bs, PetscInt count, PetscInt rootstart, const PetscInt *rootopt, const PetscInt *rootidx, Type *rootdata, PetscInt leafstart, const PetscInt *leafopt, const PetscInt *leafidx, const Type *leafdata, Type *leafupdate) 132 { 133 const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS; 134 Op op; 135 136 ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) { 137 PetscInt r = (rootopt ? MapTidToIndex(rootopt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS; 138 PetscInt l = (leafopt ? MapTidToIndex(leafopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS; 139 for (PetscInt i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]); 140 }); 141 } 142 143 /*====================================================================================*/ 144 /* Regular operations on device */ 145 /*====================================================================================*/ 146 template <typename Type> 147 struct Insert { 148 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 149 { 150 Type old = x; 151 x = y; 152 return old; 153 } 154 }; 155 template <typename Type> 156 struct Add { 157 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 158 { 159 Type old = x; 160 x += y; 161 return old; 162 } 163 }; 164 template <typename Type> 165 struct Mult { 166 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 167 { 168 Type old = x; 169 x *= y; 170 return old; 171 } 172 }; 173 template <typename Type> 174 struct Min { 175 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 176 { 177 Type old = x; 178 x = PetscMin(x, y); 179 return old; 180 } 181 }; 182 template <typename Type> 183 struct Max { 184 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 185 { 186 Type old = x; 187 x = PetscMax(x, y); 188 return old; 189 } 190 }; 191 template <typename Type> 192 struct LAND { 193 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 194 { 195 Type old = x; 196 x = x && y; 197 return old; 198 } 199 }; 200 template <typename Type> 201 struct LOR { 202 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 203 { 204 Type old = x; 205 x = x || y; 206 return old; 207 } 208 }; 209 template <typename Type> 210 struct LXOR { 211 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 212 { 213 Type old = x; 214 x = !x != !y; 215 return old; 216 } 217 }; 218 template <typename Type> 219 struct BAND { 220 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 221 { 222 Type old = x; 223 x = x & y; 224 return old; 225 } 226 }; 227 template <typename Type> 228 struct BOR { 229 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 230 { 231 Type old = x; 232 x = x | y; 233 return old; 234 } 235 }; 236 template <typename Type> 237 struct BXOR { 238 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 239 { 240 Type old = x; 241 x = x ^ y; 242 return old; 243 } 244 }; 245 template <typename Type> 246 struct Minloc { 247 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 248 { 249 Type old = x; 250 if (y.a < x.a) x = y; 251 else if (y.a == x.a) x.b = min(x.b, y.b); 252 return old; 253 } 254 }; 255 template <typename Type> 256 struct Maxloc { 257 PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const 258 { 259 Type old = x; 260 if (y.a > x.a) x = y; 261 else if (y.a == x.a) x.b = min(x.b, y.b); /* See MPI MAXLOC */ 262 return old; 263 } 264 }; 265 266 } // namespace kernels 267 268 namespace impl 269 { 270 271 /*====================================================================================*/ 272 /* Wrapper functions of cupm kernels. Function pointers are stored in 'link' */ 273 /*====================================================================================*/ 274 template <device::cupm::DeviceType T> 275 template <typename Type, PetscInt BS, PetscInt EQ> 276 inline PetscErrorCode SfInterface<T>::Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data, void *buf) noexcept 277 { 278 const PetscInt *iarray = opt ? opt->array : NULL; 279 280 PetscFunctionBegin; 281 if (!count) PetscFunctionReturn(PETSC_SUCCESS); 282 if (PetscDefined(USING_NVCC) && !opt && !idx) { /* It is a 'CUDA data to nvshmem buf' memory copy */ 283 PetscCallCUPM(cupmMemcpyAsync(buf, (char *)data + start * link->unitbytes, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream)); 284 } else { 285 PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_Pack<Type, BS, EQ>, link->bs, count, start, iarray, idx, (const Type *)data, (Type *)buf)); 286 } 287 PetscFunctionReturn(PETSC_SUCCESS); 288 } 289 290 template <device::cupm::DeviceType T> 291 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 292 inline PetscErrorCode SfInterface<T>::UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, const void *buf) noexcept 293 { 294 const PetscInt *iarray = opt ? opt->array : NULL; 295 296 PetscFunctionBegin; 297 if (!count) PetscFunctionReturn(PETSC_SUCCESS); 298 if (PetscDefined(USING_NVCC) && std::is_same<Op, kernels::Insert<Type>>::value && !opt && !idx) { /* It is a 'nvshmem buf to CUDA data' memory copy */ 299 PetscCallCUPM(cupmMemcpyAsync((char *)data + start * link->unitbytes, buf, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream)); 300 } else { 301 PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_UnpackAndOp<Type, Op, BS, EQ>, link->bs, count, start, iarray, idx, (Type *)data, (const Type *)buf)); 302 } 303 PetscFunctionReturn(PETSC_SUCCESS); 304 } 305 306 template <device::cupm::DeviceType T> 307 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 308 inline PetscErrorCode SfInterface<T>::FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf) noexcept 309 { 310 const PetscInt *iarray = opt ? opt->array : NULL; 311 312 PetscFunctionBegin; 313 if (!count) PetscFunctionReturn(PETSC_SUCCESS); 314 PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_FetchAndOp<Type, Op, BS, EQ>, link->bs, count, start, iarray, idx, (Type *)data, (const Type *)buf)); 315 PetscFunctionReturn(PETSC_SUCCESS); 316 } 317 318 template <device::cupm::DeviceType T> 319 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 320 inline PetscErrorCode SfInterface<T>::ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst) noexcept 321 { 322 PetscInt nthreads = 256; 323 PetscInt nblocks = (count + nthreads - 1) / nthreads; 324 PetscInt srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0; 325 326 PetscFunctionBegin; 327 if (!count) PetscFunctionReturn(PETSC_SUCCESS); 328 nblocks = PetscMin(nblocks, link->maxResidentThreadsPerGPU / nthreads); 329 330 /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use 3D grid and block */ 331 if (srcOpt) { 332 srcx = srcOpt->dx[0]; 333 srcy = srcOpt->dy[0]; 334 srcX = srcOpt->X[0]; 335 srcY = srcOpt->Y[0]; 336 srcStart = srcOpt->start[0]; 337 srcIdx = NULL; 338 } else if (!srcIdx) { 339 srcx = srcX = count; 340 srcy = srcY = 1; 341 } 342 343 if (dstOpt) { 344 dstx = dstOpt->dx[0]; 345 dsty = dstOpt->dy[0]; 346 dstX = dstOpt->X[0]; 347 dstY = dstOpt->Y[0]; 348 dstStart = dstOpt->start[0]; 349 dstIdx = NULL; 350 } else if (!dstIdx) { 351 dstx = dstX = count; 352 dsty = dstY = 1; 353 } 354 355 PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_ScatterAndOp<Type, Op, BS, EQ>, link->bs, count, srcx, srcy, srcX, srcY, srcStart, srcIdx, (const Type *)src, dstx, dsty, dstX, dstY, dstStart, dstIdx, (Type *)dst)); 356 PetscFunctionReturn(PETSC_SUCCESS); 357 } 358 359 template <device::cupm::DeviceType T> 360 /* Specialization for Insert since we may use cupmMemcpyAsync */ 361 template <typename Type, PetscInt BS, PetscInt EQ> 362 inline PetscErrorCode SfInterface<T>::ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst) noexcept 363 { 364 PetscFunctionBegin; 365 if (!count) PetscFunctionReturn(PETSC_SUCCESS); 366 /*src and dst are contiguous */ 367 if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) { 368 PetscCallCUPM(cupmMemcpyAsync((Type *)dst + dstStart * link->bs, (const Type *)src + srcStart * link->bs, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream)); 369 } else { 370 PetscCall(ScatterAndOp<Type, kernels::Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst)); 371 } 372 PetscFunctionReturn(PETSC_SUCCESS); 373 } 374 375 template <device::cupm::DeviceType T> 376 template <typename Type, class Op, PetscInt BS, PetscInt EQ> 377 inline PetscErrorCode SfInterface<T>::FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata, void *leafupdate) noexcept 378 { 379 const PetscInt *rarray = rootopt ? rootopt->array : NULL; 380 const PetscInt *larray = leafopt ? leafopt->array : NULL; 381 382 PetscFunctionBegin; 383 if (!count) PetscFunctionReturn(PETSC_SUCCESS); 384 PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_FetchAndOpLocal<Type, Op, BS, EQ>, link->bs, count, rootstart, rarray, rootidx, (Type *)rootdata, leafstart, larray, leafidx, (const Type *)leafdata, (Type *)leafupdate)); 385 PetscFunctionReturn(PETSC_SUCCESS); 386 } 387 388 /*====================================================================================*/ 389 /* Init various types and instantiate pack/unpack function pointers */ 390 /*====================================================================================*/ 391 template <device::cupm::DeviceType T> 392 template <typename Type, PetscInt BS, PetscInt EQ> 393 inline void SfInterface<T>::PackInit_RealType(PetscSFLink link) noexcept 394 { 395 /* Pack/unpack for remote communication */ 396 link->d_Pack = Pack<Type, BS, EQ>; 397 link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>; 398 link->d_UnpackAndAdd = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>; 399 link->d_UnpackAndMult = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>; 400 link->d_UnpackAndMin = UnpackAndOp<Type, kernels::Min<Type>, BS, EQ>; 401 link->d_UnpackAndMax = UnpackAndOp<Type, kernels::Max<Type>, BS, EQ>; 402 link->d_FetchAndAdd = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>; 403 404 /* Scatter for local communication */ 405 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */ 406 link->d_ScatterAndAdd = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>; 407 link->d_ScatterAndMult = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>; 408 link->d_ScatterAndMin = ScatterAndOp<Type, kernels::Min<Type>, BS, EQ>; 409 link->d_ScatterAndMax = ScatterAndOp<Type, kernels::Max<Type>, BS, EQ>; 410 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>; 411 412 /* Atomic versions when there are data-race possibilities */ 413 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 414 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 415 link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>; 416 link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>; 417 link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>; 418 link->da_FetchAndAdd = FetchAndOp<Type, AtomicAdd<Type>, BS, EQ>; 419 420 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 421 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 422 link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>; 423 link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>; 424 link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>; 425 link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicAdd<Type>, BS, EQ>; 426 } 427 428 /* Have this templated class to specialize for char integers */ 429 template <device::cupm::DeviceType T> 430 template <typename Type, PetscInt BS, PetscInt EQ, PetscInt size /*sizeof(Type)*/> 431 struct SfInterface<T>::PackInit_IntegerType_Atomic { 432 static inline void Init(PetscSFLink link) noexcept 433 { 434 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 435 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 436 link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>; 437 link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>; 438 link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>; 439 link->da_UnpackAndLAND = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>; 440 link->da_UnpackAndLOR = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>; 441 link->da_UnpackAndLXOR = UnpackAndOp<Type, AtomicLXOR<Type>, BS, EQ>; 442 link->da_UnpackAndBAND = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>; 443 link->da_UnpackAndBOR = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>; 444 link->da_UnpackAndBXOR = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>; 445 link->da_FetchAndAdd = FetchAndOp<Type, AtomicAdd<Type>, BS, EQ>; 446 447 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 448 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 449 link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>; 450 link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>; 451 link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>; 452 link->da_ScatterAndLAND = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>; 453 link->da_ScatterAndLOR = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>; 454 link->da_ScatterAndLXOR = ScatterAndOp<Type, AtomicLXOR<Type>, BS, EQ>; 455 link->da_ScatterAndBAND = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>; 456 link->da_ScatterAndBOR = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>; 457 link->da_ScatterAndBXOR = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>; 458 link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicAdd<Type>, BS, EQ>; 459 } 460 }; 461 462 /* CUDA does not support atomics on chars. It is TBD in PETSc. */ 463 template <device::cupm::DeviceType T> 464 template <typename Type, PetscInt BS, PetscInt EQ> 465 struct SfInterface<T>::PackInit_IntegerType_Atomic<Type, BS, EQ, 1> { 466 static inline void Init(PetscSFLink) 467 { /* Nothing to leave function pointers NULL */ 468 } 469 }; 470 471 template <device::cupm::DeviceType T> 472 template <typename Type, PetscInt BS, PetscInt EQ> 473 inline void SfInterface<T>::PackInit_IntegerType(PetscSFLink link) noexcept 474 { 475 link->d_Pack = Pack<Type, BS, EQ>; 476 link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>; 477 link->d_UnpackAndAdd = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>; 478 link->d_UnpackAndMult = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>; 479 link->d_UnpackAndMin = UnpackAndOp<Type, kernels::Min<Type>, BS, EQ>; 480 link->d_UnpackAndMax = UnpackAndOp<Type, kernels::Max<Type>, BS, EQ>; 481 link->d_UnpackAndLAND = UnpackAndOp<Type, kernels::LAND<Type>, BS, EQ>; 482 link->d_UnpackAndLOR = UnpackAndOp<Type, kernels::LOR<Type>, BS, EQ>; 483 link->d_UnpackAndLXOR = UnpackAndOp<Type, kernels::LXOR<Type>, BS, EQ>; 484 link->d_UnpackAndBAND = UnpackAndOp<Type, kernels::BAND<Type>, BS, EQ>; 485 link->d_UnpackAndBOR = UnpackAndOp<Type, kernels::BOR<Type>, BS, EQ>; 486 link->d_UnpackAndBXOR = UnpackAndOp<Type, kernels::BXOR<Type>, BS, EQ>; 487 link->d_FetchAndAdd = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>; 488 489 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 490 link->d_ScatterAndAdd = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>; 491 link->d_ScatterAndMult = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>; 492 link->d_ScatterAndMin = ScatterAndOp<Type, kernels::Min<Type>, BS, EQ>; 493 link->d_ScatterAndMax = ScatterAndOp<Type, kernels::Max<Type>, BS, EQ>; 494 link->d_ScatterAndLAND = ScatterAndOp<Type, kernels::LAND<Type>, BS, EQ>; 495 link->d_ScatterAndLOR = ScatterAndOp<Type, kernels::LOR<Type>, BS, EQ>; 496 link->d_ScatterAndLXOR = ScatterAndOp<Type, kernels::LXOR<Type>, BS, EQ>; 497 link->d_ScatterAndBAND = ScatterAndOp<Type, kernels::BAND<Type>, BS, EQ>; 498 link->d_ScatterAndBOR = ScatterAndOp<Type, kernels::BOR<Type>, BS, EQ>; 499 link->d_ScatterAndBXOR = ScatterAndOp<Type, kernels::BXOR<Type>, BS, EQ>; 500 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>; 501 PackInit_IntegerType_Atomic<Type, BS, EQ, sizeof(Type)>::Init(link); 502 } 503 504 #if defined(PETSC_HAVE_COMPLEX) 505 template <device::cupm::DeviceType T> 506 template <typename Type, PetscInt BS, PetscInt EQ> 507 inline void SfInterface<T>::PackInit_ComplexType(PetscSFLink link) noexcept 508 { 509 link->d_Pack = Pack<Type, BS, EQ>; 510 link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>; 511 link->d_UnpackAndAdd = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>; 512 link->d_UnpackAndMult = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>; 513 link->d_FetchAndAdd = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>; 514 515 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 516 link->d_ScatterAndAdd = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>; 517 link->d_ScatterAndMult = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>; 518 link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>; 519 520 link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 521 link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 522 link->da_UnpackAndMult = NULL; /* Not implemented yet */ 523 link->da_FetchAndAdd = NULL; /* Return value of atomicAdd on complex is not atomic */ 524 525 link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 526 link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 527 } 528 #endif 529 530 typedef signed char SignedChar; 531 typedef unsigned char UnsignedChar; 532 typedef struct { 533 int a; 534 int b; 535 } PairInt; 536 typedef struct { 537 PetscInt a; 538 PetscInt b; 539 } PairPetscInt; 540 541 template <device::cupm::DeviceType T> 542 template <typename Type> 543 inline void SfInterface<T>::PackInit_PairType(PetscSFLink link) noexcept 544 { 545 link->d_Pack = Pack<Type, 1, 1>; 546 link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, 1, 1>; 547 link->d_UnpackAndMaxloc = UnpackAndOp<Type, kernels::Maxloc<Type>, 1, 1>; 548 link->d_UnpackAndMinloc = UnpackAndOp<Type, kernels::Minloc<Type>, 1, 1>; 549 550 link->d_ScatterAndInsert = ScatterAndOp<Type, kernels::Insert<Type>, 1, 1>; 551 link->d_ScatterAndMaxloc = ScatterAndOp<Type, kernels::Maxloc<Type>, 1, 1>; 552 link->d_ScatterAndMinloc = ScatterAndOp<Type, kernels::Minloc<Type>, 1, 1>; 553 /* Atomics for pair types are not implemented yet */ 554 } 555 556 template <device::cupm::DeviceType T> 557 template <typename Type, PetscInt BS, PetscInt EQ> 558 inline void SfInterface<T>::PackInit_DumbType(PetscSFLink link) noexcept 559 { 560 link->d_Pack = Pack<Type, BS, EQ>; 561 link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>; 562 link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 563 /* Atomics for dumb types are not implemented yet */ 564 } 565 566 /* Some device-specific utilities */ 567 template <device::cupm::DeviceType T> 568 inline PetscErrorCode SfInterface<T>::LinkSyncDevice(PetscSFLink) noexcept 569 { 570 PetscFunctionBegin; 571 PetscCallCUPM(cupmDeviceSynchronize()); 572 PetscFunctionReturn(PETSC_SUCCESS); 573 } 574 575 template <device::cupm::DeviceType T> 576 inline PetscErrorCode SfInterface<T>::LinkSyncStream(PetscSFLink link) noexcept 577 { 578 PetscFunctionBegin; 579 PetscCallCUPM(cupmStreamSynchronize(link->stream)); 580 PetscFunctionReturn(PETSC_SUCCESS); 581 } 582 583 template <device::cupm::DeviceType T> 584 inline PetscErrorCode SfInterface<T>::LinkMemcpy(PetscSFLink link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n) noexcept 585 { 586 PetscFunctionBegin; 587 cupmMemcpyKind_t kinds[2][2] = { 588 {cupmMemcpyHostToHost, cupmMemcpyHostToDevice }, 589 {cupmMemcpyDeviceToHost, cupmMemcpyDeviceToDevice} 590 }; 591 592 if (n) { 593 if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { /* Separate HostToHost so that pure-cpu code won't call cupm runtime */ 594 PetscCall(PetscMemcpy(dst, src, n)); 595 } else { 596 int stype = PetscMemTypeDevice(srcmtype) ? 1 : 0; 597 int dtype = PetscMemTypeDevice(dstmtype) ? 1 : 0; 598 PetscCallCUPM(cupmMemcpyAsync(dst, src, n, kinds[stype][dtype], link->stream)); 599 } 600 } 601 PetscFunctionReturn(PETSC_SUCCESS); 602 } 603 604 template <device::cupm::DeviceType T> 605 inline PetscErrorCode SfInterface<T>::Malloc(PetscMemType mtype, size_t size, void **ptr) noexcept 606 { 607 PetscFunctionBegin; 608 if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr)); 609 else if (PetscMemTypeDevice(mtype)) { 610 PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM())); 611 PetscCallCUPM(cupmMalloc(ptr, size)); 612 } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype); 613 PetscFunctionReturn(PETSC_SUCCESS); 614 } 615 616 template <device::cupm::DeviceType T> 617 inline PetscErrorCode SfInterface<T>::Free(PetscMemType mtype, void *ptr) noexcept 618 { 619 PetscFunctionBegin; 620 if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr)); 621 else if (PetscMemTypeDevice(mtype)) PetscCallCUPM(cupmFree(ptr)); 622 else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype); 623 PetscFunctionReturn(PETSC_SUCCESS); 624 } 625 626 /* Destructor when the link uses MPI for communication on CUPM device */ 627 template <device::cupm::DeviceType T> 628 inline PetscErrorCode SfInterface<T>::LinkDestroy_MPI(PetscSF, PetscSFLink link) noexcept 629 { 630 PetscFunctionBegin; 631 for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) { 632 PetscCallCUPM(cupmFree(link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE])); 633 PetscCallCUPM(cupmFree(link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE])); 634 } 635 PetscFunctionReturn(PETSC_SUCCESS); 636 } 637 638 /*====================================================================================*/ 639 /* Main driver to init MPI datatype on device */ 640 /*====================================================================================*/ 641 642 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */ 643 template <device::cupm::DeviceType T> 644 inline PetscErrorCode SfInterface<T>::LinkSetUp(PetscSF sf, PetscSFLink link, MPI_Datatype unit) noexcept 645 { 646 PetscInt nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0; 647 PetscBool is2Int, is2PetscInt; 648 #if defined(PETSC_HAVE_COMPLEX) 649 PetscInt nPetscComplex = 0; 650 #endif 651 652 PetscFunctionBegin; 653 if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS); 654 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar)); 655 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar)); 656 /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */ 657 PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt)); 658 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt)); 659 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal)); 660 #if defined(PETSC_HAVE_COMPLEX) 661 PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex)); 662 #endif 663 PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int)); 664 PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt)); 665 666 if (is2Int) { 667 PackInit_PairType<PairInt>(link); 668 } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */ 669 PackInit_PairType<PairPetscInt>(link); 670 } else if (nPetscReal) { 671 #if !defined(PETSC_HAVE_DEVICE) 672 if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link); 673 else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link); 674 else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link); 675 else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link); 676 else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link); 677 else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link); 678 else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link); 679 else if (nPetscReal % 1 == 0) 680 #endif 681 PackInit_RealType<PetscReal, 1, 0>(link); 682 } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) { 683 #if !defined(PETSC_HAVE_DEVICE) 684 if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link); 685 else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link); 686 else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link); 687 else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link); 688 else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link); 689 else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link); 690 else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link); 691 else if (nPetscInt % 1 == 0) 692 #endif 693 PackInit_IntegerType<llint, 1, 0>(link); 694 } else if (nInt) { 695 #if !defined(PETSC_HAVE_DEVICE) 696 if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link); 697 else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link); 698 else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link); 699 else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link); 700 else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link); 701 else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link); 702 else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link); 703 else if (nInt % 1 == 0) 704 #endif 705 PackInit_IntegerType<int, 1, 0>(link); 706 } else if (nSignedChar) { 707 #if !defined(PETSC_HAVE_DEVICE) 708 if (nSignedChar == 8) PackInit_IntegerType<SignedChar, 8, 1>(link); 709 else if (nSignedChar % 8 == 0) PackInit_IntegerType<SignedChar, 8, 0>(link); 710 else if (nSignedChar == 4) PackInit_IntegerType<SignedChar, 4, 1>(link); 711 else if (nSignedChar % 4 == 0) PackInit_IntegerType<SignedChar, 4, 0>(link); 712 else if (nSignedChar == 2) PackInit_IntegerType<SignedChar, 2, 1>(link); 713 else if (nSignedChar % 2 == 0) PackInit_IntegerType<SignedChar, 2, 0>(link); 714 else if (nSignedChar == 1) PackInit_IntegerType<SignedChar, 1, 1>(link); 715 else if (nSignedChar % 1 == 0) 716 #endif 717 PackInit_IntegerType<SignedChar, 1, 0>(link); 718 } else if (nUnsignedChar) { 719 #if !defined(PETSC_HAVE_DEVICE) 720 if (nUnsignedChar == 8) PackInit_IntegerType<UnsignedChar, 8, 1>(link); 721 else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<UnsignedChar, 8, 0>(link); 722 else if (nUnsignedChar == 4) PackInit_IntegerType<UnsignedChar, 4, 1>(link); 723 else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<UnsignedChar, 4, 0>(link); 724 else if (nUnsignedChar == 2) PackInit_IntegerType<UnsignedChar, 2, 1>(link); 725 else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<UnsignedChar, 2, 0>(link); 726 else if (nUnsignedChar == 1) PackInit_IntegerType<UnsignedChar, 1, 1>(link); 727 else if (nUnsignedChar % 1 == 0) 728 #endif 729 PackInit_IntegerType<UnsignedChar, 1, 0>(link); 730 #if defined(PETSC_HAVE_COMPLEX) 731 } else if (nPetscComplex) { 732 #if !defined(PETSC_HAVE_DEVICE) 733 if (nPetscComplex == 8) PackInit_ComplexType<PetscComplex, 8, 1>(link); 734 else if (nPetscComplex % 8 == 0) PackInit_ComplexType<PetscComplex, 8, 0>(link); 735 else if (nPetscComplex == 4) PackInit_ComplexType<PetscComplex, 4, 1>(link); 736 else if (nPetscComplex % 4 == 0) PackInit_ComplexType<PetscComplex, 4, 0>(link); 737 else if (nPetscComplex == 2) PackInit_ComplexType<PetscComplex, 2, 1>(link); 738 else if (nPetscComplex % 2 == 0) PackInit_ComplexType<PetscComplex, 2, 0>(link); 739 else if (nPetscComplex == 1) PackInit_ComplexType<PetscComplex, 1, 1>(link); 740 else if (nPetscComplex % 1 == 0) 741 #endif 742 PackInit_ComplexType<PetscComplex, 1, 0>(link); 743 #endif 744 } else { 745 MPI_Aint lb, nbyte; 746 PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte)); 747 PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb); 748 if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */ 749 #if !defined(PETSC_HAVE_DEVICE) 750 if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link); 751 else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link); 752 else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link); 753 else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link); 754 else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link); 755 else if (nbyte % 1 == 0) 756 #endif 757 PackInit_DumbType<char, 1, 0>(link); 758 } else { 759 nInt = nbyte / sizeof(int); 760 #if !defined(PETSC_HAVE_DEVICE) 761 if (nInt == 8) PackInit_DumbType<int, 8, 1>(link); 762 else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link); 763 else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link); 764 else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link); 765 else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link); 766 else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link); 767 else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link); 768 else if (nInt % 1 == 0) 769 #endif 770 PackInit_DumbType<int, 1, 0>(link); 771 } 772 } 773 774 if (!sf->maxResidentThreadsPerGPU) { /* Not initialized */ 775 int device; 776 cupmDeviceProp_t props; 777 PetscCallCUPM(cupmGetDevice(&device)); 778 PetscCallCUPM(cupmGetDeviceProperties(&props, device)); 779 sf->maxResidentThreadsPerGPU = props.maxThreadsPerMultiProcessor * props.multiProcessorCount; 780 } 781 link->maxResidentThreadsPerGPU = sf->maxResidentThreadsPerGPU; 782 783 { 784 cupmStream_t *stream; 785 PetscDeviceContext dctx; 786 787 PetscCall(PetscDeviceContextGetCurrentContextAssertType_Internal(&dctx, PETSC_DEVICE_CUPM())); 788 PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream)); 789 link->stream = *stream; 790 } 791 link->Destroy = LinkDestroy_MPI; 792 link->SyncDevice = LinkSyncDevice; 793 link->SyncStream = LinkSyncStream; 794 link->Memcpy = LinkMemcpy; 795 link->deviceinited = PETSC_TRUE; 796 PetscFunctionReturn(PETSC_SUCCESS); 797 } 798 799 } // namespace impl 800 801 } // namespace cupm 802 803 } // namespace sf 804 805 } // namespace Petsc 806