1 #ifndef PETSC_CUPM_THRUST_UTILITY_HPP 2 #define PETSC_CUPM_THRUST_UTILITY_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupminterface.hpp> 6 7 #if defined(__cplusplus) 8 #include <thrust/device_ptr.h> 9 #include <thrust/transform.h> 10 11 namespace Petsc 12 { 13 14 namespace device 15 { 16 17 namespace cupm 18 { 19 20 namespace impl 21 { 22 23 #if PetscDefined(USING_NVCC) 24 #if !defined(THRUST_VERSION) 25 #error "THRUST_VERSION not defined!" 26 #endif 27 #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600) 28 #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__) 29 #else 30 #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__) 31 #endif 32 #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync 33 #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__) 34 #else 35 #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__) 36 #endif 37 38 namespace detail 39 { 40 41 struct PetscLogGpuTimer { 42 PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); } 43 ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); } 44 }; 45 46 struct private_tag { }; 47 48 } // namespace detail 49 50 #define THRUST_CALL(...) \ 51 [&] { \ 52 const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \ 53 return thrust_call_par_on(__VA_ARGS__); \ 54 }() 55 56 #define PetscCallThrust(...) \ 57 do { \ 58 try { \ 59 __VA_ARGS__; \ 60 } catch (const thrust::system_error &ex) { \ 61 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \ 62 } \ 63 } while (0) 64 65 #define PetscValidDevicePointer(ptr, argno) PetscAssert(ptr, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Null device pointer for " PetscStringize(ptr) " Argument #%d", argno); 66 67 // actual implementation that calls thrust, 2 argument version 68 template <DeviceType DT, typename FunctorType, typename T> 69 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr)) 70 { 71 const auto xptr = thrust::device_pointer_cast(xinout); 72 const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr; 73 74 PetscFunctionBegin; 75 PetscValidDevicePointer(xinout, 4); 76 PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor))); 77 PetscFunctionReturn(PETSC_SUCCESS); 78 } 79 80 // actual implementation that calls thrust, 3 argument version 81 template <DeviceType DT, typename FunctorType, typename T> 82 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin)) 83 { 84 const auto xptr = thrust::device_pointer_cast(xin); 85 86 PetscFunctionBegin; 87 PetscValidDevicePointer(xin, 4); 88 PetscValidDevicePointer(yin, 5); 89 PetscValidDevicePointer(zin, 6); 90 PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor))); 91 PetscFunctionReturn(PETSC_SUCCESS); 92 } 93 94 // one last intermediate function to check n, and log flops for everything 95 template <DeviceType DT, typename F, typename... T> 96 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest)) 97 { 98 PetscFunctionBegin; 99 PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n); 100 if (PetscLikely(n)) { 101 PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...)); 102 PetscCall(PetscLogGpuFlops(n)); 103 } 104 PetscFunctionReturn(PETSC_SUCCESS); 105 } 106 107 // serves as setup to the real implementation above 108 template <DeviceType T, typename F, typename... Args> 109 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest)) 110 { 111 typename Interface<T>::cupmStream_t stream; 112 113 PetscFunctionBegin; 114 static_assert(sizeof...(Args) <= 3, ""); 115 PetscValidDeviceContext(dctx, 1); 116 PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); 117 PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...)); 118 PetscFunctionReturn(PETSC_SUCCESS); 119 } 120 121 #define PetscCallCUPM_(...) \ 122 do { \ 123 using interface = Interface<DT>; \ 124 using cupmError_t = typename interface::cupmError_t; \ 125 const auto cupmName = []() { return interface::cupmName(); }; \ 126 const auto cupmGetErrorName = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \ 127 const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \ 128 const auto cupmSuccess = interface::cupmSuccess; \ 129 PetscCallCUPM(__VA_ARGS__); \ 130 } while (0) 131 132 template <DeviceType DT, typename T> 133 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val)) 134 { 135 PetscFunctionBegin; 136 PetscValidPointer(val, 4); 137 if (n) { 138 const auto size = n * sizeof(T); 139 140 PetscValidDevicePointer(ptr, 3); 141 if (*val == T{0}) { 142 PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream)); 143 } else { 144 const auto xptr = thrust::device_pointer_cast(ptr); 145 146 PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val)); 147 if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) { 148 PetscCall(PetscLogCpuToGpuScalar(size)); 149 } else { 150 PetscCall(PetscLogCpuToGpu(size)); 151 } 152 } 153 } 154 PetscFunctionReturn(PETSC_SUCCESS); 155 } 156 157 #undef PetscCallCUPM_ 158 #undef PetscValidDevicePointer 159 160 template <DeviceType DT, typename T> 161 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val)) 162 { 163 typename Interface<DT>::cupmStream_t stream; 164 165 PetscFunctionBegin; 166 PetscValidDeviceContext(dctx, 1); 167 PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream)); 168 PetscCall(ThrustSet(stream, n, ptr, val)); 169 PetscFunctionReturn(PETSC_SUCCESS); 170 } 171 172 } // namespace impl 173 174 } // namespace cupm 175 176 } // namespace device 177 178 } // namespace Petsc 179 180 #endif // __cplusplus 181 182 #endif // PETSC_CUPM_THRUST_UTILITY_HPP 183