#ifndef PETSC_CUPM_THRUST_UTILITY_HPP #define PETSC_CUPM_THRUST_UTILITY_HPP #include #include #if defined(__cplusplus) #include #include namespace Petsc { namespace device { namespace cupm { namespace impl { #if PetscDefined(USING_NVCC) #if !defined(THRUST_VERSION) #error "THRUST_VERSION not defined!" #endif #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600) #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__) #else #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__) #endif #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__) #else #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__) #endif namespace detail { struct PetscLogGpuTimer { PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); } ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); } }; struct private_tag { }; } // namespace detail #define THRUST_CALL(...) \ [&] { \ const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \ return thrust_call_par_on(__VA_ARGS__); \ }() #define PetscCallThrust(...) \ do { \ try { \ __VA_ARGS__; \ } catch (const thrust::system_error &ex) { \ SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \ } \ } while (0) #define PetscValidDevicePointer(ptr, argno) PetscAssert(ptr, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Null device pointer for " PetscStringize(ptr) " Argument #%d", argno); // actual implementation that calls thrust, 2 argument version template PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface
::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr)) { const auto xptr = thrust::device_pointer_cast(xinout); const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr; PetscFunctionBegin; PetscValidDevicePointer(xinout, 4); PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward(functor))); PetscFunctionReturn(PETSC_SUCCESS); } // actual implementation that calls thrust, 3 argument version template PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface
::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin)) { const auto xptr = thrust::device_pointer_cast(xin); PetscFunctionBegin; PetscValidDevicePointer(xin, 4); PetscValidDevicePointer(yin, 5); PetscValidDevicePointer(zin, 6); PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward(functor))); PetscFunctionReturn(PETSC_SUCCESS); } // one last intermediate function to check n, and log flops for everything template PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface
::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest)) { PetscFunctionBegin; PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n); if (PetscLikely(n)) { PetscCall(ThrustApplyPointwise
(detail::private_tag{}, stream, std::forward(functor), n, std::forward(rest)...)); PetscCall(PetscLogGpuFlops(n)); } PetscFunctionReturn(PETSC_SUCCESS); } // serves as setup to the real implementation above template PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest)) { typename Interface::cupmStream_t stream; PetscFunctionBegin; static_assert(sizeof...(Args) <= 3, ""); PetscValidDeviceContext(dctx, 1); PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); PetscCall(ThrustApplyPointwise(stream, std::forward(functor), n, std::forward(rest)...)); PetscFunctionReturn(PETSC_SUCCESS); } #define PetscCallCUPM_(...) \ do { \ using interface = Interface
; \ using cupmError_t = typename interface::cupmError_t; \ const auto cupmName = []() { return interface::cupmName(); }; \ const auto cupmGetErrorName = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \ const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \ const auto cupmSuccess = interface::cupmSuccess; \ PetscCallCUPM(__VA_ARGS__); \ } while (0) template PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface
::cupmStream_t stream, PetscInt n, T *ptr, const T *val)) { PetscFunctionBegin; PetscValidPointer(val, 4); if (n) { const auto size = n * sizeof(T); PetscValidDevicePointer(ptr, 3); if (*val == T{0}) { PetscCallCUPM_(Interface
::cupmMemsetAsync(ptr, 0, size, stream)); } else { const auto xptr = thrust::device_pointer_cast(ptr); PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val)); if (std::is_same, PetscScalar>::value) { PetscCall(PetscLogCpuToGpuScalar(size)); } else { PetscCall(PetscLogCpuToGpu(size)); } } } PetscFunctionReturn(PETSC_SUCCESS); } #undef PetscCallCUPM_ #undef PetscValidDevicePointer template PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val)) { typename Interface
::cupmStream_t stream; PetscFunctionBegin; PetscValidDeviceContext(dctx, 1); PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); PetscCall(ThrustSet(stream, n, ptr, val)); PetscFunctionReturn(PETSC_SUCCESS); } } // namespace impl } // namespace cupm } // namespace device } // namespace Petsc #endif // __cplusplus #endif // PETSC_CUPM_THRUST_UTILITY_HPP