xref: /petsc/src/sys/objects/device/impls/cupm/cupmthrustutility.hpp (revision 862e4a309d45a165aaa4da0d704ba733429d833a)
1 #ifndef PETSC_CUPM_THRUST_UTILITY_HPP
2 #define PETSC_CUPM_THRUST_UTILITY_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupminterface.hpp>
6 
7 #if defined(__cplusplus)
8   #include <thrust/device_ptr.h>
9   #include <thrust/transform.h>
10 
11 namespace Petsc
12 {
13 
14 namespace device
15 {
16 
17 namespace cupm
18 {
19 
20 namespace impl
21 {
22 
23   #if PetscDefined(USING_NVCC)
24     #if !defined(THRUST_VERSION)
25       #error "THRUST_VERSION not defined!"
26     #endif
27     #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600)
28       #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__)
29     #else
30       #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__)
31     #endif
32   #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync
33     #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__)
34   #else
35     #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__)
36   #endif
37 
38 namespace detail
39 {
40 
41 struct PetscLogGpuTimer {
42   PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); }
43   ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); }
44 };
45 
46 struct private_tag { };
47 
48 } // namespace detail
49 
50   #define THRUST_CALL(...) \
51     [&] { \
52       const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \
53       return thrust_call_par_on(__VA_ARGS__); \
54     }()
55 
56   #define PetscCallThrust(...) \
57     do { \
58       try { \
59         __VA_ARGS__; \
60       } catch (const thrust::system_error &ex) { \
61         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \
62       } \
63     } while (0)
64 
65 template <typename T, typename BinaryOperator>
66 struct shift_operator {
67   const T *const       s;
68   const BinaryOperator op;
69 
70   PETSC_HOSTDEVICE_DECL PETSC_FORCEINLINE auto operator()(T x) const PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(op(std::move(x), *s))
71 };
72 
73 template <typename T, typename BinaryOperator>
74 static inline auto make_shift_operator(T *s, BinaryOperator &&op) PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(shift_operator<T, BinaryOperator>{s, std::forward<BinaryOperator>(op)});
75 
76   #define PetscValidDevicePointer(ptr, argno) PetscAssert(ptr, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Null device pointer for " PetscStringize(ptr) " Argument #%d", argno);
77 
78 // actual implementation that calls thrust, 2 argument version
79 template <DeviceType DT, typename FunctorType, typename T>
80 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr))
81 {
82   const auto xptr   = thrust::device_pointer_cast(xinout);
83   const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr;
84 
85   PetscFunctionBegin;
86   PetscValidDevicePointer(xinout, 4);
87   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor)));
88   PetscFunctionReturn(0);
89 }
90 
91 // actual implementation that calls thrust, 3 argument version
92 template <DeviceType DT, typename FunctorType, typename T>
93 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin))
94 {
95   const auto xptr = thrust::device_pointer_cast(xin);
96 
97   PetscFunctionBegin;
98   PetscValidDevicePointer(xin, 4);
99   PetscValidDevicePointer(yin, 5);
100   PetscValidDevicePointer(zin, 6);
101   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor)));
102   PetscFunctionReturn(0);
103 }
104 
105 // one last intermediate function to check n, and log flops for everything
106 template <DeviceType DT, typename F, typename... T>
107 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest))
108 {
109   PetscFunctionBegin;
110   PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n);
111   if (PetscLikely(n)) {
112     PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...));
113     PetscCall(PetscLogGpuFlops(n));
114   }
115   PetscFunctionReturn(0);
116 }
117 
118 // serves as setup to the real implementation above
119 template <DeviceType T, typename F, typename... Args>
120 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest))
121 {
122   typename Interface<T>::cupmStream_t stream;
123 
124   PetscFunctionBegin;
125   static_assert(sizeof...(Args) <= 3, "");
126   PetscValidDeviceContext(dctx, 1);
127   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
128   PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...));
129   PetscFunctionReturn(0);
130 }
131 
132   #define PetscCallCUPM_(...) \
133     do { \
134       using interface               = Interface<DT>; \
135       using cupmError_t             = typename interface::cupmError_t; \
136       const auto cupmName           = []() { return interface::cupmName(); }; \
137       const auto cupmGetErrorName   = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \
138       const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \
139       const auto cupmSuccess        = interface::cupmSuccess; \
140       PetscCallCUPM(__VA_ARGS__); \
141     } while (0)
142 
143 template <DeviceType DT, typename T>
144 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val))
145 {
146   PetscFunctionBegin;
147   PetscValidPointer(val, 4);
148   if (n) {
149     const auto size = n * sizeof(T);
150 
151     PetscValidDevicePointer(ptr, 3);
152     if (*val == T{0}) {
153       PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream));
154     } else {
155       const auto xptr = thrust::device_pointer_cast(ptr);
156 
157       PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val));
158       if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) {
159         PetscCall(PetscLogCpuToGpuScalar(size));
160       } else {
161         PetscCall(PetscLogCpuToGpu(size));
162       }
163     }
164   }
165   PetscFunctionReturn(0);
166 }
167 
168   #undef PetscCallCUPM_
169   #undef PetscValidDevicePointer
170 
171 template <DeviceType DT, typename T>
172 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val))
173 {
174   typename Interface<DT>::cupmStream_t stream;
175 
176   PetscFunctionBegin;
177   PetscValidDeviceContext(dctx, 1);
178   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
179   PetscCall(ThrustSet(stream, n, ptr, val));
180   PetscFunctionReturn(0);
181 }
182 
183 } // namespace impl
184 
185 } // namespace cupm
186 
187 } // namespace device
188 
189 } // namespace Petsc
190 
191 #endif // __cplusplus
192 
193 #endif // PETSC_CUPM_THRUST_UTILITY_HPP
194