#pragma once #include #include "../segmentedmempool.hpp" #include "cupmevent.hpp" namespace Petsc { namespace device { namespace cupm { // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely // identify separate cupm streams. This is so that the memory pool can accelerate allocation // calls as it can just pass back a pointer to memory that was used on the same // stream. Otherwise it must either serialize with another stream or allocate a new chunk. // Address of the objects does not suffice since cupmStreams are very likely internally reused. template class CUPMStream : public StreamBase>, impl::Interface { using crtp_base_type = StreamBase>; friend crtp_base_type; public: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); using stream_type = cupmStream_t; using id_type = typename crtp_base_type::id_type; using event_type = CUPMEvent; using flag_type = unsigned int; CUPMStream() noexcept = default; PetscErrorCode destroy() noexcept; PetscErrorCode create(flag_type) noexcept; PetscErrorCode change_type(PetscStreamType) noexcept; private: stream_type stream_{}; id_type id_ = new_id_(); PETSC_NODISCARD static id_type new_id_() noexcept; // CRTP implementations PETSC_NODISCARD const stream_type &get_stream_() const noexcept; PETSC_NODISCARD id_type get_id_() const noexcept; PetscErrorCode record_event_(event_type &) const noexcept; PetscErrorCode wait_for_(event_type &) const noexcept; }; template inline PetscErrorCode CUPMStream::destroy() noexcept { PetscFunctionBegin; if (stream_) { PetscCallCUPM(cupmStreamDestroy(stream_)); stream_ = cupmStream_t{}; id_ = 0; } PetscFunctionReturn(PETSC_SUCCESS); } template inline PetscErrorCode CUPMStream::create(flag_type flags) noexcept { PetscFunctionBegin; if (stream_) { if (PetscDefined(USE_DEBUG)) { flag_type current_flags; PetscCallCUPM(cupmStreamGetFlags(stream_, ¤t_flags)); PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_); } PetscFunctionReturn(PETSC_SUCCESS); } PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags)); id_ = new_id_(); PetscFunctionReturn(PETSC_SUCCESS); } template inline PetscErrorCode CUPMStream::change_type(PetscStreamType newtype) noexcept { PetscFunctionBegin; if (newtype == PETSC_STREAM_DEFAULT || newtype == PETSC_STREAM_DEFAULT_WITH_BARRIER) { PetscCall(destroy()); } else { // change to a nonblokcing stream const flag_type preferred = cupmStreamNonBlocking; if (stream_) { flag_type flag; PetscCallCUPM(cupmStreamGetFlags(stream_, &flag)); if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS); PetscCall(destroy()); } PetscCall(create(preferred)); } PetscFunctionReturn(PETSC_SUCCESS); } template inline typename CUPMStream::id_type CUPMStream::new_id_() noexcept { static id_type id = 0; return id++; } // CRTP implementations template inline const typename CUPMStream::stream_type &CUPMStream::get_stream_() const noexcept { return stream_; } template inline typename CUPMStream::id_type CUPMStream::get_id_() const noexcept { return id_; } template inline PetscErrorCode CUPMStream::record_event_(event_type &event) const noexcept { PetscFunctionBegin; PetscCall(event.record(stream_)); PetscFunctionReturn(PETSC_SUCCESS); } template inline PetscErrorCode CUPMStream::wait_for_(event_type &event) const noexcept { PetscFunctionBegin; PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0)); PetscFunctionReturn(PETSC_SUCCESS); } } // namespace cupm } // namespace device } // namespace Petsc