xref: /petsc/src/sys/objects/device/impls/cupm/cupmthrustutility.hpp (revision d71ae5a4db6382e7f06317b8d368875286fe9008)
10e6b6b59SJacob Faibussowitsch #ifndef PETSC_CUPM_THRUST_UTILITY_HPP
20e6b6b59SJacob Faibussowitsch #define PETSC_CUPM_THRUST_UTILITY_HPP
30e6b6b59SJacob Faibussowitsch 
40e6b6b59SJacob Faibussowitsch #include <petsc/private/deviceimpl.h>
50e6b6b59SJacob Faibussowitsch #include <petsc/private/cupminterface.hpp>
60e6b6b59SJacob Faibussowitsch 
70e6b6b59SJacob Faibussowitsch #if defined(__cplusplus)
80e6b6b59SJacob Faibussowitsch   #include <thrust/device_ptr.h>
90e6b6b59SJacob Faibussowitsch   #include <thrust/transform.h>
100e6b6b59SJacob Faibussowitsch 
11*d71ae5a4SJacob Faibussowitsch namespace Petsc
12*d71ae5a4SJacob Faibussowitsch {
130e6b6b59SJacob Faibussowitsch 
14*d71ae5a4SJacob Faibussowitsch namespace device
15*d71ae5a4SJacob Faibussowitsch {
160e6b6b59SJacob Faibussowitsch 
17*d71ae5a4SJacob Faibussowitsch namespace cupm
18*d71ae5a4SJacob Faibussowitsch {
190e6b6b59SJacob Faibussowitsch 
20*d71ae5a4SJacob Faibussowitsch namespace impl
21*d71ae5a4SJacob Faibussowitsch {
220e6b6b59SJacob Faibussowitsch 
230e6b6b59SJacob Faibussowitsch   #if PetscDefined(USING_NVCC)
240e6b6b59SJacob Faibussowitsch     #if !defined(THRUST_VERSION)
250e6b6b59SJacob Faibussowitsch       #error "THRUST_VERSION not defined!"
260e6b6b59SJacob Faibussowitsch     #endif
270e6b6b59SJacob Faibussowitsch     #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600)
280e6b6b59SJacob Faibussowitsch       #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__)
290e6b6b59SJacob Faibussowitsch     #else
300e6b6b59SJacob Faibussowitsch       #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__)
310e6b6b59SJacob Faibussowitsch     #endif
320e6b6b59SJacob Faibussowitsch   #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync
330e6b6b59SJacob Faibussowitsch     #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__)
340e6b6b59SJacob Faibussowitsch   #else
350e6b6b59SJacob Faibussowitsch     #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__)
360e6b6b59SJacob Faibussowitsch   #endif
370e6b6b59SJacob Faibussowitsch 
38*d71ae5a4SJacob Faibussowitsch namespace detail
39*d71ae5a4SJacob Faibussowitsch {
400e6b6b59SJacob Faibussowitsch 
410e6b6b59SJacob Faibussowitsch struct PetscLogGpuTimer {
420e6b6b59SJacob Faibussowitsch   PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); }
430e6b6b59SJacob Faibussowitsch   ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); }
440e6b6b59SJacob Faibussowitsch };
450e6b6b59SJacob Faibussowitsch 
460e6b6b59SJacob Faibussowitsch struct private_tag { };
470e6b6b59SJacob Faibussowitsch 
480e6b6b59SJacob Faibussowitsch } // namespace detail
490e6b6b59SJacob Faibussowitsch 
500e6b6b59SJacob Faibussowitsch   #define THRUST_CALL(...) \
510e6b6b59SJacob Faibussowitsch     [&] { \
520e6b6b59SJacob Faibussowitsch       const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \
530e6b6b59SJacob Faibussowitsch       return thrust_call_par_on(__VA_ARGS__); \
540e6b6b59SJacob Faibussowitsch     }()
550e6b6b59SJacob Faibussowitsch 
560e6b6b59SJacob Faibussowitsch   #define PetscCallThrust(...) \
570e6b6b59SJacob Faibussowitsch     do { \
580e6b6b59SJacob Faibussowitsch       try { \
590e6b6b59SJacob Faibussowitsch         __VA_ARGS__; \
60*d71ae5a4SJacob Faibussowitsch       } catch (const thrust::system_error &ex) { \
61*d71ae5a4SJacob Faibussowitsch         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \
62*d71ae5a4SJacob Faibussowitsch       } \
630e6b6b59SJacob Faibussowitsch     } while (0)
640e6b6b59SJacob Faibussowitsch 
650e6b6b59SJacob Faibussowitsch template <typename T, typename BinaryOperator>
660e6b6b59SJacob Faibussowitsch struct shift_operator {
670e6b6b59SJacob Faibussowitsch   const T *const       s;
680e6b6b59SJacob Faibussowitsch   const BinaryOperator op;
690e6b6b59SJacob Faibussowitsch 
700e6b6b59SJacob Faibussowitsch   PETSC_HOSTDEVICE_DECL PETSC_FORCEINLINE auto operator()(T x) const PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(op(std::move(x), *s))
710e6b6b59SJacob Faibussowitsch };
720e6b6b59SJacob Faibussowitsch 
730e6b6b59SJacob Faibussowitsch template <typename T, typename BinaryOperator>
740e6b6b59SJacob Faibussowitsch static inline auto make_shift_operator(T *s, BinaryOperator &&op) PETSC_DECLTYPE_NOEXCEPT_AUTO_RETURNS(shift_operator<T, BinaryOperator>{s, std::forward<BinaryOperator>(op)});
750e6b6b59SJacob Faibussowitsch 
760e6b6b59SJacob Faibussowitsch   #define PetscValidDevicePointer(ptr, argno) PetscAssert(ptr, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Null device pointer for " PetscStringize(ptr) " Argument #%d", argno);
770e6b6b59SJacob Faibussowitsch 
780e6b6b59SJacob Faibussowitsch // actual implementation that calls thrust, 2 argument version
790e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename FunctorType, typename T>
80*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr))
81*d71ae5a4SJacob Faibussowitsch {
820e6b6b59SJacob Faibussowitsch   const auto xptr   = thrust::device_pointer_cast(xinout);
830e6b6b59SJacob Faibussowitsch   const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr;
840e6b6b59SJacob Faibussowitsch 
850e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
860e6b6b59SJacob Faibussowitsch   PetscValidDevicePointer(xinout, 4);
870e6b6b59SJacob Faibussowitsch   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor)));
880e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
890e6b6b59SJacob Faibussowitsch }
900e6b6b59SJacob Faibussowitsch 
910e6b6b59SJacob Faibussowitsch // actual implementation that calls thrust, 3 argument version
920e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename FunctorType, typename T>
93*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin))
94*d71ae5a4SJacob Faibussowitsch {
950e6b6b59SJacob Faibussowitsch   const auto xptr = thrust::device_pointer_cast(xin);
960e6b6b59SJacob Faibussowitsch 
970e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
980e6b6b59SJacob Faibussowitsch   PetscValidDevicePointer(xin, 4);
990e6b6b59SJacob Faibussowitsch   PetscValidDevicePointer(yin, 5);
1000e6b6b59SJacob Faibussowitsch   PetscValidDevicePointer(zin, 6);
1010e6b6b59SJacob Faibussowitsch   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor)));
1020e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
1030e6b6b59SJacob Faibussowitsch }
1040e6b6b59SJacob Faibussowitsch 
1050e6b6b59SJacob Faibussowitsch // one last intermediate function to check n, and log flops for everything
1060e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename F, typename... T>
107*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest))
108*d71ae5a4SJacob Faibussowitsch {
1090e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
1100e6b6b59SJacob Faibussowitsch   PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n);
1110e6b6b59SJacob Faibussowitsch   if (PetscLikely(n)) {
1120e6b6b59SJacob Faibussowitsch     PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...));
1130e6b6b59SJacob Faibussowitsch     PetscCall(PetscLogGpuFlops(n));
1140e6b6b59SJacob Faibussowitsch   }
1150e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
1160e6b6b59SJacob Faibussowitsch }
1170e6b6b59SJacob Faibussowitsch 
1180e6b6b59SJacob Faibussowitsch // serves as setup to the real implementation above
1190e6b6b59SJacob Faibussowitsch template <DeviceType T, typename F, typename... Args>
120*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest))
121*d71ae5a4SJacob Faibussowitsch {
1220e6b6b59SJacob Faibussowitsch   typename Interface<T>::cupmStream_t stream;
1230e6b6b59SJacob Faibussowitsch 
1240e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
1250e6b6b59SJacob Faibussowitsch   static_assert(sizeof...(Args) <= 3, "");
1260e6b6b59SJacob Faibussowitsch   PetscValidDeviceContext(dctx, 1);
1270e6b6b59SJacob Faibussowitsch   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
1280e6b6b59SJacob Faibussowitsch   PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...));
1290e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
1300e6b6b59SJacob Faibussowitsch }
1310e6b6b59SJacob Faibussowitsch 
1320e6b6b59SJacob Faibussowitsch   #define PetscCallCUPM_(...) \
1330e6b6b59SJacob Faibussowitsch     do { \
1340e6b6b59SJacob Faibussowitsch       using interface               = Interface<DT>; \
1350e6b6b59SJacob Faibussowitsch       using cupmError_t             = typename interface::cupmError_t; \
1360e6b6b59SJacob Faibussowitsch       const auto cupmName           = []() { return interface::cupmName(); }; \
1370e6b6b59SJacob Faibussowitsch       const auto cupmGetErrorName   = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \
1380e6b6b59SJacob Faibussowitsch       const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \
1390e6b6b59SJacob Faibussowitsch       const auto cupmSuccess        = interface::cupmSuccess; \
1400e6b6b59SJacob Faibussowitsch       PetscCallCUPM(__VA_ARGS__); \
1410e6b6b59SJacob Faibussowitsch     } while (0)
1420e6b6b59SJacob Faibussowitsch 
1430e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename T>
144*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val))
145*d71ae5a4SJacob Faibussowitsch {
1460e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
1470e6b6b59SJacob Faibussowitsch   PetscValidPointer(val, 4);
1480e6b6b59SJacob Faibussowitsch   if (n) {
1490e6b6b59SJacob Faibussowitsch     const auto size = n * sizeof(T);
1500e6b6b59SJacob Faibussowitsch 
1510e6b6b59SJacob Faibussowitsch     PetscValidDevicePointer(ptr, 3);
1520e6b6b59SJacob Faibussowitsch     if (*val == T{0}) {
1530e6b6b59SJacob Faibussowitsch       PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream));
1540e6b6b59SJacob Faibussowitsch     } else {
1550e6b6b59SJacob Faibussowitsch       const auto xptr = thrust::device_pointer_cast(ptr);
1560e6b6b59SJacob Faibussowitsch 
1570e6b6b59SJacob Faibussowitsch       PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val));
1580e6b6b59SJacob Faibussowitsch       if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) {
1590e6b6b59SJacob Faibussowitsch         PetscCall(PetscLogCpuToGpuScalar(size));
1600e6b6b59SJacob Faibussowitsch       } else {
1610e6b6b59SJacob Faibussowitsch         PetscCall(PetscLogCpuToGpu(size));
1620e6b6b59SJacob Faibussowitsch       }
1630e6b6b59SJacob Faibussowitsch     }
1640e6b6b59SJacob Faibussowitsch   }
1650e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
1660e6b6b59SJacob Faibussowitsch }
1670e6b6b59SJacob Faibussowitsch 
1680e6b6b59SJacob Faibussowitsch   #undef PetscCallCUPM_
1690e6b6b59SJacob Faibussowitsch   #undef PetscValidDevicePointer
1700e6b6b59SJacob Faibussowitsch 
1710e6b6b59SJacob Faibussowitsch template <DeviceType DT, typename T>
172*d71ae5a4SJacob Faibussowitsch PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val))
173*d71ae5a4SJacob Faibussowitsch {
1740e6b6b59SJacob Faibussowitsch   typename Interface<DT>::cupmStream_t stream;
1750e6b6b59SJacob Faibussowitsch 
1760e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
1770e6b6b59SJacob Faibussowitsch   PetscValidDeviceContext(dctx, 1);
1780e6b6b59SJacob Faibussowitsch   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
1790e6b6b59SJacob Faibussowitsch   PetscCall(ThrustSet(stream, n, ptr, val));
1800e6b6b59SJacob Faibussowitsch   PetscFunctionReturn(0);
1810e6b6b59SJacob Faibussowitsch }
1820e6b6b59SJacob Faibussowitsch 
1830e6b6b59SJacob Faibussowitsch } // namespace impl
1840e6b6b59SJacob Faibussowitsch 
1850e6b6b59SJacob Faibussowitsch } // namespace cupm
1860e6b6b59SJacob Faibussowitsch 
1870e6b6b59SJacob Faibussowitsch } // namespace device
1880e6b6b59SJacob Faibussowitsch 
1890e6b6b59SJacob Faibussowitsch } // namespace Petsc
1900e6b6b59SJacob Faibussowitsch 
1910e6b6b59SJacob Faibussowitsch #endif // __cplusplus
1920e6b6b59SJacob Faibussowitsch 
1930e6b6b59SJacob Faibussowitsch #endif // PETSC_CUPM_THRUST_UTILITY_HPP
194