#include <../src/vec/is/sf/impls/basic/sfpack.h> #include #include using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace; typedef Kokkos::View deviceBuffer_t; typedef Kokkos::View HostBuffer_t; typedef Kokkos::View deviceConstBuffer_t; typedef Kokkos::View HostConstBuffer_t; /*====================================================================================*/ /* Regular operations */ /*====================================================================================*/ template struct Insert { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { Type old = x; x = y; return old; } }; template struct Add { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { Type old = x; x += y; return old; } }; template struct Mult { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { Type old = x; x *= y; return old; } }; template struct Min { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { Type old = x; x = PetscMin(x, y); return old; } }; template struct Max { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { Type old = x; x = PetscMax(x, y); return old; } }; template struct LAND { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { Type old = x; x = x && y; return old; } }; template struct LOR { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { Type old = x; x = x || y; return old; } }; template struct LXOR { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { Type old = x; x = !x != !y; return old; } }; template struct BAND { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { Type old = x; x = x & y; return old; } }; template struct BOR { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { Type old = x; x = x | y; return old; } }; template struct BXOR { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { Type old = x; x = x ^ y; return old; } }; template struct Minloc { KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const { PairType old = x; if (y.first < x.first) x = y; else if (y.first == x.first) x.second = PetscMin(x.second, y.second); return old; } }; template struct Maxloc { KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const { PairType old = x; if (y.first > x.first) x = y; else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */ return old; } }; /*====================================================================================*/ /* Atomic operations */ /*====================================================================================*/ template struct AtomicInsert { KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_store(&x, y); } }; template struct AtomicAdd { KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); } }; template struct AtomicBAND { KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); } }; template struct AtomicBOR { KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); } }; template struct AtomicBXOR { KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); } }; template struct AtomicLAND { KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { const Type zero = 0, one = ~0; Kokkos::atomic_and(&x, y ? one : zero); } }; template struct AtomicLOR { KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { const Type zero = 0, one = 1; Kokkos::atomic_or(&x, y ? one : zero); } }; template struct AtomicMult { KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); } }; template struct AtomicMin { KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); } }; template struct AtomicMax { KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); } }; /* TODO: struct AtomicLXOR */ template struct AtomicFetchAdd { KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); } }; /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */ static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid) { PetscInt i, j, k, m, n, r; const PetscInt *offset, *start, *dx, *dy, *X, *Y; n = opt[0]; offset = opt + 1; start = opt + n + 2; dx = opt + 2 * n + 2; dy = opt + 3 * n + 2; X = opt + 5 * n + 2; Y = opt + 6 * n + 2; for (r = 0; r < n; r++) { if (tid < offset[r + 1]) break; } m = (tid - offset[r]); k = m / (dx[r] * dy[r]); j = (m - k * dx[r] * dy[r]) / dx[r]; i = m - k * dx[r] * dy[r] - j * dx[r]; return start[r] + k * X[r] * Y[r] + j * X[r] + i; } /*====================================================================================*/ /* Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link' */ /*====================================================================================*/ /* Suppose user calls PetscSFReduce(sf,unit,...) and is an MPI data type made of 16 PetscReals, then is PetscReal, which is the primitive type we operate on. is 16, which says contains 16 primitive types. is 8, which is the maximal SIMD width we will try to vectorize operations on . is 0, which is (bs == BS ? 1 : 0) If instead, has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant. For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled. */ template static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_) { const PetscInt *iopt = opt ? opt->array : NULL; const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */ const Type *data = static_cast(data_); Type *buf = static_cast(buf_); DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace(); PetscFunctionBegin; Kokkos::parallel_for( Kokkos::RangePolicy(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous; iopt == NULL && idx == NULL ==> the indices are contiguous; */ PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS; PetscInt s = tid * MBS; for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i]; }); PetscFunctionReturn(PETSC_SUCCESS); } template static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_) { Op op; const PetscInt *iopt = opt ? opt->array : NULL; const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; Type *data = static_cast(data_); const Type *buf = static_cast(buf_); DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace(); PetscFunctionBegin; Kokkos::parallel_for( Kokkos::RangePolicy(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS; PetscInt s = tid * MBS; for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]); }); PetscFunctionReturn(PETSC_SUCCESS); } template static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf) { Op op; const PetscInt *ropt = opt ? opt->array : NULL; const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; Type *rootdata = static_cast(data), *leafbuf = static_cast(buf); DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace(); PetscFunctionBegin; Kokkos::parallel_for( Kokkos::RangePolicy(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS; PetscInt l = tid * MBS; for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]); }); PetscFunctionReturn(PETSC_SUCCESS); } template static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) { PetscInt srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0; const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS; const Type *src = static_cast(src_); Type *dst = static_cast(dst_); DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace(); PetscFunctionBegin; /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */ if (srcOpt) { srcx = srcOpt->dx[0]; srcy = srcOpt->dy[0]; srcX = srcOpt->X[0]; srcY = srcOpt->Y[0]; srcStart = srcOpt->start[0]; srcIdx = NULL; } else if (!srcIdx) { srcx = srcX = count; srcy = srcY = 1; } if (dstOpt) { dstx = dstOpt->dx[0]; dsty = dstOpt->dy[0]; dstX = dstOpt->X[0]; dstY = dstOpt->Y[0]; dstStart = dstOpt->start[0]; dstIdx = NULL; } else if (!dstIdx) { dstx = dstX = count; dsty = dstY = 1; } Kokkos::parallel_for( Kokkos::RangePolicy(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { PetscInt i, j, k, s, t; Op op; if (!srcIdx) { /* src is in 3D */ k = tid / (srcx * srcy); j = (tid - k * srcx * srcy) / srcx; i = tid - k * srcx * srcy - j * srcx; s = srcStart + k * srcX * srcY + j * srcX + i; } else { /* src is contiguous */ s = srcIdx[tid]; } if (!dstIdx) { /* 3D */ k = tid / (dstx * dsty); j = (tid - k * dstx * dsty) / dstx; i = tid - k * dstx * dsty - j * dstx; t = dstStart + k * dstX * dstY + j * dstX + i; } else { /* contiguous */ t = dstIdx[tid]; } s *= MBS; t *= MBS; for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]); }); PetscFunctionReturn(PETSC_SUCCESS); } /* Specialization for Insert since we may use memcpy */ template static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) { const Type *src = static_cast(src_); Type *dst = static_cast(dst_); DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace(); PetscFunctionBegin; if (!count) PetscFunctionReturn(PETSC_SUCCESS); /*src and dst are contiguous */ if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) { size_t sz = count * link->unitbytes; deviceBuffer_t dbuf(reinterpret_cast(dst + dstStart * link->bs), sz); deviceConstBuffer_t sbuf(reinterpret_cast(src + srcStart * link->bs), sz); Kokkos::deep_copy(exec, dbuf, sbuf); } else { PetscCall(ScatterAndOp, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst)); } PetscFunctionReturn(PETSC_SUCCESS); } template static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_) { Op op; const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS; const PetscInt *ropt = rootopt ? rootopt->array : NULL; const PetscInt *lopt = leafopt ? leafopt->array : NULL; Type *rootdata = static_cast(rootdata_), *leafupdate = static_cast(leafupdate_); const Type *leafdata = static_cast(leafdata_); DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace(); PetscFunctionBegin; Kokkos::parallel_for( Kokkos::RangePolicy(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS; PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS; for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]); }); PetscFunctionReturn(PETSC_SUCCESS); } /*====================================================================================*/ /* Init various types and instantiate pack/unpack function pointers */ /*====================================================================================*/ template static void PackInit_RealType(PetscSFLink link) { /* Pack/unpack for remote communication */ link->d_Pack = Pack; link->d_UnpackAndInsert = UnpackAndOp, BS, EQ>; link->d_UnpackAndAdd = UnpackAndOp, BS, EQ>; link->d_UnpackAndMult = UnpackAndOp, BS, EQ>; link->d_UnpackAndMin = UnpackAndOp, BS, EQ>; link->d_UnpackAndMax = UnpackAndOp, BS, EQ>; link->d_FetchAndAdd = FetchAndOp, BS, EQ>; /* Scatter for local communication */ link->d_ScatterAndInsert = ScatterAndInsert; /* Has special optimizations */ link->d_ScatterAndAdd = ScatterAndOp, BS, EQ>; link->d_ScatterAndMult = ScatterAndOp, BS, EQ>; link->d_ScatterAndMin = ScatterAndOp, BS, EQ>; link->d_ScatterAndMax = ScatterAndOp, BS, EQ>; link->d_FetchAndAddLocal = FetchAndOpLocal, BS, EQ>; /* Atomic versions when there are data-race possibilities */ link->da_UnpackAndInsert = UnpackAndOp, BS, EQ>; link->da_UnpackAndAdd = UnpackAndOp, BS, EQ>; link->da_UnpackAndMult = UnpackAndOp, BS, EQ>; link->da_UnpackAndMin = UnpackAndOp, BS, EQ>; link->da_UnpackAndMax = UnpackAndOp, BS, EQ>; link->da_FetchAndAdd = FetchAndOp, BS, EQ>; link->da_ScatterAndInsert = ScatterAndOp, BS, EQ>; link->da_ScatterAndAdd = ScatterAndOp, BS, EQ>; link->da_ScatterAndMult = ScatterAndOp, BS, EQ>; link->da_ScatterAndMin = ScatterAndOp, BS, EQ>; link->da_ScatterAndMax = ScatterAndOp, BS, EQ>; link->da_FetchAndAddLocal = FetchAndOpLocal, BS, EQ>; } template static void PackInit_IntegerType(PetscSFLink link) { link->d_Pack = Pack; link->d_UnpackAndInsert = UnpackAndOp, BS, EQ>; link->d_UnpackAndAdd = UnpackAndOp, BS, EQ>; link->d_UnpackAndMult = UnpackAndOp, BS, EQ>; link->d_UnpackAndMin = UnpackAndOp, BS, EQ>; link->d_UnpackAndMax = UnpackAndOp, BS, EQ>; link->d_UnpackAndLAND = UnpackAndOp, BS, EQ>; link->d_UnpackAndLOR = UnpackAndOp, BS, EQ>; link->d_UnpackAndLXOR = UnpackAndOp, BS, EQ>; link->d_UnpackAndBAND = UnpackAndOp, BS, EQ>; link->d_UnpackAndBOR = UnpackAndOp, BS, EQ>; link->d_UnpackAndBXOR = UnpackAndOp, BS, EQ>; link->d_FetchAndAdd = FetchAndOp, BS, EQ>; link->d_ScatterAndInsert = ScatterAndInsert; link->d_ScatterAndAdd = ScatterAndOp, BS, EQ>; link->d_ScatterAndMult = ScatterAndOp, BS, EQ>; link->d_ScatterAndMin = ScatterAndOp, BS, EQ>; link->d_ScatterAndMax = ScatterAndOp, BS, EQ>; link->d_ScatterAndLAND = ScatterAndOp, BS, EQ>; link->d_ScatterAndLOR = ScatterAndOp, BS, EQ>; link->d_ScatterAndLXOR = ScatterAndOp, BS, EQ>; link->d_ScatterAndBAND = ScatterAndOp, BS, EQ>; link->d_ScatterAndBOR = ScatterAndOp, BS, EQ>; link->d_ScatterAndBXOR = ScatterAndOp, BS, EQ>; link->d_FetchAndAddLocal = FetchAndOpLocal, BS, EQ>; link->da_UnpackAndInsert = UnpackAndOp, BS, EQ>; link->da_UnpackAndAdd = UnpackAndOp, BS, EQ>; link->da_UnpackAndMult = UnpackAndOp, BS, EQ>; link->da_UnpackAndMin = UnpackAndOp, BS, EQ>; link->da_UnpackAndMax = UnpackAndOp, BS, EQ>; link->da_UnpackAndLAND = UnpackAndOp, BS, EQ>; link->da_UnpackAndLOR = UnpackAndOp, BS, EQ>; link->da_UnpackAndBAND = UnpackAndOp, BS, EQ>; link->da_UnpackAndBOR = UnpackAndOp, BS, EQ>; link->da_UnpackAndBXOR = UnpackAndOp, BS, EQ>; link->da_FetchAndAdd = FetchAndOp, BS, EQ>; link->da_ScatterAndInsert = ScatterAndOp, BS, EQ>; link->da_ScatterAndAdd = ScatterAndOp, BS, EQ>; link->da_ScatterAndMult = ScatterAndOp, BS, EQ>; link->da_ScatterAndMin = ScatterAndOp, BS, EQ>; link->da_ScatterAndMax = ScatterAndOp, BS, EQ>; link->da_ScatterAndLAND = ScatterAndOp, BS, EQ>; link->da_ScatterAndLOR = ScatterAndOp, BS, EQ>; link->da_ScatterAndBAND = ScatterAndOp, BS, EQ>; link->da_ScatterAndBOR = ScatterAndOp, BS, EQ>; link->da_ScatterAndBXOR = ScatterAndOp, BS, EQ>; link->da_FetchAndAddLocal = FetchAndOpLocal, BS, EQ>; } #if defined(PETSC_HAVE_COMPLEX) template static void PackInit_ComplexType(PetscSFLink link) { link->d_Pack = Pack; link->d_UnpackAndInsert = UnpackAndOp, BS, EQ>; link->d_UnpackAndAdd = UnpackAndOp, BS, EQ>; link->d_UnpackAndMult = UnpackAndOp, BS, EQ>; link->d_FetchAndAdd = FetchAndOp, BS, EQ>; link->d_ScatterAndInsert = ScatterAndInsert; link->d_ScatterAndAdd = ScatterAndOp, BS, EQ>; link->d_ScatterAndMult = ScatterAndOp, BS, EQ>; link->d_FetchAndAddLocal = FetchAndOpLocal, BS, EQ>; link->da_UnpackAndInsert = UnpackAndOp, BS, EQ>; link->da_UnpackAndAdd = UnpackAndOp, BS, EQ>; link->da_UnpackAndMult = UnpackAndOp, BS, EQ>; link->da_FetchAndAdd = FetchAndOp, BS, EQ>; link->da_ScatterAndInsert = ScatterAndOp, BS, EQ>; link->da_ScatterAndAdd = ScatterAndOp, BS, EQ>; link->da_ScatterAndMult = ScatterAndOp, BS, EQ>; link->da_FetchAndAddLocal = FetchAndOpLocal, BS, EQ>; } #endif template static void PackInit_PairType(PetscSFLink link) { link->d_Pack = Pack; link->d_UnpackAndInsert = UnpackAndOp, 1, 1>; link->d_UnpackAndMaxloc = UnpackAndOp, 1, 1>; link->d_UnpackAndMinloc = UnpackAndOp, 1, 1>; link->d_ScatterAndInsert = ScatterAndOp, 1, 1>; link->d_ScatterAndMaxloc = ScatterAndOp, 1, 1>; link->d_ScatterAndMinloc = ScatterAndOp, 1, 1>; /* Atomics for pair types are not implemented yet */ } template static void PackInit_DumbType(PetscSFLink link) { link->d_Pack = Pack; link->d_UnpackAndInsert = UnpackAndOp, BS, EQ>; link->d_ScatterAndInsert = ScatterAndInsert; /* Atomics for dumb types are not implemented yet */ } /* Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug that one is not able to repeatedly create and destroy the object. SF's original design was each SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from destroying multiple SFLinks with NULL stream and the default execution space object. To avoid memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton object in Kokkos. */ /* static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link) { PetscFunctionBegin; PetscFunctionReturn(PETSC_SUCCESS); } */ /* Some device-specific utilities */ static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link) { PetscFunctionBegin; Kokkos::fence(); PetscFunctionReturn(PETSC_SUCCESS); } static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link) { DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace(); PetscFunctionBegin; exec.fence(); PetscFunctionReturn(PETSC_SUCCESS); } static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n) { DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace(); PetscFunctionBegin; if (!n) PetscFunctionReturn(PETSC_SUCCESS); if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { // H2H PetscCallCXX(exec.fence()); // make sure async kernels on src are finished, in case of unified memory as on AMD MI300A. PetscCall(PetscMemcpy(dst, src, n)); } else { if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) { // H2D deviceBuffer_t dbuf(static_cast(dst), n); HostConstBuffer_t sbuf(static_cast(src), n); PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf)); PetscCall(PetscLogCpuToGpu(n)); } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) { // D2H HostBuffer_t dbuf(static_cast(dst), n); deviceConstBuffer_t sbuf(static_cast(src), n); PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf)); PetscCallCXX(exec.fence()); // make sure dbuf is ready for use immediately on host PetscCall(PetscLogGpuToCpu(n)); } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) { // D2D deviceBuffer_t dbuf(static_cast(dst), n); deviceConstBuffer_t sbuf(static_cast(src), n); PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf)); } } PetscFunctionReturn(PETSC_SUCCESS); } PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr) { PetscFunctionBegin; if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr)); else if (PetscMemTypeDevice(mtype)) { if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck()); PetscCallCXX(*ptr = Kokkos::kokkos_malloc(size)); } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype); PetscFunctionReturn(PETSC_SUCCESS); } PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr) { PetscFunctionBegin; if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr)); else if (PetscMemTypeDevice(mtype)) { PetscCallCXX(Kokkos::kokkos_free(ptr)); } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype); PetscFunctionReturn(PETSC_SUCCESS); } /* Destructor when the link uses MPI for communication */ static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link) { PetscFunctionBegin; for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) { PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE])); PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE])); } PetscFunctionReturn(PETSC_SUCCESS); } /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */ PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit) { PetscInt nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0; PetscBool is2Int, is2PetscInt; #if defined(PETSC_HAVE_COMPLEX) PetscInt nPetscComplex = 0; #endif PetscFunctionBegin; if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS); PetscCall(PetscKokkosInitializeCheck()); PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar)); PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar)); /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */ PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt)); PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt)); PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal)); #if defined(PETSC_HAVE_COMPLEX) PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex)); #endif PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int)); PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt)); if (is2Int) { PackInit_PairType>(link); } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */ PackInit_PairType>(link); } else if (nPetscReal) { #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */ if (nPetscReal == 8) PackInit_RealType(link); else if (nPetscReal % 8 == 0) PackInit_RealType(link); else if (nPetscReal == 4) PackInit_RealType(link); else if (nPetscReal % 4 == 0) PackInit_RealType(link); else if (nPetscReal == 2) PackInit_RealType(link); else if (nPetscReal % 2 == 0) PackInit_RealType(link); else if (nPetscReal == 1) PackInit_RealType(link); else if (nPetscReal % 1 == 0) #endif PackInit_RealType(link); } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) { #if !defined(PETSC_HAVE_DEVICE) if (nPetscInt == 8) PackInit_IntegerType(link); else if (nPetscInt % 8 == 0) PackInit_IntegerType(link); else if (nPetscInt == 4) PackInit_IntegerType(link); else if (nPetscInt % 4 == 0) PackInit_IntegerType(link); else if (nPetscInt == 2) PackInit_IntegerType(link); else if (nPetscInt % 2 == 0) PackInit_IntegerType(link); else if (nPetscInt == 1) PackInit_IntegerType(link); else if (nPetscInt % 1 == 0) #endif PackInit_IntegerType(link); } else if (nInt) { #if !defined(PETSC_HAVE_DEVICE) if (nInt == 8) PackInit_IntegerType(link); else if (nInt % 8 == 0) PackInit_IntegerType(link); else if (nInt == 4) PackInit_IntegerType(link); else if (nInt % 4 == 0) PackInit_IntegerType(link); else if (nInt == 2) PackInit_IntegerType(link); else if (nInt % 2 == 0) PackInit_IntegerType(link); else if (nInt == 1) PackInit_IntegerType(link); else if (nInt % 1 == 0) #endif PackInit_IntegerType(link); } else if (nSignedChar) { #if !defined(PETSC_HAVE_DEVICE) if (nSignedChar == 8) PackInit_IntegerType(link); else if (nSignedChar % 8 == 0) PackInit_IntegerType(link); else if (nSignedChar == 4) PackInit_IntegerType(link); else if (nSignedChar % 4 == 0) PackInit_IntegerType(link); else if (nSignedChar == 2) PackInit_IntegerType(link); else if (nSignedChar % 2 == 0) PackInit_IntegerType(link); else if (nSignedChar == 1) PackInit_IntegerType(link); else if (nSignedChar % 1 == 0) #endif PackInit_IntegerType(link); } else if (nUnsignedChar) { #if !defined(PETSC_HAVE_DEVICE) if (nUnsignedChar == 8) PackInit_IntegerType(link); else if (nUnsignedChar % 8 == 0) PackInit_IntegerType(link); else if (nUnsignedChar == 4) PackInit_IntegerType(link); else if (nUnsignedChar % 4 == 0) PackInit_IntegerType(link); else if (nUnsignedChar == 2) PackInit_IntegerType(link); else if (nUnsignedChar % 2 == 0) PackInit_IntegerType(link); else if (nUnsignedChar == 1) PackInit_IntegerType(link); else if (nUnsignedChar % 1 == 0) #endif PackInit_IntegerType(link); #if defined(PETSC_HAVE_COMPLEX) } else if (nPetscComplex) { #if !defined(PETSC_HAVE_DEVICE) if (nPetscComplex == 8) PackInit_ComplexType, 8, 1>(link); else if (nPetscComplex % 8 == 0) PackInit_ComplexType, 8, 0>(link); else if (nPetscComplex == 4) PackInit_ComplexType, 4, 1>(link); else if (nPetscComplex % 4 == 0) PackInit_ComplexType, 4, 0>(link); else if (nPetscComplex == 2) PackInit_ComplexType, 2, 1>(link); else if (nPetscComplex % 2 == 0) PackInit_ComplexType, 2, 0>(link); else if (nPetscComplex == 1) PackInit_ComplexType, 1, 1>(link); else if (nPetscComplex % 1 == 0) #endif PackInit_ComplexType, 1, 0>(link); #endif } else { MPI_Aint nbyte; PetscCall(PetscSFGetDatatypeSize_Internal(PETSC_COMM_SELF, unit, &nbyte)); if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */ #if !defined(PETSC_HAVE_DEVICE) if (nbyte == 4) PackInit_DumbType(link); else if (nbyte % 4 == 0) PackInit_DumbType(link); else if (nbyte == 2) PackInit_DumbType(link); else if (nbyte % 2 == 0) PackInit_DumbType(link); else if (nbyte == 1) PackInit_DumbType(link); else if (nbyte % 1 == 0) #endif PackInit_DumbType(link); } else { PetscCall(PetscIntCast(nbyte / sizeof(int), &nInt)); #if !defined(PETSC_HAVE_DEVICE) if (nInt == 8) PackInit_DumbType(link); else if (nInt % 8 == 0) PackInit_DumbType(link); else if (nInt == 4) PackInit_DumbType(link); else if (nInt % 4 == 0) PackInit_DumbType(link); else if (nInt == 2) PackInit_DumbType(link); else if (nInt % 2 == 0) PackInit_DumbType(link); else if (nInt == 1) PackInit_DumbType(link); else if (nInt % 1 == 0) #endif PackInit_DumbType(link); } } link->SyncDevice = PetscSFLinkSyncDevice_Kokkos; link->SyncStream = PetscSFLinkSyncStream_Kokkos; link->Memcpy = PetscSFLinkMemcpy_Kokkos; link->Destroy = PetscSFLinkDestroy_Kokkos; link->deviceinited = PETSC_TRUE; PetscFunctionReturn(PETSC_SUCCESS); }