1 #ifndef PETSC_DEVICE_CUPM_KERNELS_HPP 2 #define PETSC_DEVICE_CUPM_KERNELS_HPP 3 4 #include <petscdevice_cupm.h> 5 6 namespace Petsc 7 { 8 9 namespace device 10 { 11 12 namespace cupm 13 { 14 15 namespace kernels 16 { 17 18 namespace util 19 { 20 21 template <typename SizeType, typename T> 22 PETSC_DEVICE_INLINE_DECL static void grid_stride_1D(const SizeType size, T &&func) noexcept 23 { 24 for (SizeType i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) func(i); 25 return; 26 } 27 28 } // namespace util 29 30 } // namespace kernels 31 32 namespace functors 33 { 34 35 template <typename T> 36 class plus_equals { 37 public: 38 using value_type = T; 39 40 PETSC_HOSTDEVICE_DECL constexpr explicit plus_equals(value_type v = value_type{}) noexcept : v_{std::move(v)} { } 41 42 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &val) const noexcept { return val + v_; } 43 44 private: 45 value_type v_; 46 }; 47 48 template <typename T> 49 class times_equals { 50 public: 51 using value_type = T; 52 53 PETSC_HOSTDEVICE_DECL constexpr explicit times_equals(value_type v = value_type{}) noexcept : v_{std::move(v)} { } 54 55 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &val) const noexcept { return val * v_; } 56 57 private: 58 value_type v_; 59 }; 60 61 template <typename T> 62 class axpy { 63 public: 64 using value_type = T; 65 66 PETSC_HOSTDEVICE_DECL constexpr explicit axpy(value_type v = value_type{}) noexcept : v_{std::move(v)} { } 67 68 PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &x, const value_type &y) const noexcept { return v_ * x + y; } 69 70 private: 71 value_type v_; 72 }; 73 74 namespace 75 { 76 77 template <typename T> 78 PETSC_HOSTDEVICE_INLINE_DECL constexpr plus_equals<T> make_plus_equals(const T &v) noexcept 79 { 80 return plus_equals<T>{v}; 81 } 82 83 template <typename T> 84 PETSC_HOSTDEVICE_INLINE_DECL constexpr times_equals<T> make_times_equals(const T &v) noexcept 85 { 86 return times_equals<T>{v}; 87 } 88 89 template <typename T> 90 PETSC_HOSTDEVICE_INLINE_DECL constexpr axpy<T> make_axpy(const T &v) noexcept 91 { 92 return axpy<T>{v}; 93 } 94 95 } // anonymous namespace 96 97 } // namespace functors 98 99 } // namespace cupm 100 101 } // namespace device 102 103 } // namespace Petsc 104 105 #endif // PETSC_DEVICE_CUPM_KERNELS_HPP 106