1a4963045SJacob Faibussowitsch #pragma once
20e6b6b59SJacob Faibussowitsch
30e6b6b59SJacob Faibussowitsch #include <petsc/private/cupminterface.hpp>
40e6b6b59SJacob Faibussowitsch
50e6b6b59SJacob Faibussowitsch #include "../segmentedmempool.hpp"
60e6b6b59SJacob Faibussowitsch #include "cupmevent.hpp"
70e6b6b59SJacob Faibussowitsch
8d71ae5a4SJacob Faibussowitsch namespace Petsc
9d71ae5a4SJacob Faibussowitsch {
100e6b6b59SJacob Faibussowitsch
11d71ae5a4SJacob Faibussowitsch namespace device
12d71ae5a4SJacob Faibussowitsch {
130e6b6b59SJacob Faibussowitsch
14d71ae5a4SJacob Faibussowitsch namespace cupm
15d71ae5a4SJacob Faibussowitsch {
160e6b6b59SJacob Faibussowitsch
170e6b6b59SJacob Faibussowitsch // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
180e6b6b59SJacob Faibussowitsch // identify separate cupm streams. This is so that the memory pool can accelerate allocation
190e6b6b59SJacob Faibussowitsch // calls as it can just pass back a pointer to memory that was used on the same
200e6b6b59SJacob Faibussowitsch // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
210e6b6b59SJacob Faibussowitsch // Address of the objects does not suffice since cupmStreams are very likely internally reused.
220e6b6b59SJacob Faibussowitsch
230e6b6b59SJacob Faibussowitsch template <DeviceType T>
24*85f25e71SJed Brown class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
250e6b6b59SJacob Faibussowitsch using crtp_base_type = StreamBase<CUPMStream<T>>;
260e6b6b59SJacob Faibussowitsch friend crtp_base_type;
270e6b6b59SJacob Faibussowitsch
280e6b6b59SJacob Faibussowitsch public:
2996a4b4d9SJacob Faibussowitsch PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
300e6b6b59SJacob Faibussowitsch
310e6b6b59SJacob Faibussowitsch using stream_type = cupmStream_t;
320e6b6b59SJacob Faibussowitsch using id_type = typename crtp_base_type::id_type;
330e6b6b59SJacob Faibussowitsch using event_type = CUPMEvent<T>;
340e6b6b59SJacob Faibussowitsch using flag_type = unsigned int;
350e6b6b59SJacob Faibussowitsch
360e6b6b59SJacob Faibussowitsch CUPMStream() noexcept = default;
370e6b6b59SJacob Faibussowitsch
38089fb57cSJacob Faibussowitsch PetscErrorCode destroy() noexcept;
39089fb57cSJacob Faibussowitsch PetscErrorCode create(flag_type) noexcept;
40089fb57cSJacob Faibussowitsch PetscErrorCode change_type(PetscStreamType) noexcept;
410e6b6b59SJacob Faibussowitsch
420e6b6b59SJacob Faibussowitsch private:
430e6b6b59SJacob Faibussowitsch stream_type stream_{};
440e6b6b59SJacob Faibussowitsch id_type id_ = new_id_();
450e6b6b59SJacob Faibussowitsch
460e6b6b59SJacob Faibussowitsch PETSC_NODISCARD static id_type new_id_() noexcept;
470e6b6b59SJacob Faibussowitsch
480e6b6b59SJacob Faibussowitsch // CRTP implementations
4931d47070SJunchao Zhang PETSC_NODISCARD const stream_type &get_stream_() const noexcept;
500e6b6b59SJacob Faibussowitsch PETSC_NODISCARD id_type get_id_() const noexcept;
51089fb57cSJacob Faibussowitsch PetscErrorCode record_event_(event_type &) const noexcept;
52089fb57cSJacob Faibussowitsch PetscErrorCode wait_for_(event_type &) const noexcept;
530e6b6b59SJacob Faibussowitsch };
540e6b6b59SJacob Faibussowitsch
550e6b6b59SJacob Faibussowitsch template <DeviceType T>
destroy()56d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::destroy() noexcept
57d71ae5a4SJacob Faibussowitsch {
580e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
590e6b6b59SJacob Faibussowitsch if (stream_) {
600e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamDestroy(stream_));
610e6b6b59SJacob Faibussowitsch stream_ = cupmStream_t{};
620e6b6b59SJacob Faibussowitsch id_ = 0;
630e6b6b59SJacob Faibussowitsch }
643ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
650e6b6b59SJacob Faibussowitsch }
660e6b6b59SJacob Faibussowitsch
670e6b6b59SJacob Faibussowitsch template <DeviceType T>
create(flag_type flags)68d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept
69d71ae5a4SJacob Faibussowitsch {
700e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
710e6b6b59SJacob Faibussowitsch if (stream_) {
720e6b6b59SJacob Faibussowitsch if (PetscDefined(USE_DEBUG)) {
730e6b6b59SJacob Faibussowitsch flag_type current_flags;
740e6b6b59SJacob Faibussowitsch
750e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamGetFlags(stream_, ¤t_flags));
760e6b6b59SJacob Faibussowitsch PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
770e6b6b59SJacob Faibussowitsch }
783ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
790e6b6b59SJacob Faibussowitsch }
800e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
810e6b6b59SJacob Faibussowitsch id_ = new_id_();
823ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
830e6b6b59SJacob Faibussowitsch }
840e6b6b59SJacob Faibussowitsch
850e6b6b59SJacob Faibussowitsch template <DeviceType T>
change_type(PetscStreamType newtype)86d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept
87d71ae5a4SJacob Faibussowitsch {
880e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
89d9acb416SHong Zhang if (newtype == PETSC_STREAM_DEFAULT || newtype == PETSC_STREAM_DEFAULT_WITH_BARRIER) {
900e6b6b59SJacob Faibussowitsch PetscCall(destroy());
91d9acb416SHong Zhang } else { // change to a nonblokcing stream
92d9acb416SHong Zhang const flag_type preferred = cupmStreamNonBlocking;
930e6b6b59SJacob Faibussowitsch
940e6b6b59SJacob Faibussowitsch if (stream_) {
950e6b6b59SJacob Faibussowitsch flag_type flag;
960e6b6b59SJacob Faibussowitsch
970e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
983ba16761SJacob Faibussowitsch if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS);
996b619d28SSuyash Tandon PetscCall(destroy());
1000e6b6b59SJacob Faibussowitsch }
1010e6b6b59SJacob Faibussowitsch PetscCall(create(preferred));
1020e6b6b59SJacob Faibussowitsch }
1033ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
1040e6b6b59SJacob Faibussowitsch }
1050e6b6b59SJacob Faibussowitsch
1060e6b6b59SJacob Faibussowitsch template <DeviceType T>
new_id_()107d71ae5a4SJacob Faibussowitsch inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept
108d71ae5a4SJacob Faibussowitsch {
1090e6b6b59SJacob Faibussowitsch static id_type id = 0;
1100e6b6b59SJacob Faibussowitsch return id++;
1110e6b6b59SJacob Faibussowitsch }
1120e6b6b59SJacob Faibussowitsch
1130e6b6b59SJacob Faibussowitsch // CRTP implementations
1140e6b6b59SJacob Faibussowitsch template <DeviceType T>
get_stream_() const11531d47070SJunchao Zhang inline const typename CUPMStream<T>::stream_type &CUPMStream<T>::get_stream_() const noexcept
116d71ae5a4SJacob Faibussowitsch {
1170e6b6b59SJacob Faibussowitsch return stream_;
1180e6b6b59SJacob Faibussowitsch }
1190e6b6b59SJacob Faibussowitsch
1200e6b6b59SJacob Faibussowitsch template <DeviceType T>
get_id_() const121d71ae5a4SJacob Faibussowitsch inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept
122d71ae5a4SJacob Faibussowitsch {
1230e6b6b59SJacob Faibussowitsch return id_;
1240e6b6b59SJacob Faibussowitsch }
1250e6b6b59SJacob Faibussowitsch
1260e6b6b59SJacob Faibussowitsch template <DeviceType T>
record_event_(event_type & event) const127d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept
128d71ae5a4SJacob Faibussowitsch {
1290e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
1300e6b6b59SJacob Faibussowitsch PetscCall(event.record(stream_));
1313ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
1320e6b6b59SJacob Faibussowitsch }
1330e6b6b59SJacob Faibussowitsch
1340e6b6b59SJacob Faibussowitsch template <DeviceType T>
wait_for_(event_type & event) const135d71ae5a4SJacob Faibussowitsch inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept
136d71ae5a4SJacob Faibussowitsch {
1370e6b6b59SJacob Faibussowitsch PetscFunctionBegin;
1380e6b6b59SJacob Faibussowitsch PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
1393ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
1400e6b6b59SJacob Faibussowitsch }
1410e6b6b59SJacob Faibussowitsch
1420e6b6b59SJacob Faibussowitsch } // namespace cupm
1430e6b6b59SJacob Faibussowitsch
1440e6b6b59SJacob Faibussowitsch } // namespace device
1450e6b6b59SJacob Faibussowitsch
1460e6b6b59SJacob Faibussowitsch } // namespace Petsc
147