1 #pragma once 2 3 #include <petsc/private/cupminterface.hpp> 4 5 #include "../segmentedmempool.hpp" 6 #include "cupmevent.hpp" 7 8 namespace Petsc 9 { 10 11 namespace device 12 { 13 14 namespace cupm 15 { 16 17 // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely 18 // identify separate cupm streams. This is so that the memory pool can accelerate allocation 19 // calls as it can just pass back a pointer to memory that was used on the same 20 // stream. Otherwise it must either serialize with another stream or allocate a new chunk. 21 // Address of the objects does not suffice since cupmStreams are very likely internally reused. 22 23 template <DeviceType T> 24 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> { 25 using crtp_base_type = StreamBase<CUPMStream<T>>; 26 friend crtp_base_type; 27 28 public: 29 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); 30 31 using stream_type = cupmStream_t; 32 using id_type = typename crtp_base_type::id_type; 33 using event_type = CUPMEvent<T>; 34 using flag_type = unsigned int; 35 36 CUPMStream() noexcept = default; 37 38 PetscErrorCode destroy() noexcept; 39 PetscErrorCode create(flag_type) noexcept; 40 PetscErrorCode change_type(PetscStreamType) noexcept; 41 42 private: 43 stream_type stream_{}; 44 id_type id_ = new_id_(); 45 46 PETSC_NODISCARD static id_type new_id_() noexcept; 47 48 // CRTP implementations 49 PETSC_NODISCARD const stream_type &get_stream_() const noexcept; 50 PETSC_NODISCARD id_type get_id_() const noexcept; 51 PetscErrorCode record_event_(event_type &) const noexcept; 52 PetscErrorCode wait_for_(event_type &) const noexcept; 53 }; 54 55 template <DeviceType T> 56 inline PetscErrorCode CUPMStream<T>::destroy() noexcept 57 { 58 PetscFunctionBegin; 59 if (stream_) { 60 PetscCallCUPM(cupmStreamDestroy(stream_)); 61 stream_ = cupmStream_t{}; 62 id_ = 0; 63 } 64 PetscFunctionReturn(PETSC_SUCCESS); 65 } 66 67 template <DeviceType T> 68 inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept 69 { 70 PetscFunctionBegin; 71 if (stream_) { 72 if (PetscDefined(USE_DEBUG)) { 73 flag_type current_flags; 74 75 PetscCallCUPM(cupmStreamGetFlags(stream_, ¤t_flags)); 76 PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_); 77 } 78 PetscFunctionReturn(PETSC_SUCCESS); 79 } 80 PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags)); 81 id_ = new_id_(); 82 PetscFunctionReturn(PETSC_SUCCESS); 83 } 84 85 template <DeviceType T> 86 inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept 87 { 88 PetscFunctionBegin; 89 if (newtype == PETSC_STREAM_DEFAULT || newtype == PETSC_STREAM_DEFAULT_WITH_BARRIER) { 90 PetscCall(destroy()); 91 } else { // change to a nonblokcing stream 92 const flag_type preferred = cupmStreamNonBlocking; 93 94 if (stream_) { 95 flag_type flag; 96 97 PetscCallCUPM(cupmStreamGetFlags(stream_, &flag)); 98 if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS); 99 PetscCall(destroy()); 100 } 101 PetscCall(create(preferred)); 102 } 103 PetscFunctionReturn(PETSC_SUCCESS); 104 } 105 106 template <DeviceType T> 107 inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept 108 { 109 static id_type id = 0; 110 return id++; 111 } 112 113 // CRTP implementations 114 template <DeviceType T> 115 inline const typename CUPMStream<T>::stream_type &CUPMStream<T>::get_stream_() const noexcept 116 { 117 return stream_; 118 } 119 120 template <DeviceType T> 121 inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept 122 { 123 return id_; 124 } 125 126 template <DeviceType T> 127 inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept 128 { 129 PetscFunctionBegin; 130 PetscCall(event.record(stream_)); 131 PetscFunctionReturn(PETSC_SUCCESS); 132 } 133 134 template <DeviceType T> 135 inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept 136 { 137 PetscFunctionBegin; 138 PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0)); 139 PetscFunctionReturn(PETSC_SUCCESS); 140 } 141 142 } // namespace cupm 143 144 } // namespace device 145 146 } // namespace Petsc 147