xref: /petsc/src/sys/objects/device/impls/cupm/cupmstream.hpp (revision fbf9dbe564678ed6eff1806adbc4c4f01b9743f4)
1 #ifndef PETSC_CUPMSTREAM_HPP
2 #define PETSC_CUPMSTREAM_HPP
3 
4 #include <petsc/private/cupminterface.hpp>
5 
6 #include "../segmentedmempool.hpp"
7 #include "cupmevent.hpp"
8 
9 namespace Petsc
10 {
11 
12 namespace device
13 {
14 
15 namespace cupm
16 {
17 
18 // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
19 // identify separate cupm streams. This is so that the memory pool can accelerate allocation
20 // calls as it can just pass back a pointer to memory that was used on the same
21 // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
22 // Address of the objects does not suffice since cupmStreams are very likely internally reused.
23 
24 template <DeviceType T>
25 class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
26   using crtp_base_type = StreamBase<CUPMStream<T>>;
27   friend crtp_base_type;
28 
29 public:
30   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
31 
32   using stream_type = cupmStream_t;
33   using id_type     = typename crtp_base_type::id_type;
34   using event_type  = CUPMEvent<T>;
35   using flag_type   = unsigned int;
36 
37   CUPMStream() noexcept = default;
38 
39   PetscErrorCode destroy() noexcept;
40   PetscErrorCode create(flag_type) noexcept;
41   PetscErrorCode change_type(PetscStreamType) noexcept;
42 
43 private:
44   stream_type stream_{};
45   id_type     id_ = new_id_();
46 
47   PETSC_NODISCARD static id_type new_id_() noexcept;
48 
49   // CRTP implementations
50   PETSC_NODISCARD const stream_type &get_stream_() const noexcept;
51   PETSC_NODISCARD id_type            get_id_() const noexcept;
52   PetscErrorCode                     record_event_(event_type &) const noexcept;
53   PetscErrorCode                     wait_for_(event_type &) const noexcept;
54 };
55 
56 template <DeviceType T>
57 inline PetscErrorCode CUPMStream<T>::destroy() noexcept
58 {
59   PetscFunctionBegin;
60   if (stream_) {
61     PetscCallCUPM(cupmStreamDestroy(stream_));
62     stream_ = cupmStream_t{};
63     id_     = 0;
64   }
65   PetscFunctionReturn(PETSC_SUCCESS);
66 }
67 
68 template <DeviceType T>
69 inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept
70 {
71   PetscFunctionBegin;
72   if (stream_) {
73     if (PetscDefined(USE_DEBUG)) {
74       flag_type current_flags;
75 
76       PetscCallCUPM(cupmStreamGetFlags(stream_, &current_flags));
77       PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
78     }
79     PetscFunctionReturn(PETSC_SUCCESS);
80   }
81   PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
82   id_ = new_id_();
83   PetscFunctionReturn(PETSC_SUCCESS);
84 }
85 
86 template <DeviceType T>
87 inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept
88 {
89   PetscFunctionBegin;
90   if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) {
91     PetscCall(destroy());
92   } else {
93     const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking;
94 
95     if (stream_) {
96       flag_type flag;
97 
98       PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
99       if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS);
100       PetscCall(destroy());
101     }
102     PetscCall(create(preferred));
103   }
104   PetscFunctionReturn(PETSC_SUCCESS);
105 }
106 
107 template <DeviceType T>
108 inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept
109 {
110   static id_type id = 0;
111   return id++;
112 }
113 
114 // CRTP implementations
115 template <DeviceType T>
116 inline const typename CUPMStream<T>::stream_type &CUPMStream<T>::get_stream_() const noexcept
117 {
118   return stream_;
119 }
120 
121 template <DeviceType T>
122 inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept
123 {
124   return id_;
125 }
126 
127 template <DeviceType T>
128 inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept
129 {
130   PetscFunctionBegin;
131   PetscCall(event.record(stream_));
132   PetscFunctionReturn(PETSC_SUCCESS);
133 }
134 
135 template <DeviceType T>
136 inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept
137 {
138   PetscFunctionBegin;
139   PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
140   PetscFunctionReturn(PETSC_SUCCESS);
141 }
142 
143 } // namespace cupm
144 
145 } // namespace device
146 
147 } // namespace Petsc
148 
149 #endif // PETSC_CUPMSTREAM_HPP
150