1 #pragma once
2
3 #include <petsc/private/cpp/object_pool.hpp>
4
5 #include "../segmentedmempool.hpp"
6 #include "cupmthrustutility.hpp"
7
8 #include <thrust/device_ptr.h>
9 #include <thrust/fill.h>
10
11 #include <limits> // std::numeric_limits
12
13 namespace Petsc
14 {
15
16 namespace device
17 {
18
19 namespace cupm
20 {
21
22 // ==========================================================================================
23 // CUPM Host Allocator
24 // ==========================================================================================
25
26 template <DeviceType T, typename PetscType = char>
27 class HostAllocator;
28
29 // Allocator class to allocate pinned host memory for use with device
30 template <DeviceType T, typename PetscType>
31 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL HostAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> {
32 public:
33 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
34 using base_type = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>;
35 using real_value_type = typename base_type::real_value_type;
36 using size_type = typename base_type::size_type;
37 using value_type = typename base_type::value_type;
38
39 template <typename U>
40 static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept;
41 template <typename U>
42 static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept;
43 template <typename U>
44 static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept;
45 };
46
47 template <DeviceType T, typename P>
48 template <typename U>
allocate(value_type ** ptr,size_type n,const StreamBase<U> *)49 inline PetscErrorCode HostAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *) noexcept
50 {
51 PetscFunctionBegin;
52 PetscCall(PetscCUPMMallocHost(ptr, n));
53 PetscFunctionReturn(PETSC_SUCCESS);
54 }
55
56 template <DeviceType T, typename P>
57 template <typename U>
deallocate(value_type * ptr,const StreamBase<U> *)58 inline PetscErrorCode HostAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *) noexcept
59 {
60 PetscFunctionBegin;
61 PetscCallCUPM(cupmFreeHost(ptr));
62 PetscFunctionReturn(PETSC_SUCCESS);
63 }
64
65 template <DeviceType T, typename P>
66 template <typename U>
uninitialized_copy(value_type * dest,const value_type * src,size_type n,const StreamBase<U> * stream)67 inline PetscErrorCode HostAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept
68 {
69 PetscFunctionBegin;
70 PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyHostToHost, stream->get_stream(), true));
71 PetscFunctionReturn(PETSC_SUCCESS);
72 }
73
74 // ==========================================================================================
75 // CUPM Device Allocator
76 // ==========================================================================================
77
78 template <DeviceType T, typename PetscType = char>
79 class DeviceAllocator;
80
81 template <DeviceType T, typename PetscType>
82 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL DeviceAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> {
83 public:
84 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
85 using base_type = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>;
86 using real_value_type = typename base_type::real_value_type;
87 using size_type = typename base_type::size_type;
88 using value_type = typename base_type::value_type;
89
90 template <typename U>
91 static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept;
92 template <typename U>
93 static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept;
94 template <typename U>
95 static PetscErrorCode zero(value_type *, size_type, const StreamBase<U> *) noexcept;
96 template <typename U>
97 static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept;
98 template <typename U>
99 static PetscErrorCode set_canary(value_type *, size_type, const StreamBase<U> *) noexcept;
100 };
101
102 template <DeviceType T, typename P>
103 template <typename U>
allocate(value_type ** ptr,size_type n,const StreamBase<U> * stream)104 inline PetscErrorCode DeviceAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *stream) noexcept
105 {
106 PetscFunctionBegin;
107 PetscCall(PetscCUPMMallocAsync(ptr, n, stream->get_stream()));
108 PetscFunctionReturn(PETSC_SUCCESS);
109 }
110
111 template <DeviceType T, typename P>
112 template <typename U>
deallocate(value_type * ptr,const StreamBase<U> * stream)113 inline PetscErrorCode DeviceAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *stream) noexcept
114 {
115 PetscFunctionBegin;
116 PetscCallCUPM(cupmFreeAsync(ptr, stream->get_stream()));
117 PetscFunctionReturn(PETSC_SUCCESS);
118 }
119
120 template <DeviceType T, typename P>
121 template <typename U>
zero(value_type * ptr,size_type n,const StreamBase<U> * stream)122 inline PetscErrorCode DeviceAllocator<T, P>::zero(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept
123 {
124 PetscFunctionBegin;
125 PetscCall(PetscCUPMMemsetAsync(ptr, 0, n, stream->get_stream(), true));
126 PetscFunctionReturn(PETSC_SUCCESS);
127 }
128
129 template <DeviceType T, typename P>
130 template <typename U>
uninitialized_copy(value_type * dest,const value_type * src,size_type n,const StreamBase<U> * stream)131 inline PetscErrorCode DeviceAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept
132 {
133 PetscFunctionBegin;
134 PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyDeviceToDevice, stream->get_stream(), true));
135 PetscFunctionReturn(PETSC_SUCCESS);
136 }
137
138 template <DeviceType T, typename P>
139 template <typename U>
set_canary(value_type * ptr,size_type n,const StreamBase<U> * stream)140 inline PetscErrorCode DeviceAllocator<T, P>::set_canary(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept
141 {
142 using limit_t = std::numeric_limits<real_value_type>;
143 const value_type canary = limit_t::has_signaling_NaN ? limit_t::signaling_NaN() : limit_t::max();
144 const auto xptr = thrust::device_pointer_cast(ptr);
145
146 PetscFunctionBegin;
147 PetscCallThrust(THRUST_CALL(thrust::fill, stream->get_stream(), xptr, xptr + n, canary));
148 PetscFunctionReturn(PETSC_SUCCESS);
149 }
150
151 } // namespace cupm
152
153 } // namespace device
154
155 } // namespace Petsc
156