1 #ifndef PETSC_CUPMSTREAM_HPP 2 #define PETSC_CUPMSTREAM_HPP 3 4 #include <petsc/private/cupminterface.hpp> 5 6 #include "../segmentedmempool.hpp" 7 #include "cupmevent.hpp" 8 9 #if defined(__cplusplus) 10 namespace Petsc 11 { 12 13 namespace device 14 { 15 16 namespace cupm 17 { 18 19 // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely 20 // identify separate cupm streams. This is so that the memory pool can accelerate allocation 21 // calls as it can just pass back a pointer to memory that was used on the same 22 // stream. Otherwise it must either serialize with another stream or allocate a new chunk. 23 // Address of the objects does not suffice since cupmStreams are very likely internally reused. 24 25 template <DeviceType T> 26 class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> { 27 using crtp_base_type = StreamBase<CUPMStream<T>>; 28 friend crtp_base_type; 29 30 public: 31 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, T); 32 33 using stream_type = cupmStream_t; 34 using id_type = typename crtp_base_type::id_type; 35 using event_type = CUPMEvent<T>; 36 using flag_type = unsigned int; 37 38 CUPMStream() noexcept = default; 39 40 PetscErrorCode destroy() noexcept; 41 PetscErrorCode create(flag_type) noexcept; 42 PetscErrorCode change_type(PetscStreamType) noexcept; 43 44 private: 45 stream_type stream_{}; 46 id_type id_ = new_id_(); 47 48 PETSC_NODISCARD static id_type new_id_() noexcept; 49 50 // CRTP implementations 51 PETSC_NODISCARD stream_type get_stream_() const noexcept; 52 PETSC_NODISCARD id_type get_id_() const noexcept; 53 PetscErrorCode record_event_(event_type &) const noexcept; 54 PetscErrorCode wait_for_(event_type &) const noexcept; 55 }; 56 57 template <DeviceType T> 58 inline PetscErrorCode CUPMStream<T>::destroy() noexcept 59 { 60 PetscFunctionBegin; 61 if (stream_) { 62 PetscCallCUPM(cupmStreamDestroy(stream_)); 63 stream_ = cupmStream_t{}; 64 id_ = 0; 65 } 66 PetscFunctionReturn(PETSC_SUCCESS); 67 } 68 69 template <DeviceType T> 70 inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept 71 { 72 PetscFunctionBegin; 73 if (stream_) { 74 if (PetscDefined(USE_DEBUG)) { 75 flag_type current_flags; 76 77 PetscCallCUPM(cupmStreamGetFlags(stream_, ¤t_flags)); 78 PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_); 79 } 80 PetscFunctionReturn(PETSC_SUCCESS); 81 } 82 PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags)); 83 id_ = new_id_(); 84 PetscFunctionReturn(PETSC_SUCCESS); 85 } 86 87 template <DeviceType T> 88 inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept 89 { 90 PetscFunctionBegin; 91 if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) { 92 PetscCall(destroy()); 93 } else { 94 const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking; 95 96 if (stream_) { 97 flag_type flag; 98 99 PetscCallCUPM(cupmStreamGetFlags(stream_, &flag)); 100 if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS); 101 PetscCall(destroy()); 102 } 103 PetscCall(create(preferred)); 104 } 105 PetscFunctionReturn(PETSC_SUCCESS); 106 } 107 108 template <DeviceType T> 109 inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept 110 { 111 static id_type id = 0; 112 return id++; 113 } 114 115 // CRTP implementations 116 template <DeviceType T> 117 inline typename CUPMStream<T>::stream_type CUPMStream<T>::get_stream_() const noexcept 118 { 119 return stream_; 120 } 121 122 template <DeviceType T> 123 inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept 124 { 125 return id_; 126 } 127 128 template <DeviceType T> 129 inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept 130 { 131 PetscFunctionBegin; 132 PetscCall(event.record(stream_)); 133 PetscFunctionReturn(PETSC_SUCCESS); 134 } 135 136 template <DeviceType T> 137 inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept 138 { 139 PetscFunctionBegin; 140 PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0)); 141 PetscFunctionReturn(PETSC_SUCCESS); 142 } 143 144 } // namespace cupm 145 146 } // namespace device 147 148 } // namespace Petsc 149 #endif // __cplusplus 150 151 #endif // PETSC_CUPMSTREAM_HPP 152