xref: /petsc/src/sys/objects/device/impls/cupm/cupmthrustutility.hpp (revision 8bf4d53a7d073e2c076708235a5d9f72c132c691)
1 #ifndef PETSC_CUPM_THRUST_UTILITY_HPP
2 #define PETSC_CUPM_THRUST_UTILITY_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupminterface.hpp>
6 
7 #if defined(__cplusplus)
8 #include <thrust/device_ptr.h>
9 #include <thrust/transform.h>
10 
11 namespace Petsc {
12 
13 namespace device {
14 
15 namespace cupm {
16 
17 namespace impl {
18 
19 #if PetscDefined(USING_NVCC)
20 #if !defined(THRUST_VERSION)
21 #error "THRUST_VERSION not defined!"
22 #endif
23 #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600)
24 #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__)
25 #else
26 #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__)
27 #endif
28 #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync
29 #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__)
30 #else
31 #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__)
32 #endif
33 
34 namespace detail {
35 
36 struct PetscLogGpuTimer {
37   PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); }
38   ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); }
39 };
40 
41 struct private_tag { };
42 
43 } // namespace detail
44 
45 #define THRUST_CALL(...) \
46   [&] { \
47     const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \
48     return thrust_call_par_on(__VA_ARGS__); \
49   }()
50 
51 #define PetscCallThrust(...) \
52   do { \
53     try { \
54       __VA_ARGS__; \
55     } catch (const thrust::system_error &ex) { SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); } \
56   } while (0)
57 
58 template <typename T, typename BinaryOperator>
59 struct shift_operator {
60   const T *const       s;
61   const BinaryOperator op;
62 
63   PETSC_HOSTDEVICE_DECL PETSC_FORCEINLINE auto operator()(T x) const PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(op(std::move(x), *s))
64 };
65 
66 template <typename T, typename BinaryOperator>
67 static inline auto make_shift_operator(T *s, BinaryOperator &&op) PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(shift_operator<T, BinaryOperator>{s, std::forward<BinaryOperator>(op)});
68 
69 #define PetscValidDevicePointer(ptr, argno) PetscAssert(ptr, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Null device pointer for " PetscStringize(ptr) " Argument #%d", argno);
70 
71 // actual implementation that calls thrust, 2 argument version
72 template <DeviceType DT, typename FunctorType, typename T>
73 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr)) {
74   const auto xptr   = thrust::device_pointer_cast(xinout);
75   const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr;
76 
77   PetscFunctionBegin;
78   PetscValidDevicePointer(xinout, 4);
79   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor)));
80   PetscFunctionReturn(0);
81 }
82 
83 // actual implementation that calls thrust, 3 argument version
84 template <DeviceType DT, typename FunctorType, typename T>
85 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin)) {
86   const auto xptr = thrust::device_pointer_cast(xin);
87 
88   PetscFunctionBegin;
89   PetscValidDevicePointer(xin, 4);
90   PetscValidDevicePointer(yin, 5);
91   PetscValidDevicePointer(zin, 6);
92   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor)));
93   PetscFunctionReturn(0);
94 }
95 
96 // one last intermediate function to check n, and log flops for everything
97 template <DeviceType DT, typename F, typename... T>
98 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest)) {
99   PetscFunctionBegin;
100   PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n);
101   if (PetscLikely(n)) {
102     PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...));
103     PetscCall(PetscLogGpuFlops(n));
104   }
105   PetscFunctionReturn(0);
106 }
107 
108 // serves as setup to the real implementation above
109 template <DeviceType T, typename F, typename... Args>
110 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest)) {
111   typename Interface<T>::cupmStream_t stream;
112 
113   PetscFunctionBegin;
114   static_assert(sizeof...(Args) <= 3, "");
115   PetscValidDeviceContext(dctx, 1);
116   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
117   PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...));
118   PetscFunctionReturn(0);
119 }
120 
121 #define PetscCallCUPM_(...) \
122   do { \
123     using interface               = Interface<DT>; \
124     using cupmError_t             = typename interface::cupmError_t; \
125     const auto cupmName           = []() { return interface::cupmName(); }; \
126     const auto cupmGetErrorName   = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \
127     const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \
128     const auto cupmSuccess        = interface::cupmSuccess; \
129     PetscCallCUPM(__VA_ARGS__); \
130   } while (0)
131 
132 template <DeviceType DT, typename T>
133 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val)) {
134   PetscFunctionBegin;
135   PetscValidPointer(val, 4);
136   if (n) {
137     const auto size = n * sizeof(T);
138 
139     PetscValidDevicePointer(ptr, 3);
140     if (*val == T{0}) {
141       PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream));
142     } else {
143       const auto xptr = thrust::device_pointer_cast(ptr);
144 
145       PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val));
146       if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) {
147         PetscCall(PetscLogCpuToGpuScalar(size));
148       } else {
149         PetscCall(PetscLogCpuToGpu(size));
150       }
151     }
152   }
153   PetscFunctionReturn(0);
154 }
155 
156 #undef PetscCallCUPM_
157 #undef PetscValidDevicePointer
158 
159 template <DeviceType DT, typename T>
160 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val)) {
161   typename Interface<DT>::cupmStream_t stream;
162 
163   PetscFunctionBegin;
164   PetscValidDeviceContext(dctx, 1);
165   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
166   PetscCall(ThrustSet(stream, n, ptr, val));
167   PetscFunctionReturn(0);
168 }
169 
170 } // namespace impl
171 
172 } // namespace cupm
173 
174 } // namespace device
175 
176 } // namespace Petsc
177 
178 #endif // __cplusplus
179 
180 #endif // PETSC_CUPM_THRUST_UTILITY_HPP
181