xref: /petsc/src/sys/objects/device/impls/cupm/cupmstream.hpp (revision 31d4707089da71ebcff1fd00fbb3a11b50a9f3d1)
10e6b6b59SJacob Faibussowitsch #ifndef PETSC_CUPMSTREAM_HPP
20e6b6b59SJacob Faibussowitsch #define PETSC_CUPMSTREAM_HPP
30e6b6b59SJacob Faibussowitsch 
40e6b6b59SJacob Faibussowitsch #include <petsc/private/cupminterface.hpp>
50e6b6b59SJacob Faibussowitsch 
60e6b6b59SJacob Faibussowitsch #include "../segmentedmempool.hpp"
70e6b6b59SJacob Faibussowitsch #include "cupmevent.hpp"
80e6b6b59SJacob Faibussowitsch 
90e6b6b59SJacob Faibussowitsch #if defined(__cplusplus)
10d71ae5a4SJacob Faibussowitsch namespace Petsc
11d71ae5a4SJacob Faibussowitsch {
120e6b6b59SJacob Faibussowitsch 
13d71ae5a4SJacob Faibussowitsch namespace device
14d71ae5a4SJacob Faibussowitsch {
150e6b6b59SJacob Faibussowitsch 
16d71ae5a4SJacob Faibussowitsch namespace cupm
17d71ae5a4SJacob Faibussowitsch {
180e6b6b59SJacob Faibussowitsch 
190e6b6b59SJacob Faibussowitsch // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
200e6b6b59SJacob Faibussowitsch // identify separate cupm streams. This is so that the memory pool can accelerate allocation
210e6b6b59SJacob Faibussowitsch // calls as it can just pass back a pointer to memory that was used on the same
220e6b6b59SJacob Faibussowitsch // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
230e6b6b59SJacob Faibussowitsch // Address of the objects does not suffice since cupmStreams are very likely internally reused.
240e6b6b59SJacob Faibussowitsch 
250e6b6b59SJacob Faibussowitsch template <DeviceType T>
260e6b6b59SJacob Faibussowitsch class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
270e6b6b59SJacob Faibussowitsch   using crtp_base_type = StreamBase<CUPMStream<T>>;
280e6b6b59SJacob Faibussowitsch   friend crtp_base_type;
290e6b6b59SJacob Faibussowitsch 
300e6b6b59SJacob Faibussowitsch public:
3196a4b4d9SJacob Faibussowitsch   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
320e6b6b59SJacob Faibussowitsch 
330e6b6b59SJacob Faibussowitsch   using stream_type = cupmStream_t;
340e6b6b59SJacob Faibussowitsch   using id_type     = typename crtp_base_type::id_type;
350e6b6b59SJacob Faibussowitsch   using event_type  = CUPMEvent<T>;
360e6b6b59SJacob Faibussowitsch   using flag_type   = unsigned int;
370e6b6b59SJacob Faibussowitsch 
380e6b6b59SJacob Faibussowitsch   CUPMStream() noexcept = default;
390e6b6b59SJacob Faibussowitsch 
40089fb57cSJacob Faibussowitsch   PetscErrorCode destroy() noexcept;
41089fb57cSJacob Faibussowitsch   PetscErrorCode create(flag_type) noexcept;
42089fb57cSJacob Faibussowitsch   PetscErrorCode change_type(PetscStreamType) noexcept;
430e6b6b59SJacob Faibussowitsch 
440e6b6b59SJacob Faibussowitsch private:
450e6b6b59SJacob Faibussowitsch   stream_type stream_{};
460e6b6b59SJacob Faibussowitsch   id_type     id_ = new_id_();
470e6b6b59SJacob Faibussowitsch 
480e6b6b59SJacob Faibussowitsch   PETSC_NODISCARD static id_type new_id_() noexcept;
490e6b6b59SJacob Faibussowitsch 
500e6b6b59SJacob Faibussowitsch   // CRTP implementations
51*31d47070SJunchao Zhang   PETSC_NODISCARD const stream_type &get_stream_() const noexcept;
520e6b6b59SJacob Faibussowitsch   PETSC_NODISCARD id_type            get_id_() const noexcept;
53089fb57cSJacob Faibussowitsch   PetscErrorCode                     record_event_(event_type &) const noexcept;
54089fb57cSJacob Faibussowitsch   PetscErrorCode                     wait_for_(event_type &) const noexcept;
550e6b6b59SJacob Faibussowitsch };
560e6b6b59SJacob Faibussowitsch 
570e6b6b59SJacob Faibussowitsch template <DeviceType T>
58d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::destroy() noexcept
59d71ae5a4SJacob Faibussowitsch {
600e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
610e6b6b59SJacob Faibussowitsch   if (stream_) {
620e6b6b59SJacob Faibussowitsch     PetscCallCUPM(cupmStreamDestroy(stream_));
630e6b6b59SJacob Faibussowitsch     stream_ = cupmStream_t{};
640e6b6b59SJacob Faibussowitsch     id_     = 0;
650e6b6b59SJacob Faibussowitsch   }
663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
670e6b6b59SJacob Faibussowitsch }
680e6b6b59SJacob Faibussowitsch 
690e6b6b59SJacob Faibussowitsch template <DeviceType T>
70d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept
71d71ae5a4SJacob Faibussowitsch {
720e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
730e6b6b59SJacob Faibussowitsch   if (stream_) {
740e6b6b59SJacob Faibussowitsch     if (PetscDefined(USE_DEBUG)) {
750e6b6b59SJacob Faibussowitsch       flag_type current_flags;
760e6b6b59SJacob Faibussowitsch 
770e6b6b59SJacob Faibussowitsch       PetscCallCUPM(cupmStreamGetFlags(stream_, &current_flags));
780e6b6b59SJacob Faibussowitsch       PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
790e6b6b59SJacob Faibussowitsch     }
803ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
810e6b6b59SJacob Faibussowitsch   }
820e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
830e6b6b59SJacob Faibussowitsch   id_ = new_id_();
843ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
850e6b6b59SJacob Faibussowitsch }
860e6b6b59SJacob Faibussowitsch 
870e6b6b59SJacob Faibussowitsch template <DeviceType T>
88d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept
89d71ae5a4SJacob Faibussowitsch {
900e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
910e6b6b59SJacob Faibussowitsch   if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) {
920e6b6b59SJacob Faibussowitsch     PetscCall(destroy());
930e6b6b59SJacob Faibussowitsch   } else {
940e6b6b59SJacob Faibussowitsch     const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking;
950e6b6b59SJacob Faibussowitsch 
960e6b6b59SJacob Faibussowitsch     if (stream_) {
970e6b6b59SJacob Faibussowitsch       flag_type flag;
980e6b6b59SJacob Faibussowitsch 
990e6b6b59SJacob Faibussowitsch       PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
1003ba16761SJacob Faibussowitsch       if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS);
1016b619d28SSuyash Tandon       PetscCall(destroy());
1020e6b6b59SJacob Faibussowitsch     }
1030e6b6b59SJacob Faibussowitsch     PetscCall(create(preferred));
1040e6b6b59SJacob Faibussowitsch   }
1053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1060e6b6b59SJacob Faibussowitsch }
1070e6b6b59SJacob Faibussowitsch 
1080e6b6b59SJacob Faibussowitsch template <DeviceType T>
109d71ae5a4SJacob Faibussowitsch inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept
110d71ae5a4SJacob Faibussowitsch {
1110e6b6b59SJacob Faibussowitsch   static id_type id = 0;
1120e6b6b59SJacob Faibussowitsch   return id++;
1130e6b6b59SJacob Faibussowitsch }
1140e6b6b59SJacob Faibussowitsch 
1150e6b6b59SJacob Faibussowitsch // CRTP implementations
1160e6b6b59SJacob Faibussowitsch template <DeviceType T>
117*31d47070SJunchao Zhang inline const typename CUPMStream<T>::stream_type &CUPMStream<T>::get_stream_() const noexcept
118d71ae5a4SJacob Faibussowitsch {
1190e6b6b59SJacob Faibussowitsch   return stream_;
1200e6b6b59SJacob Faibussowitsch }
1210e6b6b59SJacob Faibussowitsch 
1220e6b6b59SJacob Faibussowitsch template <DeviceType T>
123d71ae5a4SJacob Faibussowitsch inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept
124d71ae5a4SJacob Faibussowitsch {
1250e6b6b59SJacob Faibussowitsch   return id_;
1260e6b6b59SJacob Faibussowitsch }
1270e6b6b59SJacob Faibussowitsch 
1280e6b6b59SJacob Faibussowitsch template <DeviceType T>
129d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept
130d71ae5a4SJacob Faibussowitsch {
1310e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
1320e6b6b59SJacob Faibussowitsch   PetscCall(event.record(stream_));
1333ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1340e6b6b59SJacob Faibussowitsch }
1350e6b6b59SJacob Faibussowitsch 
1360e6b6b59SJacob Faibussowitsch template <DeviceType T>
137d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept
138d71ae5a4SJacob Faibussowitsch {
1390e6b6b59SJacob Faibussowitsch   PetscFunctionBegin;
1400e6b6b59SJacob Faibussowitsch   PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
1413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1420e6b6b59SJacob Faibussowitsch }
1430e6b6b59SJacob Faibussowitsch 
1440e6b6b59SJacob Faibussowitsch } // namespace cupm
1450e6b6b59SJacob Faibussowitsch 
1460e6b6b59SJacob Faibussowitsch } // namespace device
1470e6b6b59SJacob Faibussowitsch 
1480e6b6b59SJacob Faibussowitsch } // namespace Petsc
1490e6b6b59SJacob Faibussowitsch #endif // __cplusplus
1500e6b6b59SJacob Faibussowitsch 
1510e6b6b59SJacob Faibussowitsch #endif // PETSC_CUPMSTREAM_HPP
152