1 #ifndef PETSC_CUPM_THRUST_UTILITY_HPP 2 #define PETSC_CUPM_THRUST_UTILITY_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupminterface.hpp> 6 7 #if defined(__cplusplus) 8 #include <thrust/device_ptr.h> 9 #include <thrust/transform.h> 10 11 namespace Petsc 12 { 13 14 namespace device 15 { 16 17 namespace cupm 18 { 19 20 namespace impl 21 { 22 23 #if PetscDefined(USING_NVCC) 24 #if !defined(THRUST_VERSION) 25 #error "THRUST_VERSION not defined!" 26 #endif 27 #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600) 28 #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__) 29 #else 30 #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__) 31 #endif 32 #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync 33 #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__) 34 #else 35 #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__) 36 #endif 37 38 namespace detail 39 { 40 41 struct PetscLogGpuTimer { 42 PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); } 43 ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); } 44 }; 45 46 struct private_tag { }; 47 48 } // namespace detail 49 50 #define THRUST_CALL(...) \ 51 [&] { \ 52 const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \ 53 return thrust_call_par_on(__VA_ARGS__); \ 54 }() 55 56 #define PetscCallThrust(...) \ 57 do { \ 58 try { \ 59 __VA_ARGS__; \ 60 } catch (const thrust::system_error &ex) { \ 61 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \ 62 } \ 63 } while (0) 64 65 template <typename T, typename BinaryOperator> 66 struct shift_operator { 67 const T *const s; 68 const BinaryOperator op; 69 70 PETSC_HOSTDEVICE_DECL PETSC_FORCEINLINE auto operator()(T x) const PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(op(std::move(x), *s)) 71 }; 72 73 template <typename T, typename BinaryOperator> 74 static inline auto make_shift_operator(T *s, BinaryOperator &&op) PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(shift_operator<T, BinaryOperator>{s, std::forward<BinaryOperator>(op)}); 75 76 #define PetscValidDevicePointer(ptr, argno) PetscAssert(ptr, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Null device pointer for " PetscStringize(ptr) " Argument #%d", argno); 77 78 // actual implementation that calls thrust, 2 argument version 79 template <DeviceType DT, typename FunctorType, typename T> 80 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr)) 81 { 82 const auto xptr = thrust::device_pointer_cast(xinout); 83 const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr; 84 85 PetscFunctionBegin; 86 PetscValidDevicePointer(xinout, 4); 87 PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor))); 88 PetscFunctionReturn(0); 89 } 90 91 // actual implementation that calls thrust, 3 argument version 92 template <DeviceType DT, typename FunctorType, typename T> 93 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin)) 94 { 95 const auto xptr = thrust::device_pointer_cast(xin); 96 97 PetscFunctionBegin; 98 PetscValidDevicePointer(xin, 4); 99 PetscValidDevicePointer(yin, 5); 100 PetscValidDevicePointer(zin, 6); 101 PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor))); 102 PetscFunctionReturn(0); 103 } 104 105 // one last intermediate function to check n, and log flops for everything 106 template <DeviceType DT, typename F, typename... T> 107 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest)) 108 { 109 PetscFunctionBegin; 110 PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n); 111 if (PetscLikely(n)) { 112 PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...)); 113 PetscCall(PetscLogGpuFlops(n)); 114 } 115 PetscFunctionReturn(0); 116 } 117 118 // serves as setup to the real implementation above 119 template <DeviceType T, typename F, typename... Args> 120 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest)) 121 { 122 typename Interface<T>::cupmStream_t stream; 123 124 PetscFunctionBegin; 125 static_assert(sizeof...(Args) <= 3, ""); 126 PetscValidDeviceContext(dctx, 1); 127 PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); 128 PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...)); 129 PetscFunctionReturn(0); 130 } 131 132 #define PetscCallCUPM_(...) \ 133 do { \ 134 using interface = Interface<DT>; \ 135 using cupmError_t = typename interface::cupmError_t; \ 136 const auto cupmName = []() { return interface::cupmName(); }; \ 137 const auto cupmGetErrorName = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \ 138 const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \ 139 const auto cupmSuccess = interface::cupmSuccess; \ 140 PetscCallCUPM(__VA_ARGS__); \ 141 } while (0) 142 143 template <DeviceType DT, typename T> 144 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val)) 145 { 146 PetscFunctionBegin; 147 PetscValidPointer(val, 4); 148 if (n) { 149 const auto size = n * sizeof(T); 150 151 PetscValidDevicePointer(ptr, 3); 152 if (*val == T{0}) { 153 PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream)); 154 } else { 155 const auto xptr = thrust::device_pointer_cast(ptr); 156 157 PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val)); 158 if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) { 159 PetscCall(PetscLogCpuToGpuScalar(size)); 160 } else { 161 PetscCall(PetscLogCpuToGpu(size)); 162 } 163 } 164 } 165 PetscFunctionReturn(0); 166 } 167 168 #undef PetscCallCUPM_ 169 #undef PetscValidDevicePointer 170 171 template <DeviceType DT, typename T> 172 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val)) 173 { 174 typename Interface<DT>::cupmStream_t stream; 175 176 PetscFunctionBegin; 177 PetscValidDeviceContext(dctx, 1); 178 PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); 179 PetscCall(ThrustSet(stream, n, ptr, val)); 180 PetscFunctionReturn(0); 181 } 182 183 } // namespace impl 184 185 } // namespace cupm 186 187 } // namespace device 188 189 } // namespace Petsc 190 191 #endif // __cplusplus 192 193 #endif // PETSC_CUPM_THRUST_UTILITY_HPP 194