xref: /petsc/src/sys/objects/device/impls/cupm/cupmstream.hpp (revision d5b43468fb8780a8feea140ccd6fa3e6a50411cc)
1 #ifndef PETSC_CUPMSTREAM_HPP
2 #define PETSC_CUPMSTREAM_HPP
3 
4 #include <petsc/private/cupminterface.hpp>
5 
6 #include "../segmentedmempool.hpp"
7 #include "cupmevent.hpp"
8 
9 #if defined(__cplusplus)
10 namespace Petsc
11 {
12 
13 namespace device
14 {
15 
16 namespace cupm
17 {
18 
19 // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
20 // identify separate cupm streams. This is so that the memory pool can accelerate allocation
21 // calls as it can just pass back a pointer to memory that was used on the same
22 // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
23 // Address of the objects does not suffice since cupmStreams are very likely internally reused.
24 
25 template <DeviceType T>
26 class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
27   using crtp_base_type = StreamBase<CUPMStream<T>>;
28   friend crtp_base_type;
29 
30 public:
31   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, T);
32 
33   using stream_type = cupmStream_t;
34   using id_type     = typename crtp_base_type::id_type;
35   using event_type  = CUPMEvent<T>;
36   using flag_type   = unsigned int;
37 
38   CUPMStream() noexcept = default;
39 
40   PETSC_NODISCARD PetscErrorCode destroy() noexcept;
41   PETSC_NODISCARD PetscErrorCode create(flag_type) noexcept;
42   PETSC_NODISCARD PetscErrorCode change_type(PetscStreamType) noexcept;
43 
44 private:
45   stream_type stream_{};
46   id_type     id_ = new_id_();
47 
48   PETSC_NODISCARD static id_type new_id_() noexcept;
49 
50   // CRTP implementations
51   PETSC_NODISCARD stream_type    get_stream_() const noexcept;
52   PETSC_NODISCARD id_type        get_id_() const noexcept;
53   PETSC_NODISCARD PetscErrorCode record_event_(event_type &) const noexcept;
54   PETSC_NODISCARD PetscErrorCode wait_for_(event_type &) const noexcept;
55 };
56 
57 template <DeviceType T>
58 inline PetscErrorCode CUPMStream<T>::destroy() noexcept
59 {
60   PetscFunctionBegin;
61   if (stream_) {
62     PetscCallCUPM(cupmStreamDestroy(stream_));
63     stream_ = cupmStream_t{};
64     id_     = 0;
65   }
66   PetscFunctionReturn(0);
67 }
68 
69 template <DeviceType T>
70 inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept
71 {
72   PetscFunctionBegin;
73   if (stream_) {
74     if (PetscDefined(USE_DEBUG)) {
75       flag_type current_flags;
76 
77       PetscCallCUPM(cupmStreamGetFlags(stream_, &current_flags));
78       PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
79     }
80     PetscFunctionReturn(0);
81   }
82   PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
83   id_ = new_id_();
84   PetscFunctionReturn(0);
85 }
86 
87 template <DeviceType T>
88 inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept
89 {
90   PetscFunctionBegin;
91   if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) {
92     PetscCall(destroy());
93   } else {
94     const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking;
95 
96     if (stream_) {
97       flag_type flag;
98 
99       PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
100       if ((flag != preferred) || (cupmStreamQuery(stream_) != cupmSuccess)) PetscCall(destroy());
101     }
102     PetscCall(create(preferred));
103   }
104   PetscFunctionReturn(0);
105 }
106 
107 template <DeviceType T>
108 inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept
109 {
110   static id_type id = 0;
111   return id++;
112 }
113 
114 // CRTP implementations
115 template <DeviceType T>
116 inline typename CUPMStream<T>::stream_type CUPMStream<T>::get_stream_() const noexcept
117 {
118   return stream_;
119 }
120 
121 template <DeviceType T>
122 inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept
123 {
124   return id_;
125 }
126 
127 template <DeviceType T>
128 inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept
129 {
130   PetscFunctionBegin;
131   PetscCall(event.record(stream_));
132   PetscFunctionReturn(0);
133 }
134 
135 template <DeviceType T>
136 inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept
137 {
138   PetscFunctionBegin;
139   PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
140   PetscFunctionReturn(0);
141 }
142 
143 } // namespace cupm
144 
145 } // namespace device
146 
147 } // namespace Petsc
148 #endif // __cplusplus
149 
150 #endif // PETSC_CUPMSTREAM_HPP
151