1 #ifndef PETSC_CUPMSTREAM_HPP 2 #define PETSC_CUPMSTREAM_HPP 3 4 #include <petsc/private/cupminterface.hpp> 5 6 #include "../segmentedmempool.hpp" 7 #include "cupmevent.hpp" 8 9 namespace Petsc 10 { 11 12 namespace device 13 { 14 15 namespace cupm 16 { 17 18 // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely 19 // identify separate cupm streams. This is so that the memory pool can accelerate allocation 20 // calls as it can just pass back a pointer to memory that was used on the same 21 // stream. Otherwise it must either serialize with another stream or allocate a new chunk. 22 // Address of the objects does not suffice since cupmStreams are very likely internally reused. 23 24 template <DeviceType T> 25 class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> { 26 using crtp_base_type = StreamBase<CUPMStream<T>>; 27 friend crtp_base_type; 28 29 public: 30 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); 31 32 using stream_type = cupmStream_t; 33 using id_type = typename crtp_base_type::id_type; 34 using event_type = CUPMEvent<T>; 35 using flag_type = unsigned int; 36 37 CUPMStream() noexcept = default; 38 39 PetscErrorCode destroy() noexcept; 40 PetscErrorCode create(flag_type) noexcept; 41 PetscErrorCode change_type(PetscStreamType) noexcept; 42 43 private: 44 stream_type stream_{}; 45 id_type id_ = new_id_(); 46 47 PETSC_NODISCARD static id_type new_id_() noexcept; 48 49 // CRTP implementations 50 PETSC_NODISCARD const stream_type &get_stream_() const noexcept; 51 PETSC_NODISCARD id_type get_id_() const noexcept; 52 PetscErrorCode record_event_(event_type &) const noexcept; 53 PetscErrorCode wait_for_(event_type &) const noexcept; 54 }; 55 56 template <DeviceType T> 57 inline PetscErrorCode CUPMStream<T>::destroy() noexcept 58 { 59 PetscFunctionBegin; 60 if (stream_) { 61 PetscCallCUPM(cupmStreamDestroy(stream_)); 62 stream_ = cupmStream_t{}; 63 id_ = 0; 64 } 65 PetscFunctionReturn(PETSC_SUCCESS); 66 } 67 68 template <DeviceType T> 69 inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept 70 { 71 PetscFunctionBegin; 72 if (stream_) { 73 if (PetscDefined(USE_DEBUG)) { 74 flag_type current_flags; 75 76 PetscCallCUPM(cupmStreamGetFlags(stream_, ¤t_flags)); 77 PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_); 78 } 79 PetscFunctionReturn(PETSC_SUCCESS); 80 } 81 PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags)); 82 id_ = new_id_(); 83 PetscFunctionReturn(PETSC_SUCCESS); 84 } 85 86 template <DeviceType T> 87 inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept 88 { 89 PetscFunctionBegin; 90 if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) { 91 PetscCall(destroy()); 92 } else { 93 const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking; 94 95 if (stream_) { 96 flag_type flag; 97 98 PetscCallCUPM(cupmStreamGetFlags(stream_, &flag)); 99 if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS); 100 PetscCall(destroy()); 101 } 102 PetscCall(create(preferred)); 103 } 104 PetscFunctionReturn(PETSC_SUCCESS); 105 } 106 107 template <DeviceType T> 108 inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept 109 { 110 static id_type id = 0; 111 return id++; 112 } 113 114 // CRTP implementations 115 template <DeviceType T> 116 inline const typename CUPMStream<T>::stream_type &CUPMStream<T>::get_stream_() const noexcept 117 { 118 return stream_; 119 } 120 121 template <DeviceType T> 122 inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept 123 { 124 return id_; 125 } 126 127 template <DeviceType T> 128 inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept 129 { 130 PetscFunctionBegin; 131 PetscCall(event.record(stream_)); 132 PetscFunctionReturn(PETSC_SUCCESS); 133 } 134 135 template <DeviceType T> 136 inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept 137 { 138 PetscFunctionBegin; 139 PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0)); 140 PetscFunctionReturn(PETSC_SUCCESS); 141 } 142 143 } // namespace cupm 144 145 } // namespace device 146 147 } // namespace Petsc 148 149 #endif // PETSC_CUPMSTREAM_HPP 150