10e6b6b59SJacob Faibussowitsch #ifndef PETSC_CUPMSTREAM_HPP 20e6b6b59SJacob Faibussowitsch #define PETSC_CUPMSTREAM_HPP 30e6b6b59SJacob Faibussowitsch 40e6b6b59SJacob Faibussowitsch #include <petsc/private/cupminterface.hpp> 50e6b6b59SJacob Faibussowitsch 60e6b6b59SJacob Faibussowitsch #include "../segmentedmempool.hpp" 70e6b6b59SJacob Faibussowitsch #include "cupmevent.hpp" 80e6b6b59SJacob Faibussowitsch 90e6b6b59SJacob Faibussowitsch #if defined(__cplusplus) 10d71ae5a4SJacob Faibussowitsch namespace Petsc 11d71ae5a4SJacob Faibussowitsch { 120e6b6b59SJacob Faibussowitsch 13d71ae5a4SJacob Faibussowitsch namespace device 14d71ae5a4SJacob Faibussowitsch { 150e6b6b59SJacob Faibussowitsch 16d71ae5a4SJacob Faibussowitsch namespace cupm 17d71ae5a4SJacob Faibussowitsch { 180e6b6b59SJacob Faibussowitsch 190e6b6b59SJacob Faibussowitsch // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely 200e6b6b59SJacob Faibussowitsch // identify separate cupm streams. This is so that the memory pool can accelerate allocation 210e6b6b59SJacob Faibussowitsch // calls as it can just pass back a pointer to memory that was used on the same 220e6b6b59SJacob Faibussowitsch // stream. Otherwise it must either serialize with another stream or allocate a new chunk. 230e6b6b59SJacob Faibussowitsch // Address of the objects does not suffice since cupmStreams are very likely internally reused. 240e6b6b59SJacob Faibussowitsch 250e6b6b59SJacob Faibussowitsch template <DeviceType T> 260e6b6b59SJacob Faibussowitsch class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> { 270e6b6b59SJacob Faibussowitsch using crtp_base_type = StreamBase<CUPMStream<T>>; 280e6b6b59SJacob Faibussowitsch friend crtp_base_type; 290e6b6b59SJacob Faibussowitsch 300e6b6b59SJacob Faibussowitsch public: 3196a4b4d9SJacob Faibussowitsch PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); 320e6b6b59SJacob Faibussowitsch 330e6b6b59SJacob Faibussowitsch using stream_type = cupmStream_t; 340e6b6b59SJacob Faibussowitsch using id_type = typename crtp_base_type::id_type; 350e6b6b59SJacob Faibussowitsch using event_type = CUPMEvent<T>; 360e6b6b59SJacob Faibussowitsch using flag_type = unsigned int; 370e6b6b59SJacob Faibussowitsch 380e6b6b59SJacob Faibussowitsch CUPMStream() noexcept = default; 390e6b6b59SJacob Faibussowitsch 40089fb57cSJacob Faibussowitsch PetscErrorCode destroy() noexcept; 41089fb57cSJacob Faibussowitsch PetscErrorCode create(flag_type) noexcept; 42089fb57cSJacob Faibussowitsch PetscErrorCode change_type(PetscStreamType) noexcept; 430e6b6b59SJacob Faibussowitsch 440e6b6b59SJacob Faibussowitsch private: 450e6b6b59SJacob Faibussowitsch stream_type stream_{}; 460e6b6b59SJacob Faibussowitsch id_type id_ = new_id_(); 470e6b6b59SJacob Faibussowitsch 480e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static id_type new_id_() noexcept; 490e6b6b59SJacob Faibussowitsch 500e6b6b59SJacob Faibussowitsch // CRTP implementations 51*31d47070SJunchao Zhang PETSC_NODISCARD const stream_type &get_stream_() const noexcept; 520e6b6b59SJacob Faibussowitsch PETSC_NODISCARD id_type get_id_() const noexcept; 53089fb57cSJacob Faibussowitsch PetscErrorCode record_event_(event_type &) const noexcept; 54089fb57cSJacob Faibussowitsch PetscErrorCode wait_for_(event_type &) const noexcept; 550e6b6b59SJacob Faibussowitsch }; 560e6b6b59SJacob Faibussowitsch 570e6b6b59SJacob Faibussowitsch template <DeviceType T> 58d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::destroy() noexcept 59d71ae5a4SJacob Faibussowitsch { 600e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 610e6b6b59SJacob Faibussowitsch if (stream_) { 620e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamDestroy(stream_)); 630e6b6b59SJacob Faibussowitsch stream_ = cupmStream_t{}; 640e6b6b59SJacob Faibussowitsch id_ = 0; 650e6b6b59SJacob Faibussowitsch } 663ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 670e6b6b59SJacob Faibussowitsch } 680e6b6b59SJacob Faibussowitsch 690e6b6b59SJacob Faibussowitsch template <DeviceType T> 70d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept 71d71ae5a4SJacob Faibussowitsch { 720e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 730e6b6b59SJacob Faibussowitsch if (stream_) { 740e6b6b59SJacob Faibussowitsch if (PetscDefined(USE_DEBUG)) { 750e6b6b59SJacob Faibussowitsch flag_type current_flags; 760e6b6b59SJacob Faibussowitsch 770e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamGetFlags(stream_, ¤t_flags)); 780e6b6b59SJacob Faibussowitsch PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_); 790e6b6b59SJacob Faibussowitsch } 803ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 810e6b6b59SJacob Faibussowitsch } 820e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags)); 830e6b6b59SJacob Faibussowitsch id_ = new_id_(); 843ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 850e6b6b59SJacob Faibussowitsch } 860e6b6b59SJacob Faibussowitsch 870e6b6b59SJacob Faibussowitsch template <DeviceType T> 88d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept 89d71ae5a4SJacob Faibussowitsch { 900e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 910e6b6b59SJacob Faibussowitsch if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) { 920e6b6b59SJacob Faibussowitsch PetscCall(destroy()); 930e6b6b59SJacob Faibussowitsch } else { 940e6b6b59SJacob Faibussowitsch const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking; 950e6b6b59SJacob Faibussowitsch 960e6b6b59SJacob Faibussowitsch if (stream_) { 970e6b6b59SJacob Faibussowitsch flag_type flag; 980e6b6b59SJacob Faibussowitsch 990e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamGetFlags(stream_, &flag)); 1003ba16761SJacob Faibussowitsch if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS); 1016b619d28SSuyash Tandon PetscCall(destroy()); 1020e6b6b59SJacob Faibussowitsch } 1030e6b6b59SJacob Faibussowitsch PetscCall(create(preferred)); 1040e6b6b59SJacob Faibussowitsch } 1053ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1060e6b6b59SJacob Faibussowitsch } 1070e6b6b59SJacob Faibussowitsch 1080e6b6b59SJacob Faibussowitsch template <DeviceType T> 109d71ae5a4SJacob Faibussowitsch inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept 110d71ae5a4SJacob Faibussowitsch { 1110e6b6b59SJacob Faibussowitsch static id_type id = 0; 1120e6b6b59SJacob Faibussowitsch return id++; 1130e6b6b59SJacob Faibussowitsch } 1140e6b6b59SJacob Faibussowitsch 1150e6b6b59SJacob Faibussowitsch // CRTP implementations 1160e6b6b59SJacob Faibussowitsch template <DeviceType T> 117*31d47070SJunchao Zhang inline const typename CUPMStream<T>::stream_type &CUPMStream<T>::get_stream_() const noexcept 118d71ae5a4SJacob Faibussowitsch { 1190e6b6b59SJacob Faibussowitsch return stream_; 1200e6b6b59SJacob Faibussowitsch } 1210e6b6b59SJacob Faibussowitsch 1220e6b6b59SJacob Faibussowitsch template <DeviceType T> 123d71ae5a4SJacob Faibussowitsch inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept 124d71ae5a4SJacob Faibussowitsch { 1250e6b6b59SJacob Faibussowitsch return id_; 1260e6b6b59SJacob Faibussowitsch } 1270e6b6b59SJacob Faibussowitsch 1280e6b6b59SJacob Faibussowitsch template <DeviceType T> 129d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept 130d71ae5a4SJacob Faibussowitsch { 1310e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1320e6b6b59SJacob Faibussowitsch PetscCall(event.record(stream_)); 1333ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1340e6b6b59SJacob Faibussowitsch } 1350e6b6b59SJacob Faibussowitsch 1360e6b6b59SJacob Faibussowitsch template <DeviceType T> 137d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept 138d71ae5a4SJacob Faibussowitsch { 1390e6b6b59SJacob Faibussowitsch PetscFunctionBegin; 1400e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0)); 1413ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1420e6b6b59SJacob Faibussowitsch } 1430e6b6b59SJacob Faibussowitsch 1440e6b6b59SJacob Faibussowitsch } // namespace cupm 1450e6b6b59SJacob Faibussowitsch 1460e6b6b59SJacob Faibussowitsch } // namespace device 1470e6b6b59SJacob Faibussowitsch 1480e6b6b59SJacob Faibussowitsch } // namespace Petsc 1490e6b6b59SJacob Faibussowitsch #endif // __cplusplus 1500e6b6b59SJacob Faibussowitsch 1510e6b6b59SJacob Faibussowitsch #endif // PETSC_CUPMSTREAM_HPP 152