xref: /petsc/src/sys/objects/device/impls/cupm/cupmstream.hpp (revision cf906205f1df435953e2a75894dc7795d3ff0aed)
1 #ifndef PETSC_CUPMSTREAM_HPP
2 #define PETSC_CUPMSTREAM_HPP
3 
4 #include <petsc/private/cupminterface.hpp>
5 
6 #include "../segmentedmempool.hpp"
7 #include "cupmevent.hpp"
8 
9 #if defined(__cplusplus)
10 namespace Petsc {
11 
12 namespace device {
13 
14 namespace cupm {
15 
16 // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
17 // identify separate cupm streams. This is so that the memory pool can accelerate allocation
18 // calls as it can just pass back a pointer to memory that was used on the same
19 // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
20 // Address of the objects does not suffice since cupmStreams are very likely internally reused.
21 
22 template <DeviceType T>
23 class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
24   using crtp_base_type = StreamBase<CUPMStream<T>>;
25   friend crtp_base_type;
26 
27 public:
28   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, T);
29 
30   using stream_type = cupmStream_t;
31   using id_type     = typename crtp_base_type::id_type;
32   using event_type  = CUPMEvent<T>;
33   using flag_type   = unsigned int;
34 
35   CUPMStream() noexcept = default;
36 
37   PETSC_NODISCARD PetscErrorCode destroy() noexcept;
38   PETSC_NODISCARD PetscErrorCode create(flag_type) noexcept;
39   PETSC_NODISCARD PetscErrorCode change_type(PetscStreamType) noexcept;
40 
41 private:
42   stream_type stream_{};
43   id_type     id_ = new_id_();
44 
45   PETSC_NODISCARD static id_type new_id_() noexcept;
46 
47   // CRTP implementations
48   PETSC_NODISCARD stream_type    get_stream_() const noexcept;
49   PETSC_NODISCARD id_type        get_id_() const noexcept;
50   PETSC_NODISCARD PetscErrorCode record_event_(event_type &) const noexcept;
51   PETSC_NODISCARD PetscErrorCode wait_for_(event_type &) const noexcept;
52 };
53 
54 template <DeviceType T>
55 inline PetscErrorCode CUPMStream<T>::destroy() noexcept {
56   PetscFunctionBegin;
57   if (stream_) {
58     PetscCallCUPM(cupmStreamDestroy(stream_));
59     stream_ = cupmStream_t{};
60     id_     = 0;
61   }
62   PetscFunctionReturn(0);
63 }
64 
65 template <DeviceType T>
66 inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept {
67   PetscFunctionBegin;
68   if (stream_) {
69     if (PetscDefined(USE_DEBUG)) {
70       flag_type current_flags;
71 
72       PetscCallCUPM(cupmStreamGetFlags(stream_, &current_flags));
73       PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
74     }
75     PetscFunctionReturn(0);
76   }
77   PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
78   id_ = new_id_();
79   PetscFunctionReturn(0);
80 }
81 
82 template <DeviceType T>
83 inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept {
84   PetscFunctionBegin;
85   if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) {
86     PetscCall(destroy());
87   } else {
88     const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking;
89 
90     if (stream_) {
91       flag_type flag;
92 
93       PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
94       if ((flag != preferred) || (cupmStreamQuery(stream_) != cupmSuccess)) PetscCall(destroy());
95     }
96     PetscCall(create(preferred));
97   }
98   PetscFunctionReturn(0);
99 }
100 
101 template <DeviceType T>
102 inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept {
103   static id_type id = 0;
104   return id++;
105 }
106 
107 // CRTP implementations
108 template <DeviceType T>
109 inline typename CUPMStream<T>::stream_type CUPMStream<T>::get_stream_() const noexcept {
110   return stream_;
111 }
112 
113 template <DeviceType T>
114 inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept {
115   return id_;
116 }
117 
118 template <DeviceType T>
119 inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept {
120   PetscFunctionBegin;
121   PetscCall(event.record(stream_));
122   PetscFunctionReturn(0);
123 }
124 
125 template <DeviceType T>
126 inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept {
127   PetscFunctionBegin;
128   PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
129   PetscFunctionReturn(0);
130 }
131 
132 } // namespace cupm
133 
134 } // namespace device
135 
136 } // namespace Petsc
137 #endif // __cplusplus
138 
139 #endif // PETSC_CUPMSTREAM_HPP
140