1 #ifndef CUPMALLOCATOR_HPP 2 #define CUPMALLOCATOR_HPP 3 4 #if defined(__cplusplus) 5 #include <petsc/private/cpp/object_pool.hpp> 6 7 #include "../segmentedmempool.hpp" 8 #include "cupmthrustutility.hpp" 9 10 #include <thrust/device_ptr.h> 11 #include <thrust/fill.h> 12 13 #include <limits> // std::numeric_limits 14 15 namespace Petsc 16 { 17 18 namespace device 19 { 20 21 namespace cupm 22 { 23 24 // ========================================================================================== 25 // CUPM Host Allocator 26 // ========================================================================================== 27 28 template <DeviceType T, typename PetscType = char> 29 class HostAllocator; 30 31 // Allocator class to allocate pinned host memory for use with device 32 template <DeviceType T, typename PetscType> 33 class HostAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> { 34 public: 35 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); 36 using base_type = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>; 37 using real_value_type = typename base_type::real_value_type; 38 using size_type = typename base_type::size_type; 39 using value_type = typename base_type::value_type; 40 41 template <typename U> 42 static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept; 43 template <typename U> 44 static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept; 45 template <typename U> 46 static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept; 47 }; 48 49 template <DeviceType T, typename P> 50 template <typename U> 51 inline PetscErrorCode HostAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *) noexcept 52 { 53 PetscFunctionBegin; 54 PetscCall(PetscCUPMMallocHost(ptr, n)); 55 PetscFunctionReturn(PETSC_SUCCESS); 56 } 57 58 template <DeviceType T, typename P> 59 template <typename U> 60 inline PetscErrorCode HostAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *) noexcept 61 { 62 PetscFunctionBegin; 63 PetscCallCUPM(cupmFreeHost(ptr)); 64 PetscFunctionReturn(PETSC_SUCCESS); 65 } 66 67 template <DeviceType T, typename P> 68 template <typename U> 69 inline PetscErrorCode HostAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept 70 { 71 PetscFunctionBegin; 72 PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyHostToHost, stream->get_stream(), true)); 73 PetscFunctionReturn(PETSC_SUCCESS); 74 } 75 76 // ========================================================================================== 77 // CUPM Device Allocator 78 // ========================================================================================== 79 80 template <DeviceType T, typename PetscType = char> 81 class DeviceAllocator; 82 83 template <DeviceType T, typename PetscType> 84 class DeviceAllocator : public memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>, impl::Interface<T> { 85 public: 86 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); 87 using base_type = memory::impl::SegmentedMemoryPoolAllocatorBase<PetscType>; 88 using real_value_type = typename base_type::real_value_type; 89 using size_type = typename base_type::size_type; 90 using value_type = typename base_type::value_type; 91 92 template <typename U> 93 static PetscErrorCode allocate(value_type **, size_type, const StreamBase<U> *) noexcept; 94 template <typename U> 95 static PetscErrorCode deallocate(value_type *, const StreamBase<U> *) noexcept; 96 template <typename U> 97 static PetscErrorCode zero(value_type *, size_type, const StreamBase<U> *) noexcept; 98 template <typename U> 99 static PetscErrorCode uninitialized_copy(value_type *, const value_type *, size_type, const StreamBase<U> *) noexcept; 100 template <typename U> 101 static PetscErrorCode set_canary(value_type *, size_type, const StreamBase<U> *) noexcept; 102 }; 103 104 template <DeviceType T, typename P> 105 template <typename U> 106 inline PetscErrorCode DeviceAllocator<T, P>::allocate(value_type **ptr, size_type n, const StreamBase<U> *stream) noexcept 107 { 108 PetscFunctionBegin; 109 PetscCall(PetscCUPMMallocAsync(ptr, n, stream->get_stream())); 110 PetscFunctionReturn(PETSC_SUCCESS); 111 } 112 113 template <DeviceType T, typename P> 114 template <typename U> 115 inline PetscErrorCode DeviceAllocator<T, P>::deallocate(value_type *ptr, const StreamBase<U> *stream) noexcept 116 { 117 PetscFunctionBegin; 118 PetscCallCUPM(cupmFreeAsync(ptr, stream->get_stream())); 119 PetscFunctionReturn(PETSC_SUCCESS); 120 } 121 122 template <DeviceType T, typename P> 123 template <typename U> 124 inline PetscErrorCode DeviceAllocator<T, P>::zero(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept 125 { 126 PetscFunctionBegin; 127 PetscCall(PetscCUPMMemsetAsync(ptr, 0, n, stream->get_stream(), true)); 128 PetscFunctionReturn(PETSC_SUCCESS); 129 } 130 131 template <DeviceType T, typename P> 132 template <typename U> 133 inline PetscErrorCode DeviceAllocator<T, P>::uninitialized_copy(value_type *dest, const value_type *src, size_type n, const StreamBase<U> *stream) noexcept 134 { 135 PetscFunctionBegin; 136 PetscCall(PetscCUPMMemcpyAsync(dest, src, n, cupmMemcpyDeviceToDevice, stream->get_stream(), true)); 137 PetscFunctionReturn(PETSC_SUCCESS); 138 } 139 140 template <DeviceType T, typename P> 141 template <typename U> 142 inline PetscErrorCode DeviceAllocator<T, P>::set_canary(value_type *ptr, size_type n, const StreamBase<U> *stream) noexcept 143 { 144 using limit_t = std::numeric_limits<real_value_type>; 145 const value_type canary = limit_t::has_signaling_NaN ? limit_t::signaling_NaN() : limit_t::max(); 146 const auto xptr = thrust::device_pointer_cast(ptr); 147 148 PetscFunctionBegin; 149 PetscCallThrust(THRUST_CALL(thrust::fill, stream->get_stream(), xptr, xptr + n, canary)); 150 PetscFunctionReturn(PETSC_SUCCESS); 151 } 152 153 } // namespace cupm 154 155 } // namespace device 156 157 } // namespace Petsc 158 159 #endif // __cplusplus 160 161 #endif // CUPMALLOCATOR_HPP 162