xref: /petsc/src/sys/objects/device/impls/cupm/kernels.hpp (revision 9dd11ecf0918283bb567d8b33a92f53ac4ea7840)
1*a4963045SJacob Faibussowitsch #pragma once
26d54fb17SJacob Faibussowitsch 
36d54fb17SJacob Faibussowitsch #include <petscdevice_cupm.h>
46d54fb17SJacob Faibussowitsch 
56d54fb17SJacob Faibussowitsch namespace Petsc
66d54fb17SJacob Faibussowitsch {
76d54fb17SJacob Faibussowitsch 
86d54fb17SJacob Faibussowitsch namespace device
96d54fb17SJacob Faibussowitsch {
106d54fb17SJacob Faibussowitsch 
116d54fb17SJacob Faibussowitsch namespace cupm
126d54fb17SJacob Faibussowitsch {
136d54fb17SJacob Faibussowitsch 
146d54fb17SJacob Faibussowitsch namespace kernels
156d54fb17SJacob Faibussowitsch {
166d54fb17SJacob Faibussowitsch 
176d54fb17SJacob Faibussowitsch namespace util
186d54fb17SJacob Faibussowitsch {
196d54fb17SJacob Faibussowitsch 
206d54fb17SJacob Faibussowitsch template <typename SizeType, typename T>
grid_stride_1D(const SizeType size,T && func)216d54fb17SJacob Faibussowitsch PETSC_DEVICE_INLINE_DECL static void grid_stride_1D(const SizeType size, T &&func) noexcept
226d54fb17SJacob Faibussowitsch {
236d54fb17SJacob Faibussowitsch   for (SizeType i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) func(i);
246d54fb17SJacob Faibussowitsch   return;
256d54fb17SJacob Faibussowitsch }
266d54fb17SJacob Faibussowitsch 
276d54fb17SJacob Faibussowitsch } // namespace util
286d54fb17SJacob Faibussowitsch 
296d54fb17SJacob Faibussowitsch } // namespace kernels
306d54fb17SJacob Faibussowitsch 
3190585354SJacob Faibussowitsch namespace functors
3290585354SJacob Faibussowitsch {
3390585354SJacob Faibussowitsch 
3490585354SJacob Faibussowitsch template <typename T>
3590585354SJacob Faibussowitsch class plus_equals {
3690585354SJacob Faibussowitsch public:
3790585354SJacob Faibussowitsch   using value_type = T;
3890585354SJacob Faibussowitsch 
plus_equals(value_type v=value_type{})3990585354SJacob Faibussowitsch   PETSC_HOSTDEVICE_DECL constexpr explicit plus_equals(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
4090585354SJacob Faibussowitsch 
operator ()(const value_type & val) const4190585354SJacob Faibussowitsch   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &val) const noexcept { return val + v_; }
4290585354SJacob Faibussowitsch 
4390585354SJacob Faibussowitsch private:
4490585354SJacob Faibussowitsch   value_type v_;
4590585354SJacob Faibussowitsch };
4690585354SJacob Faibussowitsch 
472ea277ceSJacob Faibussowitsch template <typename T>
482ea277ceSJacob Faibussowitsch class times_equals {
492ea277ceSJacob Faibussowitsch public:
502ea277ceSJacob Faibussowitsch   using value_type = T;
512ea277ceSJacob Faibussowitsch 
times_equals(value_type v=value_type{})522ea277ceSJacob Faibussowitsch   PETSC_HOSTDEVICE_DECL constexpr explicit times_equals(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
532ea277ceSJacob Faibussowitsch 
operator ()(const value_type & val) const542ea277ceSJacob Faibussowitsch   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &val) const noexcept { return val * v_; }
552ea277ceSJacob Faibussowitsch 
562ea277ceSJacob Faibussowitsch private:
572ea277ceSJacob Faibussowitsch   value_type v_;
582ea277ceSJacob Faibussowitsch };
592ea277ceSJacob Faibussowitsch 
60025e0618SJacob Faibussowitsch template <typename T>
61025e0618SJacob Faibussowitsch class axpy {
62025e0618SJacob Faibussowitsch public:
63025e0618SJacob Faibussowitsch   using value_type = T;
64025e0618SJacob Faibussowitsch 
axpy(value_type v=value_type{})65025e0618SJacob Faibussowitsch   PETSC_HOSTDEVICE_DECL constexpr explicit axpy(value_type v = value_type{}) noexcept : v_{std::move(v)} { }
66025e0618SJacob Faibussowitsch 
operator ()(const value_type & x,const value_type & y) const67025e0618SJacob Faibussowitsch   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr value_type operator()(const value_type &x, const value_type &y) const noexcept { return v_ * x + y; }
68025e0618SJacob Faibussowitsch 
69025e0618SJacob Faibussowitsch private:
70025e0618SJacob Faibussowitsch   value_type v_;
71025e0618SJacob Faibussowitsch };
72025e0618SJacob Faibussowitsch 
7390585354SJacob Faibussowitsch namespace
7490585354SJacob Faibussowitsch {
7590585354SJacob Faibussowitsch 
7690585354SJacob Faibussowitsch template <typename T>
make_plus_equals(const T & v)7790585354SJacob Faibussowitsch PETSC_HOSTDEVICE_INLINE_DECL constexpr plus_equals<T> make_plus_equals(const T &v) noexcept
7890585354SJacob Faibussowitsch {
7990585354SJacob Faibussowitsch   return plus_equals<T>{v};
8090585354SJacob Faibussowitsch }
8190585354SJacob Faibussowitsch 
822ea277ceSJacob Faibussowitsch template <typename T>
make_times_equals(const T & v)832ea277ceSJacob Faibussowitsch PETSC_HOSTDEVICE_INLINE_DECL constexpr times_equals<T> make_times_equals(const T &v) noexcept
842ea277ceSJacob Faibussowitsch {
852ea277ceSJacob Faibussowitsch   return times_equals<T>{v};
862ea277ceSJacob Faibussowitsch }
872ea277ceSJacob Faibussowitsch 
88025e0618SJacob Faibussowitsch template <typename T>
make_axpy(const T & v)89025e0618SJacob Faibussowitsch PETSC_HOSTDEVICE_INLINE_DECL constexpr axpy<T> make_axpy(const T &v) noexcept
90025e0618SJacob Faibussowitsch {
91025e0618SJacob Faibussowitsch   return axpy<T>{v};
92025e0618SJacob Faibussowitsch }
93025e0618SJacob Faibussowitsch 
9490585354SJacob Faibussowitsch } // anonymous namespace
9590585354SJacob Faibussowitsch 
9690585354SJacob Faibussowitsch } // namespace functors
9790585354SJacob Faibussowitsch 
986d54fb17SJacob Faibussowitsch } // namespace cupm
996d54fb17SJacob Faibussowitsch 
1006d54fb17SJacob Faibussowitsch } // namespace device
1016d54fb17SJacob Faibussowitsch 
1026d54fb17SJacob Faibussowitsch } // namespace Petsc
103