xref: /petsc/src/sys/objects/device/impls/cupm/cupmstream.hpp (revision 21e3ffae2f3b73c0bd738cf6d0a809700fc04bb0)
1 #ifndef PETSC_CUPMSTREAM_HPP
2 #define PETSC_CUPMSTREAM_HPP
3 
4 #include <petsc/private/cupminterface.hpp>
5 
6 #include "../segmentedmempool.hpp"
7 #include "cupmevent.hpp"
8 
9 #if defined(__cplusplus)
10 namespace Petsc
11 {
12 
13 namespace device
14 {
15 
16 namespace cupm
17 {
18 
19 // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
20 // identify separate cupm streams. This is so that the memory pool can accelerate allocation
21 // calls as it can just pass back a pointer to memory that was used on the same
22 // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
23 // Address of the objects does not suffice since cupmStreams are very likely internally reused.
24 
25 template <DeviceType T>
26 class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
27   using crtp_base_type = StreamBase<CUPMStream<T>>;
28   friend crtp_base_type;
29 
30 public:
31   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, T);
32 
33   using stream_type = cupmStream_t;
34   using id_type     = typename crtp_base_type::id_type;
35   using event_type  = CUPMEvent<T>;
36   using flag_type   = unsigned int;
37 
38   CUPMStream() noexcept = default;
39 
40   PetscErrorCode destroy() noexcept;
41   PetscErrorCode create(flag_type) noexcept;
42   PetscErrorCode change_type(PetscStreamType) noexcept;
43 
44 private:
45   stream_type stream_{};
46   id_type     id_ = new_id_();
47 
48   PETSC_NODISCARD static id_type new_id_() noexcept;
49 
50   // CRTP implementations
51   PETSC_NODISCARD stream_type get_stream_() const noexcept;
52   PETSC_NODISCARD id_type     get_id_() const noexcept;
53   PetscErrorCode              record_event_(event_type &) const noexcept;
54   PetscErrorCode              wait_for_(event_type &) const noexcept;
55 };
56 
57 template <DeviceType T>
58 inline PetscErrorCode CUPMStream<T>::destroy() noexcept
59 {
60   PetscFunctionBegin;
61   if (stream_) {
62     PetscCallCUPM(cupmStreamDestroy(stream_));
63     stream_ = cupmStream_t{};
64     id_     = 0;
65   }
66   PetscFunctionReturn(PETSC_SUCCESS);
67 }
68 
69 template <DeviceType T>
70 inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept
71 {
72   PetscFunctionBegin;
73   if (stream_) {
74     if (PetscDefined(USE_DEBUG)) {
75       flag_type current_flags;
76 
77       PetscCallCUPM(cupmStreamGetFlags(stream_, &current_flags));
78       PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
79     }
80     PetscFunctionReturn(PETSC_SUCCESS);
81   }
82   PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
83   id_ = new_id_();
84   PetscFunctionReturn(PETSC_SUCCESS);
85 }
86 
87 template <DeviceType T>
88 inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept
89 {
90   PetscFunctionBegin;
91   if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) {
92     PetscCall(destroy());
93   } else {
94     const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking;
95 
96     if (stream_) {
97       flag_type flag;
98 
99       PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
100       if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS);
101       PetscCall(destroy());
102     }
103     PetscCall(create(preferred));
104   }
105   PetscFunctionReturn(PETSC_SUCCESS);
106 }
107 
108 template <DeviceType T>
109 inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept
110 {
111   static id_type id = 0;
112   return id++;
113 }
114 
115 // CRTP implementations
116 template <DeviceType T>
117 inline typename CUPMStream<T>::stream_type CUPMStream<T>::get_stream_() const noexcept
118 {
119   return stream_;
120 }
121 
122 template <DeviceType T>
123 inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept
124 {
125   return id_;
126 }
127 
128 template <DeviceType T>
129 inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept
130 {
131   PetscFunctionBegin;
132   PetscCall(event.record(stream_));
133   PetscFunctionReturn(PETSC_SUCCESS);
134 }
135 
136 template <DeviceType T>
137 inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept
138 {
139   PetscFunctionBegin;
140   PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
141   PetscFunctionReturn(PETSC_SUCCESS);
142 }
143 
144 } // namespace cupm
145 
146 } // namespace device
147 
148 } // namespace Petsc
149 #endif // __cplusplus
150 
151 #endif // PETSC_CUPMSTREAM_HPP
152