10e6b6b59SJacob Faibussowitsch #ifndef PETSC_CUPM_THRUST_UTILITY_HPP 20e6b6b59SJacob Faibussowitsch #define PETSC_CUPM_THRUST_UTILITY_HPP 30e6b6b59SJacob Faibussowitsch 40e6b6b59SJacob Faibussowitsch #include <petsc/private/deviceimpl.h> 50e6b6b59SJacob Faibussowitsch #include <petsc/private/cupminterface.hpp> 60e6b6b59SJacob Faibussowitsch 70e6b6b59SJacob Faibussowitsch #if defined(__cplusplus) 80e6b6b59SJacob Faibussowitsch #include <thrust/device_ptr.h> 90e6b6b59SJacob Faibussowitsch #include <thrust/transform.h> 100e6b6b59SJacob Faibussowitsch 11*d71ae5a4SJacob Faibussowitsch namespace Petsc 12*d71ae5a4SJacob Faibussowitsch { 130e6b6b59SJacob Faibussowitsch 14*d71ae5a4SJacob Faibussowitsch namespace device 15*d71ae5a4SJacob Faibussowitsch { 160e6b6b59SJacob Faibussowitsch 17*d71ae5a4SJacob Faibussowitsch namespace cupm 18*d71ae5a4SJacob Faibussowitsch { 190e6b6b59SJacob Faibussowitsch 20*d71ae5a4SJacob Faibussowitsch namespace impl 21*d71ae5a4SJacob Faibussowitsch { 220e6b6b59SJacob Faibussowitsch 230e6b6b59SJacob Faibussowitsch #if PetscDefined(USING_NVCC) 240e6b6b59SJacob Faibussowitsch #if !defined(THRUST_VERSION) 250e6b6b59SJacob Faibussowitsch #error "THRUST_VERSION not defined!" 260e6b6b59SJacob Faibussowitsch #endif 270e6b6b59SJacob Faibussowitsch #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600) 280e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__) 290e6b6b59SJacob Faibussowitsch #else 300e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__) 310e6b6b59SJacob Faibussowitsch #endif 320e6b6b59SJacob Faibussowitsch #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync 330e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__) 340e6b6b59SJacob Faibussowitsch #else 350e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__) 360e6b6b59SJacob Faibussowitsch #endif 370e6b6b59SJacob Faibussowitsch 38*d71ae5a4SJacob Faibussowitsch namespace detail 39*d71ae5a4SJacob Faibussowitsch { 400e6b6b59SJacob Faibussowitsch 410e6b6b59SJacob Faibussowitsch struct PetscLogGpuTimer { 420e6b6b59SJacob Faibussowitsch PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); } 430e6b6b59SJacob Faibussowitsch ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); } 440e6b6b59SJacob Faibussowitsch }; 450e6b6b59SJacob Faibussowitsch 460e6b6b59SJacob Faibussowitsch struct private_tag { }; 470e6b6b59SJacob Faibussowitsch 480e6b6b59SJacob Faibussowitsch } // namespace detail 490e6b6b59SJacob Faibussowitsch 500e6b6b59SJacob Faibussowitsch #define THRUST_CALL(...) \ 510e6b6b59SJacob Faibussowitsch [&] { \ 520e6b6b59SJacob Faibussowitsch const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \ 530e6b6b59SJacob Faibussowitsch return thrust_call_par_on(__VA_ARGS__); \ 540e6b6b59SJacob Faibussowitsch }() 550e6b6b59SJacob Faibussowitsch 560e6b6b59SJacob Faibussowitsch #define PetscCallThrust(...) \ 570e6b6b59SJacob Faibussowitsch do { \ 580e6b6b59SJacob Faibussowitsch try { \ 590e6b6b59SJacob Faibussowitsch __VA_ARGS__; \ 60*d71ae5a4SJacob Faibussowitsch } catch (const thrust::system_error &ex) { \ 61*d71ae5a4SJacob Faibussowitsch SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \ 62*d71ae5a4SJacob Faibussowitsch } \ 630e6b6b59SJacob Faibussowitsch } while (0) 640e6b6b59SJacob Faibussowitsch 650e6b6b59SJacob Faibussowitsch template <typename T, typename BinaryOperator> 660e6b6b59SJacob Faibussowitsch struct shift_operator { 670e6b6b59SJacob Faibussowitsch const T *const s; 680e6b6b59SJacob Faibussowitsch const BinaryOperator op; 690e6b6b59SJacob Faibussowitsch 700e6b6b59SJacob Faibussowitsch PETSC_HOSTDEVICE_DECL PETSC_FORCEINLINE auto operator()(T x) const PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(op(std::move(x), *s)) 710e6b6b59SJacob Faibussowitsch }; 720e6b6b59SJacob Faibussowitsch 730e6b6b59SJacob Faibussowitsch template <typename T, typename BinaryOperator> 740e6b6b59SJacob Faibussowitsch static inline auto make_shift_operator(T *s, BinaryOperator &&op) PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(shift_operator<T, BinaryOperator>{s, std::forward<BinaryOperator>(op)}); 750e6b6b59SJacob Faibussowitsch 760e6b6b59SJacob Faibussowitsch #define PetscValidDevicePointer(ptr, argno) PetscAssert(ptr, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Null device pointer for " PetscStringize(ptr) " Argument #%d", argno); 770e6b6b59SJacob Faibussowitsch 780e6b6b59SJacob Faibussowitsch // actual implementation that calls thrust, 2 argument version 790e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename FunctorType, typename T> 80*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr)) 81*d71ae5a4SJacob Faibussowitsch { 820e6b6b59SJacob Faibussowitsch const auto xptr = thrust::device_pointer_cast(xinout); 830e6b6b59SJacob Faibussowitsch const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr; 840e6b6b59SJacob Faibussowitsch 850e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 860e6b6b59SJacob Faibussowitsch PetscValidDevicePointer(xinout, 4); 870e6b6b59SJacob Faibussowitsch PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor))); 880e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 890e6b6b59SJacob Faibussowitsch } 900e6b6b59SJacob Faibussowitsch 910e6b6b59SJacob Faibussowitsch // actual implementation that calls thrust, 3 argument version 920e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename FunctorType, typename T> 93*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin)) 94*d71ae5a4SJacob Faibussowitsch { 950e6b6b59SJacob Faibussowitsch const auto xptr = thrust::device_pointer_cast(xin); 960e6b6b59SJacob Faibussowitsch 970e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 980e6b6b59SJacob Faibussowitsch PetscValidDevicePointer(xin, 4); 990e6b6b59SJacob Faibussowitsch PetscValidDevicePointer(yin, 5); 1000e6b6b59SJacob Faibussowitsch PetscValidDevicePointer(zin, 6); 1010e6b6b59SJacob Faibussowitsch PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor))); 1020e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 1030e6b6b59SJacob Faibussowitsch } 1040e6b6b59SJacob Faibussowitsch 1050e6b6b59SJacob Faibussowitsch // one last intermediate function to check n, and log flops for everything 1060e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename F, typename... T> 107*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest)) 108*d71ae5a4SJacob Faibussowitsch { 1090e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1100e6b6b59SJacob Faibussowitsch PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n); 1110e6b6b59SJacob Faibussowitsch if (PetscLikely(n)) { 1120e6b6b59SJacob Faibussowitsch PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...)); 1130e6b6b59SJacob Faibussowitsch PetscCall(PetscLogGpuFlops(n)); 1140e6b6b59SJacob Faibussowitsch } 1150e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 1160e6b6b59SJacob Faibussowitsch } 1170e6b6b59SJacob Faibussowitsch 1180e6b6b59SJacob Faibussowitsch // serves as setup to the real implementation above 1190e6b6b59SJacob Faibussowitsch template <DeviceType T, typename F, typename... Args> 120*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest)) 121*d71ae5a4SJacob Faibussowitsch { 1220e6b6b59SJacob Faibussowitsch typename Interface<T>::cupmStream_t stream; 1230e6b6b59SJacob Faibussowitsch 1240e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1250e6b6b59SJacob Faibussowitsch static_assert(sizeof...(Args) <= 3, ""); 1260e6b6b59SJacob Faibussowitsch PetscValidDeviceContext(dctx, 1); 1270e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); 1280e6b6b59SJacob Faibussowitsch PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...)); 1290e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 1300e6b6b59SJacob Faibussowitsch } 1310e6b6b59SJacob Faibussowitsch 1320e6b6b59SJacob Faibussowitsch #define PetscCallCUPM_(...) \ 1330e6b6b59SJacob Faibussowitsch do { \ 1340e6b6b59SJacob Faibussowitsch using interface = Interface<DT>; \ 1350e6b6b59SJacob Faibussowitsch using cupmError_t = typename interface::cupmError_t; \ 1360e6b6b59SJacob Faibussowitsch const auto cupmName = []() { return interface::cupmName(); }; \ 1370e6b6b59SJacob Faibussowitsch const auto cupmGetErrorName = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \ 1380e6b6b59SJacob Faibussowitsch const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \ 1390e6b6b59SJacob Faibussowitsch const auto cupmSuccess = interface::cupmSuccess; \ 1400e6b6b59SJacob Faibussowitsch PetscCallCUPM(__VA_ARGS__); \ 1410e6b6b59SJacob Faibussowitsch } while (0) 1420e6b6b59SJacob Faibussowitsch 1430e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename T> 144*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val)) 145*d71ae5a4SJacob Faibussowitsch { 1460e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1470e6b6b59SJacob Faibussowitsch PetscValidPointer(val, 4); 1480e6b6b59SJacob Faibussowitsch if (n) { 1490e6b6b59SJacob Faibussowitsch const auto size = n * sizeof(T); 1500e6b6b59SJacob Faibussowitsch 1510e6b6b59SJacob Faibussowitsch PetscValidDevicePointer(ptr, 3); 1520e6b6b59SJacob Faibussowitsch if (*val == T{0}) { 1530e6b6b59SJacob Faibussowitsch PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream)); 1540e6b6b59SJacob Faibussowitsch } else { 1550e6b6b59SJacob Faibussowitsch const auto xptr = thrust::device_pointer_cast(ptr); 1560e6b6b59SJacob Faibussowitsch 1570e6b6b59SJacob Faibussowitsch PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val)); 1580e6b6b59SJacob Faibussowitsch if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) { 1590e6b6b59SJacob Faibussowitsch PetscCall(PetscLogCpuToGpuScalar(size)); 1600e6b6b59SJacob Faibussowitsch } else { 1610e6b6b59SJacob Faibussowitsch PetscCall(PetscLogCpuToGpu(size)); 1620e6b6b59SJacob Faibussowitsch } 1630e6b6b59SJacob Faibussowitsch } 1640e6b6b59SJacob Faibussowitsch } 1650e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 1660e6b6b59SJacob Faibussowitsch } 1670e6b6b59SJacob Faibussowitsch 1680e6b6b59SJacob Faibussowitsch #undef PetscCallCUPM_ 1690e6b6b59SJacob Faibussowitsch #undef PetscValidDevicePointer 1700e6b6b59SJacob Faibussowitsch 1710e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename T> 172*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val)) 173*d71ae5a4SJacob Faibussowitsch { 1740e6b6b59SJacob Faibussowitsch typename Interface<DT>::cupmStream_t stream; 1750e6b6b59SJacob Faibussowitsch 1760e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1770e6b6b59SJacob Faibussowitsch PetscValidDeviceContext(dctx, 1); 1780e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); 1790e6b6b59SJacob Faibussowitsch PetscCall(ThrustSet(stream, n, ptr, val)); 1800e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 1810e6b6b59SJacob Faibussowitsch } 1820e6b6b59SJacob Faibussowitsch 1830e6b6b59SJacob Faibussowitsch } // namespace impl 1840e6b6b59SJacob Faibussowitsch 1850e6b6b59SJacob Faibussowitsch } // namespace cupm 1860e6b6b59SJacob Faibussowitsch 1870e6b6b59SJacob Faibussowitsch } // namespace device 1880e6b6b59SJacob Faibussowitsch 1890e6b6b59SJacob Faibussowitsch } // namespace Petsc 1900e6b6b59SJacob Faibussowitsch 1910e6b6b59SJacob Faibussowitsch #endif // __cplusplus 1920e6b6b59SJacob Faibussowitsch 1930e6b6b59SJacob Faibussowitsch #endif // PETSC_CUPM_THRUST_UTILITY_HPP 194