xref: /petsc/src/sys/objects/device/impls/cupm/kernels.hpp (revision 0619917b5a674bb687c64e7daba2ab22be99af31)
1 #ifndef PETSC_DEVICE_CUPM_KERNELS_HPP
2 #define PETSC_DEVICE_CUPM_KERNELS_HPP
3 
4 #include <petscdevice_cupm.h>
5 
6 namespace Petsc
7 {
8 
9 namespace device
10 {
11 
12 namespace cupm
13 {
14 
15 namespace kernels
16 {
17 
18 namespace util
19 {
20 
21 template <typename SizeType, typename T>
22 PETSC_DEVICE_INLINE_DECL static void grid_stride_1D(const SizeType size, T &&func) noexcept
23 {
24   for (SizeType i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) func(i);
25   return;
26 }
27 
28 } // namespace util
29 
30 } // namespace kernels
31 
32 namespace functors
33 {
34 
35 template <typename T>
36 class plus_equals {
37 public:
38   using value_type = T;
39 
40   PETSC_HOSTDEVICE_DECL constexpr explicit plus_equals(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
41 
42   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &val) const noexcept { return val + v_; }
43 
44 private:
45   value_type v_;
46 };
47 
48 template <typename T>
49 class times_equals {
50 public:
51   using value_type = T;
52 
53   PETSC_HOSTDEVICE_DECL constexpr explicit times_equals(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
54 
55   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &val) const noexcept { return val * v_; }
56 
57 private:
58   value_type v_;
59 };
60 
61 template <typename T>
62 class axpy {
63 public:
64   using value_type = T;
65 
66   PETSC_HOSTDEVICE_DECL constexpr explicit axpy(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
67 
68   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &x, const value_type &y) const noexcept { return v_ * x + y; }
69 
70 private:
71   value_type v_;
72 };
73 
74 namespace
75 {
76 
77 template <typename T>
78 PETSC_HOSTDEVICE_INLINE_DECL constexpr plus_equals<T> make_plus_equals(const T &v) noexcept
79 {
80   return plus_equals<T>{v};
81 }
82 
83 template <typename T>
84 PETSC_HOSTDEVICE_INLINE_DECL constexpr times_equals<T> make_times_equals(const T &v) noexcept
85 {
86   return times_equals<T>{v};
87 }
88 
89 template <typename T>
90 PETSC_HOSTDEVICE_INLINE_DECL constexpr axpy<T> make_axpy(const T &v) noexcept
91 {
92   return axpy<T>{v};
93 }
94 
95 } // anonymous namespace
96 
97 } // namespace functors
98 
99 } // namespace cupm
100 
101 } // namespace device
102 
103 } // namespace Petsc
104 
105 #endif // PETSC_DEVICE_CUPM_KERNELS_HPP
106