xref: /petsc/src/sys/objects/device/impls/cupm/cupmallocator.hpp (revision 029efca733670abeac35b14fdbd8ea5bc1afb3ff)
1 #ifndef CUPMALLOCATOR_HPP
2 #define CUPMALLOCATOR_HPP
3 
4 #if defined(__cplusplus)
5   #include <petsc/private/cpp/object_pool.hpp>
6 
7   #include "../segmentedmempool.hpp"
8   #include "cupmthrustutility.hpp"
9 
10   #include <thrust/device_ptr.h>
11   #include <thrust/fill.h>
12 
13   #include <limits> // std::numeric_limits
14 
15 namespace Petsc
16 {
17 
18 namespace device
19 {
20 
21 namespace cupm
22 {
23 
24 // ==========================================================================================
25 // CUPM Host Allocator
26 // ==========================================================================================
27 
28 template <DeviceType T, typename PetscType = char>
29 class HostAllocator;
30 
31 // Allocator class to allocate pinned host memory for use with device
32 template <DeviceType T, typename PetscType>
33 class HostAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> {
34 public:
35   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
36   using base_type       = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>;
37   using real_value_type = typename base_type::real_value_type;
38   using size_type       = typename base_type::size_type;
39   using value_type      = typename base_type::value_type;
40 
41   template <typename U>
42   static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept;
43   template <typename U>
44   static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept;
45   template <typename U>
46   static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept;
47 };
48 
49 template <DeviceType T, typename P>
50 template <typename U>
51 inline PetscErrorCode HostAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *) noexcept
52 {
53   PetscFunctionBegin;
54   PetscCall(PetscCUPMMallocHost(ptr, n));
55   PetscFunctionReturn(PETSC_SUCCESS);
56 }
57 
58 template <DeviceType T, typename P>
59 template <typename U>
60 inline PetscErrorCode HostAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *) noexcept
61 {
62   PetscFunctionBegin;
63   PetscCallCUPM(cupmFreeHost(ptr));
64   PetscFunctionReturn(PETSC_SUCCESS);
65 }
66 
67 template <DeviceType T, typename P>
68 template <typename U>
69 inline PetscErrorCode HostAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept
70 {
71   PetscFunctionBegin;
72   PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyHostToHost, stream->get_stream(), true));
73   PetscFunctionReturn(PETSC_SUCCESS);
74 }
75 
76 // ==========================================================================================
77 // CUPM Device Allocator
78 // ==========================================================================================
79 
80 template <DeviceType T, typename PetscType = char>
81 class DeviceAllocator;
82 
83 template <DeviceType T, typename PetscType>
84 class DeviceAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> {
85 public:
86   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
87   using base_type       = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>;
88   using real_value_type = typename base_type::real_value_type;
89   using size_type       = typename base_type::size_type;
90   using value_type      = typename base_type::value_type;
91 
92   template <typename U>
93   static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept;
94   template <typename U>
95   static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept;
96   template <typename U>
97   static PetscErrorCode zero(value_type *, size_type, const StreamBase<U> *) noexcept;
98   template <typename U>
99   static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept;
100   template <typename U>
101   static PetscErrorCode set_canary(value_type *, size_type, const StreamBase<U> *) noexcept;
102 };
103 
104 template <DeviceType T, typename P>
105 template <typename U>
106 inline PetscErrorCode DeviceAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *stream) noexcept
107 {
108   PetscFunctionBegin;
109   PetscCall(PetscCUPMMallocAsync(ptr, n, stream->get_stream()));
110   PetscFunctionReturn(PETSC_SUCCESS);
111 }
112 
113 template <DeviceType T, typename P>
114 template <typename U>
115 inline PetscErrorCode DeviceAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *stream) noexcept
116 {
117   PetscFunctionBegin;
118   PetscCallCUPM(cupmFreeAsync(ptr, stream->get_stream()));
119   PetscFunctionReturn(PETSC_SUCCESS);
120 }
121 
122 template <DeviceType T, typename P>
123 template <typename U>
124 inline PetscErrorCode DeviceAllocator<T, P>::zero(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept
125 {
126   PetscFunctionBegin;
127   PetscCall(PetscCUPMMemsetAsync(ptr, 0, n, stream->get_stream(), true));
128   PetscFunctionReturn(PETSC_SUCCESS);
129 }
130 
131 template <DeviceType T, typename P>
132 template <typename U>
133 inline PetscErrorCode DeviceAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept
134 {
135   PetscFunctionBegin;
136   PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyDeviceToDevice, stream->get_stream(), true));
137   PetscFunctionReturn(PETSC_SUCCESS);
138 }
139 
140 template <DeviceType T, typename P>
141 template <typename U>
142 inline PetscErrorCode DeviceAllocator<T, P>::set_canary(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept
143 {
144   using limit_t           = std::numeric_limits<real_value_type>;
145   const value_type canary = limit_t::has_signaling_NaN ? limit_t::signaling_NaN() : limit_t::max();
146   const auto       xptr   = thrust::device_pointer_cast(ptr);
147 
148   PetscFunctionBegin;
149   PetscCallThrust(THRUST_CALL(thrust::fill, stream->get_stream(), xptr, xptr + n, canary));
150   PetscFunctionReturn(PETSC_SUCCESS);
151 }
152 
153 } // namespace cupm
154 
155 } // namespace device
156 
157 } // namespace Petsc
158 
159 #endif // __cplusplus
160 
161 #endif // CUPMALLOCATOR_HPP
162