xref: /petsc/src/sys/objects/device/impls/cupm/cupmthrustutility.hpp (revision 0e6b6b5985dd9b1172860d21fb88bd3966bf7c54)
1*0e6b6b59SJacob Faibussowitsch #ifndef PETSC_CUPM_THRUST_UTILITY_HPP
2*0e6b6b59SJacob Faibussowitsch #define PETSC_CUPM_THRUST_UTILITY_HPP
3*0e6b6b59SJacob Faibussowitsch 
4*0e6b6b59SJacob Faibussowitsch #include <petsc/private/deviceimpl.h>
5*0e6b6b59SJacob Faibussowitsch #include <petsc/private/cupminterface.hpp>
6*0e6b6b59SJacob Faibussowitsch 
7*0e6b6b59SJacob Faibussowitsch #if defined(__cplusplus)
8*0e6b6b59SJacob Faibussowitsch #include <thrust/device_ptr.h>
9*0e6b6b59SJacob Faibussowitsch #include <thrust/transform.h>
10*0e6b6b59SJacob Faibussowitsch 
11*0e6b6b59SJacob Faibussowitsch namespace Petsc {
12*0e6b6b59SJacob Faibussowitsch 
13*0e6b6b59SJacob Faibussowitsch namespace device {
14*0e6b6b59SJacob Faibussowitsch 
15*0e6b6b59SJacob Faibussowitsch namespace cupm {
16*0e6b6b59SJacob Faibussowitsch 
17*0e6b6b59SJacob Faibussowitsch namespace impl {
18*0e6b6b59SJacob Faibussowitsch 
19*0e6b6b59SJacob Faibussowitsch #if PetscDefined(USING_NVCC)
20*0e6b6b59SJacob Faibussowitsch #if !defined(THRUST_VERSION)
21*0e6b6b59SJacob Faibussowitsch #error "THRUST_VERSION not defined!"
22*0e6b6b59SJacob Faibussowitsch #endif
23*0e6b6b59SJacob Faibussowitsch #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600)
24*0e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__)
25*0e6b6b59SJacob Faibussowitsch #else
26*0e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__)
27*0e6b6b59SJacob Faibussowitsch #endif
28*0e6b6b59SJacob Faibussowitsch #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync
29*0e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__)
30*0e6b6b59SJacob Faibussowitsch #else
31*0e6b6b59SJacob Faibussowitsch #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__)
32*0e6b6b59SJacob Faibussowitsch #endif
33*0e6b6b59SJacob Faibussowitsch 
34*0e6b6b59SJacob Faibussowitsch namespace detail {
35*0e6b6b59SJacob Faibussowitsch 
36*0e6b6b59SJacob Faibussowitsch struct PetscLogGpuTimer {
37*0e6b6b59SJacob Faibussowitsch   PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); }
38*0e6b6b59SJacob Faibussowitsch   ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); }
39*0e6b6b59SJacob Faibussowitsch };
40*0e6b6b59SJacob Faibussowitsch 
41*0e6b6b59SJacob Faibussowitsch struct private_tag { };
42*0e6b6b59SJacob Faibussowitsch 
43*0e6b6b59SJacob Faibussowitsch } // namespace detail
44*0e6b6b59SJacob Faibussowitsch 
45*0e6b6b59SJacob Faibussowitsch #define THRUST_CALL(...) \
46*0e6b6b59SJacob Faibussowitsch   [&] { \
47*0e6b6b59SJacob Faibussowitsch     const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \
48*0e6b6b59SJacob Faibussowitsch     return thrust_call_par_on(__VA_ARGS__); \
49*0e6b6b59SJacob Faibussowitsch   }()
50*0e6b6b59SJacob Faibussowitsch 
51*0e6b6b59SJacob Faibussowitsch #define PetscCallThrust(...) \
52*0e6b6b59SJacob Faibussowitsch   do { \
53*0e6b6b59SJacob Faibussowitsch     try { \
54*0e6b6b59SJacob Faibussowitsch       __VA_ARGS__; \
55*0e6b6b59SJacob Faibussowitsch     } catch (const thrust::system_error &ex) { SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); } \
56*0e6b6b59SJacob Faibussowitsch   } while (0)
57*0e6b6b59SJacob Faibussowitsch 
58*0e6b6b59SJacob Faibussowitsch template <typename T, typename BinaryOperator>
59*0e6b6b59SJacob Faibussowitsch struct shift_operator {
60*0e6b6b59SJacob Faibussowitsch   const T *const       s;
61*0e6b6b59SJacob Faibussowitsch   const BinaryOperator op;
62*0e6b6b59SJacob Faibussowitsch 
63*0e6b6b59SJacob Faibussowitsch   PETSC_HOSTDEVICE_DECL PETSC_FORCEINLINE auto operator()(T x) const PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(op(std::move(x), *s))
64*0e6b6b59SJacob Faibussowitsch };
65*0e6b6b59SJacob Faibussowitsch 
66*0e6b6b59SJacob Faibussowitsch template <typename T, typename BinaryOperator>
67*0e6b6b59SJacob Faibussowitsch static inline auto make_shift_operator(T *s, BinaryOperator &&op) PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(shift_operator<T, BinaryOperator>{s, std::forward<BinaryOperator>(op)});
68*0e6b6b59SJacob Faibussowitsch 
69*0e6b6b59SJacob Faibussowitsch #define PetscValidDevicePointer(ptr, argno) PetscAssert(ptr, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Null device pointer for " PetscStringize(ptr) " Argument #%d", argno);
70*0e6b6b59SJacob Faibussowitsch 
71*0e6b6b59SJacob Faibussowitsch // actual implementation that calls thrust, 2 argument version
72*0e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename FunctorType, typename T>
73*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr)) {
74*0e6b6b59SJacob Faibussowitsch   const auto xptr   = thrust::device_pointer_cast(xinout);
75*0e6b6b59SJacob Faibussowitsch   const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr;
76*0e6b6b59SJacob Faibussowitsch 
77*0e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
78*0e6b6b59SJacob Faibussowitsch   PetscValidDevicePointer(xinout, 4);
79*0e6b6b59SJacob Faibussowitsch   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor)));
80*0e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
81*0e6b6b59SJacob Faibussowitsch }
82*0e6b6b59SJacob Faibussowitsch 
83*0e6b6b59SJacob Faibussowitsch // actual implementation that calls thrust, 3 argument version
84*0e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename FunctorType, typename T>
85*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin)) {
86*0e6b6b59SJacob Faibussowitsch   const auto xptr = thrust::device_pointer_cast(xin);
87*0e6b6b59SJacob Faibussowitsch 
88*0e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
89*0e6b6b59SJacob Faibussowitsch   PetscValidDevicePointer(xin, 4);
90*0e6b6b59SJacob Faibussowitsch   PetscValidDevicePointer(yin, 5);
91*0e6b6b59SJacob Faibussowitsch   PetscValidDevicePointer(zin, 6);
92*0e6b6b59SJacob Faibussowitsch   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor)));
93*0e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
94*0e6b6b59SJacob Faibussowitsch }
95*0e6b6b59SJacob Faibussowitsch 
96*0e6b6b59SJacob Faibussowitsch // one last intermediate function to check n, and log flops for everything
97*0e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename F, typename... T>
98*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest)) {
99*0e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
100*0e6b6b59SJacob Faibussowitsch   PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n);
101*0e6b6b59SJacob Faibussowitsch   if (PetscLikely(n)) {
102*0e6b6b59SJacob Faibussowitsch     PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...));
103*0e6b6b59SJacob Faibussowitsch     PetscCall(PetscLogGpuFlops(n));
104*0e6b6b59SJacob Faibussowitsch   }
105*0e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
106*0e6b6b59SJacob Faibussowitsch }
107*0e6b6b59SJacob Faibussowitsch 
108*0e6b6b59SJacob Faibussowitsch // serves as setup to the real implementation above
109*0e6b6b59SJacob Faibussowitsch template <DeviceType T, typename F, typename... Args>
110*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest)) {
111*0e6b6b59SJacob Faibussowitsch   typename Interface<T>::cupmStream_t stream;
112*0e6b6b59SJacob Faibussowitsch 
113*0e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
114*0e6b6b59SJacob Faibussowitsch   static_assert(sizeof...(Args) <= 3, "");
115*0e6b6b59SJacob Faibussowitsch   PetscValidDeviceContext(dctx, 1);
116*0e6b6b59SJacob Faibussowitsch   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
117*0e6b6b59SJacob Faibussowitsch   PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...));
118*0e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
119*0e6b6b59SJacob Faibussowitsch }
120*0e6b6b59SJacob Faibussowitsch 
121*0e6b6b59SJacob Faibussowitsch #define PetscCallCUPM_(...) \
122*0e6b6b59SJacob Faibussowitsch   do { \
123*0e6b6b59SJacob Faibussowitsch     using interface               = Interface<DT>; \
124*0e6b6b59SJacob Faibussowitsch     using cupmError_t             = typename interface::cupmError_t; \
125*0e6b6b59SJacob Faibussowitsch     const auto cupmName           = []() { return interface::cupmName(); }; \
126*0e6b6b59SJacob Faibussowitsch     const auto cupmGetErrorName   = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \
127*0e6b6b59SJacob Faibussowitsch     const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \
128*0e6b6b59SJacob Faibussowitsch     const auto cupmSuccess        = interface::cupmSuccess; \
129*0e6b6b59SJacob Faibussowitsch     PetscCallCUPM(__VA_ARGS__); \
130*0e6b6b59SJacob Faibussowitsch   } while (0)
131*0e6b6b59SJacob Faibussowitsch 
132*0e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename T>
133*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val)) {
134*0e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
135*0e6b6b59SJacob Faibussowitsch   PetscValidPointer(val, 4);
136*0e6b6b59SJacob Faibussowitsch   if (n) {
137*0e6b6b59SJacob Faibussowitsch     const auto size = n * sizeof(T);
138*0e6b6b59SJacob Faibussowitsch 
139*0e6b6b59SJacob Faibussowitsch     PetscValidDevicePointer(ptr, 3);
140*0e6b6b59SJacob Faibussowitsch     if (*val == T{0}) {
141*0e6b6b59SJacob Faibussowitsch       PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream));
142*0e6b6b59SJacob Faibussowitsch     } else {
143*0e6b6b59SJacob Faibussowitsch       const auto xptr = thrust::device_pointer_cast(ptr);
144*0e6b6b59SJacob Faibussowitsch 
145*0e6b6b59SJacob Faibussowitsch       PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val));
146*0e6b6b59SJacob Faibussowitsch       if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) {
147*0e6b6b59SJacob Faibussowitsch         PetscCall(PetscLogCpuToGpuScalar(size));
148*0e6b6b59SJacob Faibussowitsch       } else {
149*0e6b6b59SJacob Faibussowitsch         PetscCall(PetscLogCpuToGpu(size));
150*0e6b6b59SJacob Faibussowitsch       }
151*0e6b6b59SJacob Faibussowitsch     }
152*0e6b6b59SJacob Faibussowitsch   }
153*0e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
154*0e6b6b59SJacob Faibussowitsch }
155*0e6b6b59SJacob Faibussowitsch 
156*0e6b6b59SJacob Faibussowitsch #undef PetscCallCUPM_
157*0e6b6b59SJacob Faibussowitsch #undef PetscValidDevicePointer
158*0e6b6b59SJacob Faibussowitsch 
159*0e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename T>
160*0e6b6b59SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val)) {
161*0e6b6b59SJacob Faibussowitsch   typename Interface<DT>::cupmStream_t stream;
162*0e6b6b59SJacob Faibussowitsch 
163*0e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
164*0e6b6b59SJacob Faibussowitsch   PetscValidDeviceContext(dctx, 1);
165*0e6b6b59SJacob Faibussowitsch   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
166*0e6b6b59SJacob Faibussowitsch   PetscCall(ThrustSet(stream, n, ptr, val));
167*0e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
168*0e6b6b59SJacob Faibussowitsch }
169*0e6b6b59SJacob Faibussowitsch 
170*0e6b6b59SJacob Faibussowitsch } // namespace impl
171*0e6b6b59SJacob Faibussowitsch 
172*0e6b6b59SJacob Faibussowitsch } // namespace cupm
173*0e6b6b59SJacob Faibussowitsch 
174*0e6b6b59SJacob Faibussowitsch } // namespace device
175*0e6b6b59SJacob Faibussowitsch 
176*0e6b6b59SJacob Faibussowitsch } // namespace Petsc
177*0e6b6b59SJacob Faibussowitsch 
178*0e6b6b59SJacob Faibussowitsch #endif // __cplusplus
179*0e6b6b59SJacob Faibussowitsch 
180*0e6b6b59SJacob Faibussowitsch #endif // PETSC_CUPM_THRUST_UTILITY_HPP
181