xref: /petsc/src/sys/objects/device/impls/cupm/cupmallocator.hpp (revision 1cc06b555e92f8ec64db10330b8bbd830e5bc876)
1 #ifndef CUPMALLOCATOR_HPP
2 #define CUPMALLOCATOR_HPP
3 
4 #include <petsc/private/cpp/object_pool.hpp>
5 
6 #include "../segmentedmempool.hpp"
7 #include "cupmthrustutility.hpp"
8 
9 #include <thrust/device_ptr.h>
10 #include <thrust/fill.h>
11 
12 #include <limits> // std::numeric_limits
13 
14 namespace Petsc
15 {
16 
17 namespace device
18 {
19 
20 namespace cupm
21 {
22 
23 // ==========================================================================================
24 // CUPM Host Allocator
25 // ==========================================================================================
26 
27 template <DeviceType T, typename PetscType = char>
28 class HostAllocator;
29 
30 // Allocator class to allocate pinned host memory for use with device
31 template <DeviceType T, typename PetscType>
32 class HostAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> {
33 public:
34   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
35   using base_type       = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>;
36   using real_value_type = typename base_type::real_value_type;
37   using size_type       = typename base_type::size_type;
38   using value_type      = typename base_type::value_type;
39 
40   template <typename U>
41   static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept;
42   template <typename U>
43   static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept;
44   template <typename U>
45   static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept;
46 };
47 
48 template <DeviceType T, typename P>
49 template <typename U>
50 inline PetscErrorCode HostAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *) noexcept
51 {
52   PetscFunctionBegin;
53   PetscCall(PetscCUPMMallocHost(ptr, n));
54   PetscFunctionReturn(PETSC_SUCCESS);
55 }
56 
57 template <DeviceType T, typename P>
58 template <typename U>
59 inline PetscErrorCode HostAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *) noexcept
60 {
61   PetscFunctionBegin;
62   PetscCallCUPM(cupmFreeHost(ptr));
63   PetscFunctionReturn(PETSC_SUCCESS);
64 }
65 
66 template <DeviceType T, typename P>
67 template <typename U>
68 inline PetscErrorCode HostAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept
69 {
70   PetscFunctionBegin;
71   PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyHostToHost, stream->get_stream(), true));
72   PetscFunctionReturn(PETSC_SUCCESS);
73 }
74 
75 // ==========================================================================================
76 // CUPM Device Allocator
77 // ==========================================================================================
78 
79 template <DeviceType T, typename PetscType = char>
80 class DeviceAllocator;
81 
82 template <DeviceType T, typename PetscType>
83 class DeviceAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> {
84 public:
85   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
86   using base_type       = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>;
87   using real_value_type = typename base_type::real_value_type;
88   using size_type       = typename base_type::size_type;
89   using value_type      = typename base_type::value_type;
90 
91   template <typename U>
92   static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept;
93   template <typename U>
94   static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept;
95   template <typename U>
96   static PetscErrorCode zero(value_type *, size_type, const StreamBase<U> *) noexcept;
97   template <typename U>
98   static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept;
99   template <typename U>
100   static PetscErrorCode set_canary(value_type *, size_type, const StreamBase<U> *) noexcept;
101 };
102 
103 template <DeviceType T, typename P>
104 template <typename U>
105 inline PetscErrorCode DeviceAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *stream) noexcept
106 {
107   PetscFunctionBegin;
108   PetscCall(PetscCUPMMallocAsync(ptr, n, stream->get_stream()));
109   PetscFunctionReturn(PETSC_SUCCESS);
110 }
111 
112 template <DeviceType T, typename P>
113 template <typename U>
114 inline PetscErrorCode DeviceAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *stream) noexcept
115 {
116   PetscFunctionBegin;
117   PetscCallCUPM(cupmFreeAsync(ptr, stream->get_stream()));
118   PetscFunctionReturn(PETSC_SUCCESS);
119 }
120 
121 template <DeviceType T, typename P>
122 template <typename U>
123 inline PetscErrorCode DeviceAllocator<T, P>::zero(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept
124 {
125   PetscFunctionBegin;
126   PetscCall(PetscCUPMMemsetAsync(ptr, 0, n, stream->get_stream(), true));
127   PetscFunctionReturn(PETSC_SUCCESS);
128 }
129 
130 template <DeviceType T, typename P>
131 template <typename U>
132 inline PetscErrorCode DeviceAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept
133 {
134   PetscFunctionBegin;
135   PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyDeviceToDevice, stream->get_stream(), true));
136   PetscFunctionReturn(PETSC_SUCCESS);
137 }
138 
139 template <DeviceType T, typename P>
140 template <typename U>
141 inline PetscErrorCode DeviceAllocator<T, P>::set_canary(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept
142 {
143   using limit_t           = std::numeric_limits<real_value_type>;
144   const value_type canary = limit_t::has_signaling_NaN ? limit_t::signaling_NaN() : limit_t::max();
145   const auto       xptr   = thrust::device_pointer_cast(ptr);
146 
147   PetscFunctionBegin;
148   PetscCallThrust(THRUST_CALL(thrust::fill, stream->get_stream(), xptr, xptr + n, canary));
149   PetscFunctionReturn(PETSC_SUCCESS);
150 }
151 
152 } // namespace cupm
153 
154 } // namespace device
155 
156 } // namespace Petsc
157 
158 #endif // CUPMALLOCATOR_HPP
159