xref: /petsc/src/sys/objects/device/impls/cupm/cupmthrustutility.hpp (revision 21e3ffae2f3b73c0bd738cf6d0a809700fc04bb0)
1 #ifndef PETSC_CUPM_THRUST_UTILITY_HPP
2 #define PETSC_CUPM_THRUST_UTILITY_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupminterface.hpp>
6 
7 #if defined(__cplusplus)
8   #include <thrust/device_ptr.h>
9   #include <thrust/transform.h>
10 
11 namespace Petsc
12 {
13 
14 namespace device
15 {
16 
17 namespace cupm
18 {
19 
20 namespace impl
21 {
22 
23   #if PetscDefined(USING_NVCC)
24     #if !defined(THRUST_VERSION)
25       #error "THRUST_VERSION not defined!"
26     #endif
27     #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600)
28       #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__)
29     #else
30       #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__)
31     #endif
32   #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync
33     #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__)
34   #else
35     #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__)
36   #endif
37 
38 namespace detail
39 {
40 
41 struct PetscLogGpuTimer {
42   PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); }
43   ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); }
44 };
45 
46 struct private_tag { };
47 
48 } // namespace detail
49 
50   #define THRUST_CALL(...) \
51     [&] { \
52       const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \
53       return thrust_call_par_on(__VA_ARGS__); \
54     }()
55 
56   #define PetscCallThrust(...) \
57     do { \
58       try { \
59         __VA_ARGS__; \
60       } catch (const thrust::system_error &ex) { \
61         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \
62       } \
63     } while (0)
64 
65   #define PetscValidDevicePointer(ptr, argno) PetscAssert(ptr, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Null device pointer for " PetscStringize(ptr) " Argument #%d", argno);
66 
67 // actual implementation that calls thrust, 2 argument version
68 template <DeviceType DT, typename FunctorType, typename T>
69 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr))
70 {
71   const auto xptr   = thrust::device_pointer_cast(xinout);
72   const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr;
73 
74   PetscFunctionBegin;
75   PetscValidDevicePointer(xinout, 4);
76   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor)));
77   PetscFunctionReturn(PETSC_SUCCESS);
78 }
79 
80 // actual implementation that calls thrust, 3 argument version
81 template <DeviceType DT, typename FunctorType, typename T>
82 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin))
83 {
84   const auto xptr = thrust::device_pointer_cast(xin);
85 
86   PetscFunctionBegin;
87   PetscValidDevicePointer(xin, 4);
88   PetscValidDevicePointer(yin, 5);
89   PetscValidDevicePointer(zin, 6);
90   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor)));
91   PetscFunctionReturn(PETSC_SUCCESS);
92 }
93 
94 // one last intermediate function to check n, and log flops for everything
95 template <DeviceType DT, typename F, typename... T>
96 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest))
97 {
98   PetscFunctionBegin;
99   PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n);
100   if (PetscLikely(n)) {
101     PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...));
102     PetscCall(PetscLogGpuFlops(n));
103   }
104   PetscFunctionReturn(PETSC_SUCCESS);
105 }
106 
107 // serves as setup to the real implementation above
108 template <DeviceType T, typename F, typename... Args>
109 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest))
110 {
111   typename Interface<T>::cupmStream_t stream;
112 
113   PetscFunctionBegin;
114   static_assert(sizeof...(Args) <= 3, "");
115   PetscValidDeviceContext(dctx, 1);
116   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
117   PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...));
118   PetscFunctionReturn(PETSC_SUCCESS);
119 }
120 
121   #define PetscCallCUPM_(...) \
122     do { \
123       using interface               = Interface<DT>; \
124       using cupmError_t             = typename interface::cupmError_t; \
125       const auto cupmName           = []() { return interface::cupmName(); }; \
126       const auto cupmGetErrorName   = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \
127       const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \
128       const auto cupmSuccess        = interface::cupmSuccess; \
129       PetscCallCUPM(__VA_ARGS__); \
130     } while (0)
131 
132 template <DeviceType DT, typename T>
133 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val))
134 {
135   PetscFunctionBegin;
136   PetscValidPointer(val, 4);
137   if (n) {
138     const auto size = n * sizeof(T);
139 
140     PetscValidDevicePointer(ptr, 3);
141     if (*val == T{0}) {
142       PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream));
143     } else {
144       const auto xptr = thrust::device_pointer_cast(ptr);
145 
146       PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val));
147       if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) {
148         PetscCall(PetscLogCpuToGpuScalar(size));
149       } else {
150         PetscCall(PetscLogCpuToGpu(size));
151       }
152     }
153   }
154   PetscFunctionReturn(PETSC_SUCCESS);
155 }
156 
157   #undef PetscCallCUPM_
158   #undef PetscValidDevicePointer
159 
160 template <DeviceType DT, typename T>
161 PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val))
162 {
163   typename Interface<DT>::cupmStream_t stream;
164 
165   PetscFunctionBegin;
166   PetscValidDeviceContext(dctx, 1);
167   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
168   PetscCall(ThrustSet(stream, n, ptr, val));
169   PetscFunctionReturn(PETSC_SUCCESS);
170 }
171 
172 } // namespace impl
173 
174 } // namespace cupm
175 
176 } // namespace device
177 
178 } // namespace Petsc
179 
180 #endif // __cplusplus
181 
182 #endif // PETSC_CUPM_THRUST_UTILITY_HPP
183