xref: /petsc/src/sys/objects/device/impls/cupm/kernels.hpp (revision 1b37a2a7cc4a4fb30c3e967db1c694c0a1013f51)
1 #pragma once
2 
3 #include <petscdevice_cupm.h>
4 
5 namespace Petsc
6 {
7 
8 namespace device
9 {
10 
11 namespace cupm
12 {
13 
14 namespace kernels
15 {
16 
17 namespace util
18 {
19 
20 template <typename SizeType, typename T>
21 PETSC_DEVICE_INLINE_DECL static void grid_stride_1D(const SizeType size, T &&func) noexcept
22 {
23   for (SizeType i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) func(i);
24   return;
25 }
26 
27 } // namespace util
28 
29 } // namespace kernels
30 
31 namespace functors
32 {
33 
34 template <typename T>
35 class plus_equals {
36 public:
37   using value_type = T;
38 
39   PETSC_HOSTDEVICE_DECL constexpr explicit plus_equals(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
40 
41   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &val) const noexcept { return val + v_; }
42 
43 private:
44   value_type v_;
45 };
46 
47 template <typename T>
48 class times_equals {
49 public:
50   using value_type = T;
51 
52   PETSC_HOSTDEVICE_DECL constexpr explicit times_equals(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
53 
54   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &val) const noexcept { return val * v_; }
55 
56 private:
57   value_type v_;
58 };
59 
60 template <typename T>
61 class axpy {
62 public:
63   using value_type = T;
64 
65   PETSC_HOSTDEVICE_DECL constexpr explicit axpy(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
66 
67   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &x, const value_type &y) const noexcept { return v_ * x + y; }
68 
69 private:
70   value_type v_;
71 };
72 
73 namespace
74 {
75 
76 template <typename T>
77 PETSC_HOSTDEVICE_INLINE_DECL constexpr plus_equals<T> make_plus_equals(const T &v) noexcept
78 {
79   return plus_equals<T>{v};
80 }
81 
82 template <typename T>
83 PETSC_HOSTDEVICE_INLINE_DECL constexpr times_equals<T> make_times_equals(const T &v) noexcept
84 {
85   return times_equals<T>{v};
86 }
87 
88 template <typename T>
89 PETSC_HOSTDEVICE_INLINE_DECL constexpr axpy<T> make_axpy(const T &v) noexcept
90 {
91   return axpy<T>{v};
92 }
93 
94 } // anonymous namespace
95 
96 } // namespace functors
97 
98 } // namespace cupm
99 
100 } // namespace device
101 
102 } // namespace Petsc
103