1*0e6b6b59SJacob Faibussowitsch #ifndef PETSC_CUPM_THRUST_UTILITY_HPP 2*0e6b6b59SJacob Faibussowitsch #define PETSC_CUPM_THRUST_UTILITY_HPP 3*0e6b6b59SJacob Faibussowitsch 4*0e6b6b59SJacob Faibussowitsch #include <petsc/private/deviceimpl.h> 5*0e6b6b59SJacob Faibussowitsch #include <petsc/private/cupminterface.hpp> 6*0e6b6b59SJacob Faibussowitsch 7*0e6b6b59SJacob Faibussowitsch #if defined(__cplusplus) 8*0e6b6b59SJacob Faibussowitsch #include <thrust/device_ptr.h> 9*0e6b6b59SJacob Faibussowitsch #include <thrust/transform.h> 10*0e6b6b59SJacob Faibussowitsch 11*0e6b6b59SJacob Faibussowitsch namespace Petsc { 12*0e6b6b59SJacob Faibussowitsch 13*0e6b6b59SJacob Faibussowitsch namespace device { 14*0e6b6b59SJacob Faibussowitsch 15*0e6b6b59SJacob Faibussowitsch namespace cupm { 16*0e6b6b59SJacob Faibussowitsch 17*0e6b6b59SJacob Faibussowitsch namespace impl { 18*0e6b6b59SJacob Faibussowitsch 19*0e6b6b59SJacob Faibussowitsch #if PetscDefined(USING_NVCC) 20*0e6b6b59SJacob Faibussowitsch #if !defined(THRUST_VERSION) 21*0e6b6b59SJacob Faibussowitsch #error "THRUST_VERSION not defined!" 22*0e6b6b59SJacob Faibussowitsch #endif 23*0e6b6b59SJacob Faibussowitsch #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600) 24*0e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__) 25*0e6b6b59SJacob Faibussowitsch #else 26*0e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__) 27*0e6b6b59SJacob Faibussowitsch #endif 28*0e6b6b59SJacob Faibussowitsch #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync 29*0e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__) 30*0e6b6b59SJacob Faibussowitsch #else 31*0e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__) 32*0e6b6b59SJacob Faibussowitsch #endif 33*0e6b6b59SJacob Faibussowitsch 34*0e6b6b59SJacob Faibussowitsch namespace detail { 35*0e6b6b59SJacob Faibussowitsch 36*0e6b6b59SJacob Faibussowitsch struct PetscLogGpuTimer { 37*0e6b6b59SJacob Faibussowitsch PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); } 38*0e6b6b59SJacob Faibussowitsch ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); } 39*0e6b6b59SJacob Faibussowitsch }; 40*0e6b6b59SJacob Faibussowitsch 41*0e6b6b59SJacob Faibussowitsch struct private_tag { }; 42*0e6b6b59SJacob Faibussowitsch 43*0e6b6b59SJacob Faibussowitsch } // namespace detail 44*0e6b6b59SJacob Faibussowitsch 45*0e6b6b59SJacob Faibussowitsch #define THRUST_CALL(...) \ 46*0e6b6b59SJacob Faibussowitsch [&] { \ 47*0e6b6b59SJacob Faibussowitsch const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \ 48*0e6b6b59SJacob Faibussowitsch return thrust_call_par_on(__VA_ARGS__); \ 49*0e6b6b59SJacob Faibussowitsch }() 50*0e6b6b59SJacob Faibussowitsch 51*0e6b6b59SJacob Faibussowitsch #define PetscCallThrust(...) \ 52*0e6b6b59SJacob Faibussowitsch do { \ 53*0e6b6b59SJacob Faibussowitsch try { \ 54*0e6b6b59SJacob Faibussowitsch __VA_ARGS__; \ 55*0e6b6b59SJacob Faibussowitsch } catch (const thrust::system_error &ex) { SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); } \ 56*0e6b6b59SJacob Faibussowitsch } while (0) 57*0e6b6b59SJacob Faibussowitsch 58*0e6b6b59SJacob Faibussowitsch template <typename T, typename BinaryOperator> 59*0e6b6b59SJacob Faibussowitsch struct shift_operator { 60*0e6b6b59SJacob Faibussowitsch const T *const s; 61*0e6b6b59SJacob Faibussowitsch const BinaryOperator op; 62*0e6b6b59SJacob Faibussowitsch 63*0e6b6b59SJacob Faibussowitsch PETSC_HOSTDEVICE_DECL PETSC_FORCEINLINE auto operator()(T x) const PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(op(std::move(x), *s)) 64*0e6b6b59SJacob Faibussowitsch }; 65*0e6b6b59SJacob Faibussowitsch 66*0e6b6b59SJacob Faibussowitsch template <typename T, typename BinaryOperator> 67*0e6b6b59SJacob Faibussowitsch static inline auto make_shift_operator(T *s, BinaryOperator &&op) PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(shift_operator<T, BinaryOperator>{s, std::forward<BinaryOperator>(op)}); 68*0e6b6b59SJacob Faibussowitsch 69*0e6b6b59SJacob Faibussowitsch #define PetscValidDevicePointer(ptr, argno) PetscAssert(ptr, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Null device pointer for " PetscStringize(ptr) " Argument #%d", argno); 70*0e6b6b59SJacob Faibussowitsch 71*0e6b6b59SJacob Faibussowitsch // actual implementation that calls thrust, 2 argument version 72*0e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename FunctorType, typename T> 73*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr)) { 74*0e6b6b59SJacob Faibussowitsch const auto xptr = thrust::device_pointer_cast(xinout); 75*0e6b6b59SJacob Faibussowitsch const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr; 76*0e6b6b59SJacob Faibussowitsch 77*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 78*0e6b6b59SJacob Faibussowitsch PetscValidDevicePointer(xinout, 4); 79*0e6b6b59SJacob Faibussowitsch PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor))); 80*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 81*0e6b6b59SJacob Faibussowitsch } 82*0e6b6b59SJacob Faibussowitsch 83*0e6b6b59SJacob Faibussowitsch // actual implementation that calls thrust, 3 argument version 84*0e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename FunctorType, typename T> 85*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin)) { 86*0e6b6b59SJacob Faibussowitsch const auto xptr = thrust::device_pointer_cast(xin); 87*0e6b6b59SJacob Faibussowitsch 88*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 89*0e6b6b59SJacob Faibussowitsch PetscValidDevicePointer(xin, 4); 90*0e6b6b59SJacob Faibussowitsch PetscValidDevicePointer(yin, 5); 91*0e6b6b59SJacob Faibussowitsch PetscValidDevicePointer(zin, 6); 92*0e6b6b59SJacob Faibussowitsch PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor))); 93*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 94*0e6b6b59SJacob Faibussowitsch } 95*0e6b6b59SJacob Faibussowitsch 96*0e6b6b59SJacob Faibussowitsch // one last intermediate function to check n, and log flops for everything 97*0e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename F, typename... T> 98*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest)) { 99*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 100*0e6b6b59SJacob Faibussowitsch PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n); 101*0e6b6b59SJacob Faibussowitsch if (PetscLikely(n)) { 102*0e6b6b59SJacob Faibussowitsch PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...)); 103*0e6b6b59SJacob Faibussowitsch PetscCall(PetscLogGpuFlops(n)); 104*0e6b6b59SJacob Faibussowitsch } 105*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 106*0e6b6b59SJacob Faibussowitsch } 107*0e6b6b59SJacob Faibussowitsch 108*0e6b6b59SJacob Faibussowitsch // serves as setup to the real implementation above 109*0e6b6b59SJacob Faibussowitsch template <DeviceType T, typename F, typename... Args> 110*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest)) { 111*0e6b6b59SJacob Faibussowitsch typename Interface<T>::cupmStream_t stream; 112*0e6b6b59SJacob Faibussowitsch 113*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 114*0e6b6b59SJacob Faibussowitsch static_assert(sizeof...(Args) <= 3, ""); 115*0e6b6b59SJacob Faibussowitsch PetscValidDeviceContext(dctx, 1); 116*0e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); 117*0e6b6b59SJacob Faibussowitsch PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...)); 118*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 119*0e6b6b59SJacob Faibussowitsch } 120*0e6b6b59SJacob Faibussowitsch 121*0e6b6b59SJacob Faibussowitsch #define PetscCallCUPM_(...) \ 122*0e6b6b59SJacob Faibussowitsch do { \ 123*0e6b6b59SJacob Faibussowitsch using interface = Interface<DT>; \ 124*0e6b6b59SJacob Faibussowitsch using cupmError_t = typename interface::cupmError_t; \ 125*0e6b6b59SJacob Faibussowitsch const auto cupmName = []() { return interface::cupmName(); }; \ 126*0e6b6b59SJacob Faibussowitsch const auto cupmGetErrorName = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \ 127*0e6b6b59SJacob Faibussowitsch const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \ 128*0e6b6b59SJacob Faibussowitsch const auto cupmSuccess = interface::cupmSuccess; \ 129*0e6b6b59SJacob Faibussowitsch PetscCallCUPM(__VA_ARGS__); \ 130*0e6b6b59SJacob Faibussowitsch } while (0) 131*0e6b6b59SJacob Faibussowitsch 132*0e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename T> 133*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val)) { 134*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 135*0e6b6b59SJacob Faibussowitsch PetscValidPointer(val, 4); 136*0e6b6b59SJacob Faibussowitsch if (n) { 137*0e6b6b59SJacob Faibussowitsch const auto size = n * sizeof(T); 138*0e6b6b59SJacob Faibussowitsch 139*0e6b6b59SJacob Faibussowitsch PetscValidDevicePointer(ptr, 3); 140*0e6b6b59SJacob Faibussowitsch if (*val == T{0}) { 141*0e6b6b59SJacob Faibussowitsch PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream)); 142*0e6b6b59SJacob Faibussowitsch } else { 143*0e6b6b59SJacob Faibussowitsch const auto xptr = thrust::device_pointer_cast(ptr); 144*0e6b6b59SJacob Faibussowitsch 145*0e6b6b59SJacob Faibussowitsch PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val)); 146*0e6b6b59SJacob Faibussowitsch if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) { 147*0e6b6b59SJacob Faibussowitsch PetscCall(PetscLogCpuToGpuScalar(size)); 148*0e6b6b59SJacob Faibussowitsch } else { 149*0e6b6b59SJacob Faibussowitsch PetscCall(PetscLogCpuToGpu(size)); 150*0e6b6b59SJacob Faibussowitsch } 151*0e6b6b59SJacob Faibussowitsch } 152*0e6b6b59SJacob Faibussowitsch } 153*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 154*0e6b6b59SJacob Faibussowitsch } 155*0e6b6b59SJacob Faibussowitsch 156*0e6b6b59SJacob Faibussowitsch #undef PetscCallCUPM_ 157*0e6b6b59SJacob Faibussowitsch #undef PetscValidDevicePointer 158*0e6b6b59SJacob Faibussowitsch 159*0e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename T> 160*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val)) { 161*0e6b6b59SJacob Faibussowitsch typename Interface<DT>::cupmStream_t stream; 162*0e6b6b59SJacob Faibussowitsch 163*0e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 164*0e6b6b59SJacob Faibussowitsch PetscValidDeviceContext(dctx, 1); 165*0e6b6b59SJacob Faibussowitsch PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); 166*0e6b6b59SJacob Faibussowitsch PetscCall(ThrustSet(stream, n, ptr, val)); 167*0e6b6b59SJacob Faibussowitsch PetscFunctionReturn(0); 168*0e6b6b59SJacob Faibussowitsch } 169*0e6b6b59SJacob Faibussowitsch 170*0e6b6b59SJacob Faibussowitsch } // namespace impl 171*0e6b6b59SJacob Faibussowitsch 172*0e6b6b59SJacob Faibussowitsch } // namespace cupm 173*0e6b6b59SJacob Faibussowitsch 174*0e6b6b59SJacob Faibussowitsch } // namespace device 175*0e6b6b59SJacob Faibussowitsch 176*0e6b6b59SJacob Faibussowitsch } // namespace Petsc 177*0e6b6b59SJacob Faibussowitsch 178*0e6b6b59SJacob Faibussowitsch #endif // __cplusplus 179*0e6b6b59SJacob Faibussowitsch 180*0e6b6b59SJacob Faibussowitsch #endif // PETSC_CUPM_THRUST_UTILITY_HPP 181