xref: /petsc/src/sys/objects/device/impls/cupm/cupmallocator.hpp (revision d71ae5a4db6382e7f06317b8d368875286fe9008)
1 #ifndef CUPMALLOCATOR_HPP
2 #define CUPMALLOCATOR_HPP
3 
4 #if defined(__cplusplus)
5   #include <petsc/private/cpp/object_pool.hpp>
6 
7   #include "../segmentedmempool.hpp"
8   #include "cupmthrustutility.hpp"
9 
10   #include <limits> // std::numeric_limits
11 
12 namespace Petsc
13 {
14 
15 namespace device
16 {
17 
18 namespace cupm
19 {
20 
21 // ==========================================================================================
22 // CUPM Host Allocator
23 // ==========================================================================================
24 
25 template <DeviceType T, typename PetscType = char>
26 class HostAllocator;
27 
28 // Allocator class to allocate pinned host memory for use with device
29 template <DeviceType T, typename PetscType>
30 class HostAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> {
31 public:
32   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, T);
33   using base_type = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>;
34   using typename base_type::real_value_type;
35   using typename base_type::size_type;
36   using typename base_type::value_type;
37 
38   template <typename U>
39   PETSC_NODISCARD static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept;
40   template <typename U>
41   PETSC_NODISCARD static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept;
42   template <typename U>
43   PETSC_NODISCARD static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept;
44 };
45 
46 template <DeviceType T, typename P>
47 template <typename U>
48 inline PetscErrorCode HostAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *) noexcept
49 {
50   PetscFunctionBegin;
51   PetscCall(PetscCUPMMallocHost(ptr, n));
52   PetscFunctionReturn(0);
53 }
54 
55 template <DeviceType T, typename P>
56 template <typename U>
57 inline PetscErrorCode HostAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *) noexcept
58 {
59   PetscFunctionBegin;
60   PetscCallCUPM(cupmFreeHost(ptr));
61   PetscFunctionReturn(0);
62 }
63 
64 template <DeviceType T, typename P>
65 template <typename U>
66 inline PetscErrorCode HostAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept
67 {
68   PetscFunctionBegin;
69   PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyHostToHost, stream->get_stream(), true));
70   PetscFunctionReturn(0);
71 }
72 
73 // ==========================================================================================
74 // CUPM Device Allocator
75 // ==========================================================================================
76 
77 template <DeviceType T, typename PetscType = char>
78 class DeviceAllocator;
79 
80 template <DeviceType T, typename PetscType>
81 class DeviceAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> {
82 public:
83   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, T);
84   using base_type = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>;
85   using typename base_type::real_value_type;
86   using typename base_type::size_type;
87   using typename base_type::value_type;
88 
89   template <typename U>
90   PETSC_NODISCARD static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept;
91   template <typename U>
92   PETSC_NODISCARD static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept;
93   template <typename U>
94   PETSC_NODISCARD static PetscErrorCode zero(value_type *, size_type, const StreamBase<U> *) noexcept;
95   template <typename U>
96   PETSC_NODISCARD static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept;
97   template <typename U>
98   PETSC_NODISCARD static PetscErrorCode set_canary(value_type *, size_type, const StreamBase<U> *) noexcept;
99 };
100 
101 template <DeviceType T, typename P>
102 template <typename U>
103 inline PetscErrorCode DeviceAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *stream) noexcept
104 {
105   PetscFunctionBegin;
106   PetscCall(PetscCUPMMallocAsync(ptr, n, stream->get_stream()));
107   PetscFunctionReturn(0);
108 }
109 
110 template <DeviceType T, typename P>
111 template <typename U>
112 inline PetscErrorCode DeviceAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *stream) noexcept
113 {
114   PetscFunctionBegin;
115   PetscCallCUPM(cupmFreeAsync(ptr, stream->get_stream()));
116   PetscFunctionReturn(0);
117 }
118 
119 template <DeviceType T, typename P>
120 template <typename U>
121 inline PetscErrorCode DeviceAllocator<T, P>::zero(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept
122 {
123   PetscFunctionBegin;
124   PetscCall(PetscCUPMMemsetAsync(ptr, 0, n, stream->get_stream(), true));
125   PetscFunctionReturn(0);
126 }
127 
128 template <DeviceType T, typename P>
129 template <typename U>
130 inline PetscErrorCode DeviceAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept
131 {
132   PetscFunctionBegin;
133   PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyDeviceToDevice, stream->get_stream(), true));
134   PetscFunctionReturn(0);
135 }
136 
137 template <DeviceType T, typename P>
138 template <typename U>
139 inline PetscErrorCode DeviceAllocator<T, P>::set_canary(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept
140 {
141   using limit_t           = std::numeric_limits<real_value_type>;
142   const value_type canary = limit_t::has_signaling_NaN ? limit_t::signaling_NaN() : limit_t::max();
143 
144   PetscFunctionBegin;
145   PetscCall(impl::ThrustSet<T>(stream->get_stream(), n, ptr, &canary));
146   PetscFunctionReturn(0);
147 }
148 
149 } // namespace cupm
150 
151 } // namespace device
152 
153 } // namespace Petsc
154 
155 #endif // __cplusplus
156 
157 #endif // CUPMALLOCATOR_HPP
158