1 #pragma once
2
3 #include <petsc/private/cupminterface.hpp>
4
5 #include "../segmentedmempool.hpp"
6 #include "cupmevent.hpp"
7
8 namespace Petsc
9 {
10
11 namespace device
12 {
13
14 namespace cupm
15 {
16
17 // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
18 // identify separate cupm streams. This is so that the memory pool can accelerate allocation
19 // calls as it can just pass back a pointer to memory that was used on the same
20 // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
21 // Address of the objects does not suffice since cupmStreams are very likely internally reused.
22
23 template <DeviceType T>
24 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
25 using crtp_base_type = StreamBase<CUPMStream<T>>;
26 friend crtp_base_type;
27
28 public:
29 PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
30
31 using stream_type = cupmStream_t;
32 using id_type = typename crtp_base_type::id_type;
33 using event_type = CUPMEvent<T>;
34 using flag_type = unsigned int;
35
36 CUPMStream() noexcept = default;
37
38 PetscErrorCode destroy() noexcept;
39 PetscErrorCode create(flag_type) noexcept;
40 PetscErrorCode change_type(PetscStreamType) noexcept;
41
42 private:
43 stream_type stream_{};
44 id_type id_ = new_id_();
45
46 PETSC_NODISCARD static id_type new_id_() noexcept;
47
48 // CRTP implementations
49 PETSC_NODISCARD const stream_type &get_stream_() const noexcept;
50 PETSC_NODISCARD id_type get_id_() const noexcept;
51 PetscErrorCode record_event_(event_type &) const noexcept;
52 PetscErrorCode wait_for_(event_type &) const noexcept;
53 };
54
55 template <DeviceType T>
destroy()56 inline PetscErrorCode CUPMStream<T>::destroy() noexcept
57 {
58 PetscFunctionBegin;
59 if (stream_) {
60 PetscCallCUPM(cupmStreamDestroy(stream_));
61 stream_ = cupmStream_t{};
62 id_ = 0;
63 }
64 PetscFunctionReturn(PETSC_SUCCESS);
65 }
66
67 template <DeviceType T>
create(flag_type flags)68 inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept
69 {
70 PetscFunctionBegin;
71 if (stream_) {
72 if (PetscDefined(USE_DEBUG)) {
73 flag_type current_flags;
74
75 PetscCallCUPM(cupmStreamGetFlags(stream_, ¤t_flags));
76 PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
77 }
78 PetscFunctionReturn(PETSC_SUCCESS);
79 }
80 PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
81 id_ = new_id_();
82 PetscFunctionReturn(PETSC_SUCCESS);
83 }
84
85 template <DeviceType T>
change_type(PetscStreamType newtype)86 inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept
87 {
88 PetscFunctionBegin;
89 if (newtype == PETSC_STREAM_DEFAULT || newtype == PETSC_STREAM_DEFAULT_WITH_BARRIER) {
90 PetscCall(destroy());
91 } else { // change to a nonblokcing stream
92 const flag_type preferred = cupmStreamNonBlocking;
93
94 if (stream_) {
95 flag_type flag;
96
97 PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
98 if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS);
99 PetscCall(destroy());
100 }
101 PetscCall(create(preferred));
102 }
103 PetscFunctionReturn(PETSC_SUCCESS);
104 }
105
106 template <DeviceType T>
new_id_()107 inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept
108 {
109 static id_type id = 0;
110 return id++;
111 }
112
113 // CRTP implementations
114 template <DeviceType T>
get_stream_() const115 inline const typename CUPMStream<T>::stream_type &CUPMStream<T>::get_stream_() const noexcept
116 {
117 return stream_;
118 }
119
120 template <DeviceType T>
get_id_() const121 inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept
122 {
123 return id_;
124 }
125
126 template <DeviceType T>
record_event_(event_type & event) const127 inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept
128 {
129 PetscFunctionBegin;
130 PetscCall(event.record(stream_));
131 PetscFunctionReturn(PETSC_SUCCESS);
132 }
133
134 template <DeviceType T>
wait_for_(event_type & event) const135 inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept
136 {
137 PetscFunctionBegin;
138 PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
139 PetscFunctionReturn(PETSC_SUCCESS);
140 }
141
142 } // namespace cupm
143
144 } // namespace device
145
146 } // namespace Petsc
147