xref: /petsc/src/sys/objects/device/impls/cupm/cupmstream.hpp (revision 66af8762ec03dbef0e079729eb2a1734a35ed7ff)
1 #pragma once
2 
3 #include <petsc/private/cupminterface.hpp>
4 
5 #include "../segmentedmempool.hpp"
6 #include "cupmevent.hpp"
7 
8 namespace Petsc
9 {
10 
11 namespace device
12 {
13 
14 namespace cupm
15 {
16 
17 // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
18 // identify separate cupm streams. This is so that the memory pool can accelerate allocation
19 // calls as it can just pass back a pointer to memory that was used on the same
20 // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
21 // Address of the objects does not suffice since cupmStreams are very likely internally reused.
22 
23 template <DeviceType T>
24 class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
25   using crtp_base_type = StreamBase<CUPMStream<T>>;
26   friend crtp_base_type;
27 
28 public:
29   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
30 
31   using stream_type = cupmStream_t;
32   using id_type     = typename crtp_base_type::id_type;
33   using event_type  = CUPMEvent<T>;
34   using flag_type   = unsigned int;
35 
36   CUPMStream() noexcept = default;
37 
38   PetscErrorCode destroy() noexcept;
39   PetscErrorCode create(flag_type) noexcept;
40   PetscErrorCode change_type(PetscStreamType) noexcept;
41 
42 private:
43   stream_type stream_{};
44   id_type     id_ = new_id_();
45 
46   PETSC_NODISCARD static id_type new_id_() noexcept;
47 
48   // CRTP implementations
49   PETSC_NODISCARD const stream_type &get_stream_() const noexcept;
50   PETSC_NODISCARD id_type            get_id_() const noexcept;
51   PetscErrorCode                     record_event_(event_type &) const noexcept;
52   PetscErrorCode                     wait_for_(event_type &) const noexcept;
53 };
54 
55 template <DeviceType T>
56 inline PetscErrorCode CUPMStream<T>::destroy() noexcept
57 {
58   PetscFunctionBegin;
59   if (stream_) {
60     PetscCallCUPM(cupmStreamDestroy(stream_));
61     stream_ = cupmStream_t{};
62     id_     = 0;
63   }
64   PetscFunctionReturn(PETSC_SUCCESS);
65 }
66 
67 template <DeviceType T>
68 inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept
69 {
70   PetscFunctionBegin;
71   if (stream_) {
72     if (PetscDefined(USE_DEBUG)) {
73       flag_type current_flags;
74 
75       PetscCallCUPM(cupmStreamGetFlags(stream_, &current_flags));
76       PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
77     }
78     PetscFunctionReturn(PETSC_SUCCESS);
79   }
80   PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
81   id_ = new_id_();
82   PetscFunctionReturn(PETSC_SUCCESS);
83 }
84 
85 template <DeviceType T>
86 inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept
87 {
88   PetscFunctionBegin;
89   if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) {
90     PetscCall(destroy());
91   } else {
92     const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking;
93 
94     if (stream_) {
95       flag_type flag;
96 
97       PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
98       if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS);
99       PetscCall(destroy());
100     }
101     PetscCall(create(preferred));
102   }
103   PetscFunctionReturn(PETSC_SUCCESS);
104 }
105 
106 template <DeviceType T>
107 inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept
108 {
109   static id_type id = 0;
110   return id++;
111 }
112 
113 // CRTP implementations
114 template <DeviceType T>
115 inline const typename CUPMStream<T>::stream_type &CUPMStream<T>::get_stream_() const noexcept
116 {
117   return stream_;
118 }
119 
120 template <DeviceType T>
121 inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept
122 {
123   return id_;
124 }
125 
126 template <DeviceType T>
127 inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept
128 {
129   PetscFunctionBegin;
130   PetscCall(event.record(stream_));
131   PetscFunctionReturn(PETSC_SUCCESS);
132 }
133 
134 template <DeviceType T>
135 inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept
136 {
137   PetscFunctionBegin;
138   PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
139   PetscFunctionReturn(PETSC_SUCCESS);
140 }
141 
142 } // namespace cupm
143 
144 } // namespace device
145 
146 } // namespace Petsc
147