1 #ifndef PETSC_CUPM_THRUST_UTILITY_HPP 2 #define PETSC_CUPM_THRUST_UTILITY_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupminterface.hpp> 6 7 #if defined(__cplusplus) 8 #include <thrust/device_ptr.h> 9 #include <thrust/transform.h> 10 11 namespace Petsc { 12 13 namespace device { 14 15 namespace cupm { 16 17 namespace impl { 18 19 #if PetscDefined(USING_NVCC) 20 #if !defined(THRUST_VERSION) 21 #error "THRUST_VERSION not defined!" 22 #endif 23 #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600) 24 #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__) 25 #else 26 #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__) 27 #endif 28 #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync 29 #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__) 30 #else 31 #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__) 32 #endif 33 34 namespace detail { 35 36 struct PetscLogGpuTimer { 37 PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); } 38 ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); } 39 }; 40 41 struct private_tag { }; 42 43 } // namespace detail 44 45 #define THRUST_CALL(...) \ 46 [&] { \ 47 const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \ 48 return thrust_call_par_on(__VA_ARGS__); \ 49 }() 50 51 #define PetscCallThrust(...) \ 52 do { \ 53 try { \ 54 __VA_ARGS__; \ 55 } catch (const thrust::system_error &ex) { SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); } \ 56 } while (0) 57 58 template <typename T, typename BinaryOperator> 59 struct shift_operator { 60 const T *const s; 61 const BinaryOperator op; 62 63 PETSC_HOSTDEVICE_DECL PETSC_FORCEINLINE auto operator()(T x) const PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(op(std::move(x), *s)) 64 }; 65 66 template <typename T, typename BinaryOperator> 67 static inline auto make_shift_operator(T *s, BinaryOperator &&op) PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(shift_operator<T, BinaryOperator>{s, std::forward<BinaryOperator>(op)}); 68 69 #define PetscValidDevicePointer(ptr, argno) PetscAssert(ptr, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Null device pointer for " PetscStringize(ptr) " Argument #%d", argno); 70 71 // actual implementation that calls thrust, 2 argument version 72 template <DeviceType DT, typename FunctorType, typename T> 73 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr)) { 74 const auto xptr = thrust::device_pointer_cast(xinout); 75 const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr; 76 77 PetscFunctionBegin; 78 PetscValidDevicePointer(xinout, 4); 79 PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor))); 80 PetscFunctionReturn(0); 81 } 82 83 // actual implementation that calls thrust, 3 argument version 84 template <DeviceType DT, typename FunctorType, typename T> 85 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin)) { 86 const auto xptr = thrust::device_pointer_cast(xin); 87 88 PetscFunctionBegin; 89 PetscValidDevicePointer(xin, 4); 90 PetscValidDevicePointer(yin, 5); 91 PetscValidDevicePointer(zin, 6); 92 PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor))); 93 PetscFunctionReturn(0); 94 } 95 96 // one last intermediate function to check n, and log flops for everything 97 template <DeviceType DT, typename F, typename... T> 98 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest)) { 99 PetscFunctionBegin; 100 PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n); 101 if (PetscLikely(n)) { 102 PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...)); 103 PetscCall(PetscLogGpuFlops(n)); 104 } 105 PetscFunctionReturn(0); 106 } 107 108 // serves as setup to the real implementation above 109 template <DeviceType T, typename F, typename... Args> 110 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest)) { 111 typename Interface<T>::cupmStream_t stream; 112 113 PetscFunctionBegin; 114 static_assert(sizeof...(Args) <= 3, ""); 115 PetscValidDeviceContext(dctx, 1); 116 PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); 117 PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...)); 118 PetscFunctionReturn(0); 119 } 120 121 #define PetscCallCUPM_(...) \ 122 do { \ 123 using interface = Interface<DT>; \ 124 using cupmError_t = typename interface::cupmError_t; \ 125 const auto cupmName = []() { return interface::cupmName(); }; \ 126 const auto cupmGetErrorName = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \ 127 const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \ 128 const auto cupmSuccess = interface::cupmSuccess; \ 129 PetscCallCUPM(__VA_ARGS__); \ 130 } while (0) 131 132 template <DeviceType DT, typename T> 133 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val)) { 134 PetscFunctionBegin; 135 PetscValidPointer(val, 4); 136 if (n) { 137 const auto size = n * sizeof(T); 138 139 PetscValidDevicePointer(ptr, 3); 140 if (*val == T{0}) { 141 PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream)); 142 } else { 143 const auto xptr = thrust::device_pointer_cast(ptr); 144 145 PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val)); 146 if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) { 147 PetscCall(PetscLogCpuToGpuScalar(size)); 148 } else { 149 PetscCall(PetscLogCpuToGpu(size)); 150 } 151 } 152 } 153 PetscFunctionReturn(0); 154 } 155 156 #undef PetscCallCUPM_ 157 #undef PetscValidDevicePointer 158 159 template <DeviceType DT, typename T> 160 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val)) { 161 typename Interface<DT>::cupmStream_t stream; 162 163 PetscFunctionBegin; 164 PetscValidDeviceContext(dctx, 1); 165 PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); 166 PetscCall(ThrustSet(stream, n, ptr, val)); 167 PetscFunctionReturn(0); 168 } 169 170 } // namespace impl 171 172 } // namespace cupm 173 174 } // namespace device 175 176 } // namespace Petsc 177 178 #endif // __cplusplus 179 180 #endif // PETSC_CUPM_THRUST_UTILITY_HPP 181