xref: /petsc/src/sys/objects/device/impls/cupm/kernels.hpp (revision fff3f5fae336d1caaec33911beb15ba4efbb662b)
1 #ifndef PETSC_DEVICE_CUPM_KERNELS_HPP
2 #define PETSC_DEVICE_CUPM_KERNELS_HPP
3 
4 #include <petscdevice_cupm.h>
5 
6 #if defined(__cplusplus)
7 
8 namespace Petsc
9 {
10 
11 namespace device
12 {
13 
14 namespace cupm
15 {
16 
17 namespace kernels
18 {
19 
20 namespace util
21 {
22 
23 template <typename SizeType, typename T>
24 PETSC_DEVICE_INLINE_DECL static void grid_stride_1D(const SizeType size, T &&func) noexcept
25 {
26   for (SizeType i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) func(i);
27   return;
28 }
29 
30 } // namespace util
31 
32 } // namespace kernels
33 
34 namespace functors
35 {
36 
37 template <typename T>
38 class plus_equals {
39 public:
40   using value_type = T;
41 
42   PETSC_HOSTDEVICE_DECL constexpr explicit plus_equals(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
43 
44   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &val) const noexcept { return val + v_; }
45 
46 private:
47   value_type v_;
48 };
49 
50 template <typename T>
51 class times_equals {
52 public:
53   using value_type = T;
54 
55   PETSC_HOSTDEVICE_DECL constexpr explicit times_equals(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
56 
57   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &val) const noexcept { return val * v_; }
58 
59 private:
60   value_type v_;
61 };
62 
63 template <typename T>
64 class axpy {
65 public:
66   using value_type = T;
67 
68   PETSC_HOSTDEVICE_DECL constexpr explicit axpy(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
69 
70   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &x, const value_type &y) const noexcept { return v_ * x + y; }
71 
72 private:
73   value_type v_;
74 };
75 
76 namespace
77 {
78 
79 template <typename T>
80 PETSC_HOSTDEVICE_INLINE_DECL constexpr plus_equals<T> make_plus_equals(const T &v) noexcept
81 {
82   return plus_equals<T>{v};
83 }
84 
85 template <typename T>
86 PETSC_HOSTDEVICE_INLINE_DECL constexpr times_equals<T> make_times_equals(const T &v) noexcept
87 {
88   return times_equals<T>{v};
89 }
90 
91 template <typename T>
92 PETSC_HOSTDEVICE_INLINE_DECL constexpr axpy<T> make_axpy(const T &v) noexcept
93 {
94   return axpy<T>{v};
95 }
96 
97 } // anonymous namespace
98 
99 } // namespace functors
100 
101 } // namespace cupm
102 
103 } // namespace device
104 
105 } // namespace Petsc
106 
107 #endif // __cplusplus
108 
109 #endif // PETSC_DEVICE_CUPM_KERNELS_HPP
110