1 #ifndef PETSC_CUPMSTREAM_HPP 2 #define PETSC_CUPMSTREAM_HPP 3 4 #include <petsc/private/cupminterface.hpp> 5 6 #include "../segmentedmempool.hpp" 7 #include "cupmevent.hpp" 8 9 #if defined(__cplusplus) 10 namespace Petsc { 11 12 namespace device { 13 14 namespace cupm { 15 16 // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely 17 // identify separate cupm streams. This is so that the memory pool can accelerate allocation 18 // calls as it can just pass back a pointer to memory that was used on the same 19 // stream. Otherwise it must either serialize with another stream or allocate a new chunk. 20 // Address of the objects does not suffice since cupmStreams are very likely internally reused. 21 22 template <DeviceType T> 23 class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> { 24 using crtp_base_type = StreamBase<CUPMStream<T>>; 25 friend crtp_base_type; 26 27 public: 28 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, T); 29 30 using stream_type = cupmStream_t; 31 using id_type = typename crtp_base_type::id_type; 32 using event_type = CUPMEvent<T>; 33 using flag_type = unsigned int; 34 35 CUPMStream() noexcept = default; 36 37 PETSC_NODISCARD PetscErrorCode destroy() noexcept; 38 PETSC_NODISCARD PetscErrorCode create(flag_type) noexcept; 39 PETSC_NODISCARD PetscErrorCode change_type(PetscStreamType) noexcept; 40 41 private: 42 stream_type stream_{}; 43 id_type id_ = new_id_(); 44 45 PETSC_NODISCARD static id_type new_id_() noexcept; 46 47 // CRTP implementations 48 PETSC_NODISCARD stream_type get_stream_() const noexcept; 49 PETSC_NODISCARD id_type get_id_() const noexcept; 50 PETSC_NODISCARD PetscErrorCode record_event_(event_type &) const noexcept; 51 PETSC_NODISCARD PetscErrorCode wait_for_(event_type &) const noexcept; 52 }; 53 54 template <DeviceType T> 55 inline PetscErrorCode CUPMStream<T>::destroy() noexcept { 56 PetscFunctionBegin; 57 if (stream_) { 58 PetscCallCUPM(cupmStreamDestroy(stream_)); 59 stream_ = cupmStream_t{}; 60 id_ = 0; 61 } 62 PetscFunctionReturn(0); 63 } 64 65 template <DeviceType T> 66 inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept { 67 PetscFunctionBegin; 68 if (stream_) { 69 if (PetscDefined(USE_DEBUG)) { 70 flag_type current_flags; 71 72 PetscCallCUPM(cupmStreamGetFlags(stream_, ¤t_flags)); 73 PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_); 74 } 75 PetscFunctionReturn(0); 76 } 77 PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags)); 78 id_ = new_id_(); 79 PetscFunctionReturn(0); 80 } 81 82 template <DeviceType T> 83 inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept { 84 PetscFunctionBegin; 85 if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) { 86 PetscCall(destroy()); 87 } else { 88 const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking; 89 90 if (stream_) { 91 flag_type flag; 92 93 PetscCallCUPM(cupmStreamGetFlags(stream_, &flag)); 94 if ((flag != preferred) || (cupmStreamQuery(stream_) != cupmSuccess)) PetscCall(destroy()); 95 } 96 PetscCall(create(preferred)); 97 } 98 PetscFunctionReturn(0); 99 } 100 101 template <DeviceType T> 102 inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept { 103 static id_type id = 0; 104 return id++; 105 } 106 107 // CRTP implementations 108 template <DeviceType T> 109 inline typename CUPMStream<T>::stream_type CUPMStream<T>::get_stream_() const noexcept { 110 return stream_; 111 } 112 113 template <DeviceType T> 114 inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept { 115 return id_; 116 } 117 118 template <DeviceType T> 119 inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept { 120 PetscFunctionBegin; 121 PetscCall(event.record(stream_)); 122 PetscFunctionReturn(0); 123 } 124 125 template <DeviceType T> 126 inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept { 127 PetscFunctionBegin; 128 PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0)); 129 PetscFunctionReturn(0); 130 } 131 132 } // namespace cupm 133 134 } // namespace device 135 136 } // namespace Petsc 137 #endif // __cplusplus 138 139 #endif // PETSC_CUPMSTREAM_HPP 140