xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 7a101e5e7ba9859de4c800924a501d6ea3cd325c)
1 #ifndef PETSCDEVICECONTEXTCUPM_HPP
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupmblasinterface.hpp>
6 #include <petsc/private/logimpl.h>
7 
8 #include <array>
9 
10 namespace Petsc
11 {
12 
13 namespace Device
14 {
15 
16 namespace CUPM
17 {
18 
19 namespace Impl
20 {
21 
22 // Forward declare
23 template <DeviceType T> class PETSC_VISIBILITY_INTERNAL DeviceContext;
24 
25 template <DeviceType T>
26 class DeviceContext : Impl::BlasInterface<T>
27 {
28 public:
29   PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t,T);
30 
31 private:
32   // for tag-based dispatch of handle retrieval
33   template <typename H,std::size_t> struct HandleTag { using type = H; };
34   using stream_tag = HandleTag<cupmStream_t,0>;
35   using blas_tag   = HandleTag<cupmBlasHandle_t,1>;
36   using solver_tag = HandleTag<cupmSolverHandle_t,2>;
37 
38 public:
39   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
40   // header, but since we are using the power of templates it must be declared part of
41   // this class to have easy access the same typedefs. Technically one can make a
42   // templated struct outside the class but it's more code for the same result.
43   struct PetscDeviceContext_IMPLS
44   {
45     cupmStream_t       stream;
46     cupmEvent_t        event;
47     cupmEvent_t        begin; // timer-only
48     cupmEvent_t        end;   // timer-only
49 #if PetscDefined(USE_DEBUG)
50     PetscBool          timerInUse;
51 #endif
52     cupmBlasHandle_t   blas;
53     cupmSolverHandle_t solver;
54 
55     PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; }
56     PETSC_NODISCARD auto get(blas_tag)   const -> decltype(this->blas)   { return this->blas;   }
57     PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; }
58   };
59 
60 private:
61   static bool initialized_;
62   static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   blashandles_;
63   static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> solverhandles_;
64 
65   PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS* impls_cast_(PetscDeviceContext ptr))
66   {
67     return static_cast<PetscDeviceContext_IMPLS*>(ptr->data);
68   }
69 
70   PETSC_CXX_COMPAT_DECL(constexpr PetscLogEvent CUPMBLAS_HANDLE_CREATE())
71   {
72     return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE;
73   }
74 
75   PETSC_CXX_COMPAT_DECL(constexpr PetscLogEvent CUPMSOLVER_HANDLE_CREATE())
76   {
77     return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE;
78   }
79 
80   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
81   // handles
82   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag,PetscDeviceContext)) { return 0; }
83 
84   PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmBlasHandle_t &handle))
85   {
86     PetscLogEvent event;
87 
88     PetscFunctionBegin;
89     if (PetscLikely(handle)) PetscFunctionReturn(0);
90     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
91     PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(),0,0,0,0));
92     for (auto i = 0; i < 3; ++i) {
93       auto cberr = cupmBlasCreate(&handle);
94       if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
95       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
96       if (i != 2) {
97         PetscCall(PetscSleep(3));
98         continue;
99       }
100       PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS,PETSC_COMM_SELF,PETSC_ERR_GPU_RESOURCE,"Unable to initialize %s",cupmBlasName());
101     }
102     PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(),0,0,0,0));
103     PetscCall(PetscLogEventResume_Internal(event));
104     PetscFunctionReturn(0);
105   }
106 
107   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx))
108   {
109     const auto dci = impls_cast_(dctx);
110     auto& handle = blashandles_[dctx->device->deviceId];
111 
112     PetscFunctionBegin;
113     PetscCall(create_handle_(handle));
114     PetscCallCUPMBLAS(cupmBlasSetStream(handle,dci->stream));
115     dci->blas = handle;
116     PetscFunctionReturn(0);
117   }
118 
119   PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(cupmSolverHandle_t &handle))
120   {
121     PetscLogEvent event;
122 
123     PetscFunctionBegin;
124     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
125     PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(),0,0,0,0));
126     PetscCall(cupmBlasInterface_t::InitializeHandle(handle));
127     PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(),0,0,0,0));
128     PetscCall(PetscLogEventResume_Internal(event));
129     PetscFunctionReturn(0);
130   }
131 
132   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx))
133   {
134     const auto dci    = impls_cast_(dctx);
135     auto&      handle = solverhandles_[dctx->device->deviceId];
136 
137     PetscFunctionBegin;
138     PetscCall(create_handle_(handle));
139     PetscCall(cupmBlasInterface_t::SetHandleStream(handle,dci->stream));
140     dci->solver = handle;
141     PetscFunctionReturn(0);
142   }
143 
144   PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_())
145   {
146     PetscFunctionBegin;
147     for (auto&& handle : blashandles_) {
148       if (handle) {
149         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
150         handle     = nullptr;
151       }
152     }
153     for (auto&& handle : solverhandles_) {
154       if (handle) {
155         PetscCall(cupmBlasInterface_t::DestroyHandle(handle));
156         handle    = nullptr;
157       }
158     }
159     initialized_ = false;
160     PetscFunctionReturn(0);
161   }
162 
163 public:
164   const struct _DeviceContextOps ops = {
165     destroy,
166     changeStreamType,
167     setUp,
168     query,
169     waitForContext,
170     synchronize,
171     getHandle<blas_tag>,
172     getHandle<solver_tag>,
173     getHandle<stream_tag>,
174     beginTimer,
175     endTimer,
176   };
177 
178   // All of these functions MUST be static in order to be callable from C, otherwise they
179   // get the implicit 'this' pointer tacked on
180   PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext));
181   PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType));
182   PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext));
183   PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext,PetscBool*));
184   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext));
185   PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext));
186   template <typename Handle_t>
187   PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext,void*));
188   PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext));
189   PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*));
190 
191   // not a PetscDeviceContext method, this registers the class
192   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize());
193 };
194 
195 template <DeviceType T>
196 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize())
197 {
198   PetscFunctionBegin;
199   if (PetscUnlikely(!initialized_)) {
200     initialized_ = true;
201     PetscCall(PetscRegisterFinalize(finalize_));
202   }
203   PetscFunctionReturn(0);
204 }
205 
206 template <DeviceType T>
207 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx))
208 {
209   const auto dci = impls_cast_(dctx);
210 
211   PetscFunctionBegin;
212   if (dci->stream) PetscCallCUPM(cupmStreamDestroy(dci->stream));
213   if (dci->event)  PetscCallCUPM(cupmEventDestroy(dci->event));
214   if (dci->begin)  PetscCallCUPM(cupmEventDestroy(dci->begin));
215   if (dci->end)    PetscCallCUPM(cupmEventDestroy(dci->end));
216   PetscCall(PetscFree(dctx->data));
217   PetscFunctionReturn(0);
218 }
219 
220 template <DeviceType T>
221 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype))
222 {
223   const auto dci = impls_cast_(dctx);
224 
225   PetscFunctionBegin;
226   if (auto& stream = dci->stream) {
227     PetscCallCUPM(cupmStreamDestroy(stream));
228     stream = nullptr;
229   }
230   // set these to null so they aren't usable until setup is called again
231   dci->blas   = nullptr;
232   dci->solver = nullptr;
233   PetscFunctionReturn(0);
234 }
235 
236 template <DeviceType T>
237 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx))
238 {
239   const auto dci    = impls_cast_(dctx);
240   auto&      stream = dci->stream;
241 
242   PetscFunctionBegin;
243   if (stream) {
244     PetscCallCUPM(cupmStreamDestroy(stream));
245     stream = nullptr;
246   }
247   switch (const auto stype = dctx->streamType) {
248   case PETSC_STREAM_GLOBAL_BLOCKING:
249     // don't create a stream for global blocking
250     break;
251   case PETSC_STREAM_DEFAULT_BLOCKING:
252     PetscCallCUPM(cupmStreamCreate(&stream));
253     break;
254   case PETSC_STREAM_GLOBAL_NONBLOCKING:
255     PetscCallCUPM(cupmStreamCreateWithFlags(&stream,cupmStreamNonBlocking));
256     break;
257   default:
258     SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[util::integral_value(stype)]);
259     break;
260   }
261   if (!dci->event) PetscCallCUPM(cupmEventCreate(&dci->event));
262 #if PetscDefined(USE_DEBUG)
263   dci->timerInUse = PETSC_FALSE;
264 #endif
265   PetscFunctionReturn(0);
266 }
267 
268 template <DeviceType T>
269 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle))
270 {
271   cupmError_t cerr;
272 
273   PetscFunctionBegin;
274   cerr = cupmStreamQuery(impls_cast_(dctx)->stream);
275   if (cerr == cupmSuccess) *idle = PETSC_TRUE;
276   else {
277     // somethings gone wrong
278     if (PetscUnlikely(cerr != cupmErrorNotReady)) PetscCallCUPM(cerr);
279     *idle = PETSC_FALSE;
280   }
281   PetscFunctionReturn(0);
282 }
283 
284 template <DeviceType T>
285 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb))
286 {
287   auto        dcib = impls_cast_(dctxb);
288 
289   PetscFunctionBegin;
290   PetscCallCUPM(cupmEventRecord(dcib->event,dcib->stream));
291   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream,dcib->event,0));
292   PetscFunctionReturn(0);
293 }
294 
295 template <DeviceType T>
296 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx))
297 {
298   auto        dci = impls_cast_(dctx);
299 
300   PetscFunctionBegin;
301   // in case anything was queued on the event
302   PetscCallCUPM(cupmStreamWaitEvent(dci->stream,dci->event,0));
303   PetscCallCUPM(cupmStreamSynchronize(dci->stream));
304   PetscFunctionReturn(0);
305 }
306 
307 template <DeviceType T>
308 template <typename handle_t>
309 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle))
310 {
311   PetscFunctionBegin;
312   PetscCall(initialize_handle_(handle_t{},dctx));
313   *static_cast<typename handle_t::type*>(handle) = impls_cast_(dctx)->get(handle_t{});
314   PetscFunctionReturn(0);
315 }
316 
317 template <DeviceType T>
318 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx))
319 {
320   auto        dci = impls_cast_(dctx);
321 
322   PetscFunctionBegin;
323 #if PetscDefined(USE_DEBUG)
324   PetscCheck(!dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?");
325   dci->timerInUse = PETSC_TRUE;
326 #endif
327   if (!dci->begin) {
328     PetscCallCUPM(cupmEventCreate(&dci->begin));
329     PetscCallCUPM(cupmEventCreate(&dci->end));
330   }
331   PetscCallCUPM(cupmEventRecord(dci->begin,dci->stream));
332   PetscFunctionReturn(0);
333 }
334 
335 template <DeviceType T>
336 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed))
337 {
338   float       gtime;
339   auto        dci = impls_cast_(dctx);
340 
341   PetscFunctionBegin;
342 #if PetscDefined(USE_DEBUG)
343   PetscCheck(dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?");
344   dci->timerInUse = PETSC_FALSE;
345 #endif
346   PetscCallCUPM(cupmEventRecord(dci->end,dci->stream));
347   PetscCallCUPM(cupmEventSynchronize(dci->end));
348   PetscCallCUPM(cupmEventElapsedTime(&gtime,dci->begin,dci->end));
349   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
350   PetscFunctionReturn(0);
351 }
352 
353 // initialize the static member variables
354 template <DeviceType T> bool DeviceContext<T>::initialized_ = false;
355 
356 template <DeviceType T>
357 std::array<typename DeviceContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   DeviceContext<T>::blashandles_ = {};
358 
359 template <DeviceType T>
360 std::array<typename DeviceContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
361 
362 } // namespace Impl
363 
364 // shorten this one up a bit (and instantiate the templates)
365 using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>;
366 using CUPMContextHip  = Impl::DeviceContext<DeviceType::HIP>;
367 
368 // shorthand for what is an EXTREMELY long name
369 #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
370 
371 } // namespace CUPM
372 
373 } // namespace Device
374 
375 } // namespace Petsc
376 
377 #endif // PETSCDEVICECONTEXTCUDA_HPP
378