xref: /petsc/src/sys/objects/device/impls/cupm/cupmallocator.hpp (revision 750b007cd8d816cecd9de99077bb0a703b4cf61a)
1 #ifndef CUPMALLOCATOR_HPP
2 #define CUPMALLOCATOR_HPP
3 
4 #if defined(__cplusplus)
5 #include <petsc/private/cpp/object_pool.hpp>
6 
7 #include "../segmentedmempool.hpp"
8 #include "cupmthrustutility.hpp"
9 
10 #include <limits> // std::numeric_limits
11 
12 namespace Petsc {
13 
14 namespace device {
15 
16 namespace cupm {
17 
18 // ==========================================================================================
19 // CUPM Host Allocator
20 // ==========================================================================================
21 
22 template <DeviceType T, typename PetscType = char>
23 class HostAllocator;
24 
25 // Allocator class to allocate pinned host memory for use with device
26 template <DeviceType T, typename PetscType>
27 class HostAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> {
28 public:
29   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, T);
30   using base_type = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>;
31   using typename base_type::real_value_type;
32   using typename base_type::size_type;
33   using typename base_type::value_type;
34 
35   template <typename U>
36   PETSC_NODISCARD static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept;
37   template <typename U>
38   PETSC_NODISCARD static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept;
39   template <typename U>
40   PETSC_NODISCARD static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept;
41 };
42 
43 template <DeviceType T, typename P>
44 template <typename U>
45 inline PetscErrorCode HostAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *) noexcept {
46   PetscFunctionBegin;
47   PetscCall(PetscCUPMMallocHost(ptr, n));
48   PetscFunctionReturn(0);
49 }
50 
51 template <DeviceType T, typename P>
52 template <typename U>
53 inline PetscErrorCode HostAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *) noexcept {
54   PetscFunctionBegin;
55   PetscCallCUPM(cupmFreeHost(ptr));
56   PetscFunctionReturn(0);
57 }
58 
59 template <DeviceType T, typename P>
60 template <typename U>
61 inline PetscErrorCode HostAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept {
62   PetscFunctionBegin;
63   PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyHostToHost, stream->get_stream(), true));
64   PetscFunctionReturn(0);
65 }
66 
67 // ==========================================================================================
68 // CUPM Device Allocator
69 // ==========================================================================================
70 
71 template <DeviceType T, typename PetscType = char>
72 class DeviceAllocator;
73 
74 template <DeviceType T, typename PetscType>
75 class DeviceAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> {
76 public:
77   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, T);
78   using base_type = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>;
79   using typename base_type::real_value_type;
80   using typename base_type::size_type;
81   using typename base_type::value_type;
82 
83   template <typename U>
84   PETSC_NODISCARD static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept;
85   template <typename U>
86   PETSC_NODISCARD static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept;
87   template <typename U>
88   PETSC_NODISCARD static PetscErrorCode zero(value_type *, size_type, const StreamBase<U> *) noexcept;
89   template <typename U>
90   PETSC_NODISCARD static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept;
91   template <typename U>
92   PETSC_NODISCARD static PetscErrorCode set_canary(value_type *, size_type, const StreamBase<U> *) noexcept;
93 };
94 
95 template <DeviceType T, typename P>
96 template <typename U>
97 inline PetscErrorCode DeviceAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *stream) noexcept {
98   PetscFunctionBegin;
99   PetscCall(PetscCUPMMallocAsync(ptr, n, stream->get_stream()));
100   PetscFunctionReturn(0);
101 }
102 
103 template <DeviceType T, typename P>
104 template <typename U>
105 inline PetscErrorCode DeviceAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *stream) noexcept {
106   PetscFunctionBegin;
107   PetscCallCUPM(cupmFreeAsync(ptr, stream->get_stream()));
108   PetscFunctionReturn(0);
109 }
110 
111 template <DeviceType T, typename P>
112 template <typename U>
113 inline PetscErrorCode DeviceAllocator<T, P>::zero(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept {
114   PetscFunctionBegin;
115   PetscCall(PetscCUPMMemsetAsync(ptr, 0, n, stream->get_stream(), true));
116   PetscFunctionReturn(0);
117 }
118 
119 template <DeviceType T, typename P>
120 template <typename U>
121 inline PetscErrorCode DeviceAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept {
122   PetscFunctionBegin;
123   PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyDeviceToDevice, stream->get_stream(), true));
124   PetscFunctionReturn(0);
125 }
126 
127 template <DeviceType T, typename P>
128 template <typename U>
129 inline PetscErrorCode DeviceAllocator<T, P>::set_canary(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept {
130   using limit_t           = std::numeric_limits<real_value_type>;
131   const value_type canary = limit_t::has_signaling_NaN ? limit_t::signaling_NaN() : limit_t::max();
132 
133   PetscFunctionBegin;
134   PetscCall(impl::ThrustSet<T>(stream->get_stream(), n, ptr, &canary));
135   PetscFunctionReturn(0);
136 }
137 
138 } // namespace cupm
139 
140 } // namespace device
141 
142 } // namespace Petsc
143 
144 #endif // __cplusplus
145 
146 #endif // CUPMALLOCATOR_HPP
147