xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision f97672e55eacc8688507b9471cd7ec2664d7f203)
1 #ifndef PETSCDEVICECONTEXTCUPM_HPP
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupmblasinterface.hpp>
6 
7 #include <array>
8 
9 namespace Petsc
10 {
11 
12 namespace Device
13 {
14 
15 namespace CUPM
16 {
17 
18 namespace Impl
19 {
20 
21 namespace detail
22 {
23 
24 // for tag-based dispatch of handle retrieval
25 template <typename T> struct HandleTag { using type = T; };
26 
27 } // namespace detail
28 
29 // Forward declare
30 template <DeviceType T> class PETSC_VISIBILITY_INTERNAL DeviceContext;
31 
32 template <DeviceType T>
33 class DeviceContext : Impl::BlasInterface<T>
34 {
35 public:
36   PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t,T);
37 
38 private:
39   template <typename H> using HandleTag = typename detail::HandleTag<H>;
40   using stream_tag = HandleTag<cupmStream_t>;
41   using blas_tag   = HandleTag<cupmBlasHandle_t>;
42   using solver_tag = HandleTag<cupmSolverHandle_t>;
43 
44 public:
45   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
46   // header, but since we are using the power of templates it must be declared part of
47   // this class to have easy access the same typedefs. Technically one can make a
48   // templated struct outside the class but it's more code for the same result.
49   struct PetscDeviceContext_IMPLS
50   {
51     cupmStream_t       stream;
52     cupmEvent_t        event;
53     cupmEvent_t        begin; // timer-only
54     cupmEvent_t        end;   // timer-only
55 #if PetscDefined(USE_DEBUG)
56     PetscBool          timerInUse;
57 #endif
58     cupmBlasHandle_t   blas;
59     cupmSolverHandle_t solver;
60 
61     PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; }
62     PETSC_NODISCARD auto get(blas_tag)   const -> decltype(this->blas)   { return this->blas;   }
63     PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; }
64   };
65 
66 private:
67   static bool initialized_;
68   static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   blashandles_;
69   static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> solverhandles_;
70 
71   PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS* impls_cast_(PetscDeviceContext ptr))
72   {
73     return static_cast<PetscDeviceContext_IMPLS*>(ptr->data);
74   }
75 
76   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(cupmBlasHandle_t &handle))
77   {
78     PetscFunctionBegin;
79     if (handle) PetscFunctionReturn(0);
80     for (auto i = 0; i < 3; ++i) {
81       auto cberr = cupmBlasCreate(&handle);
82       if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
83       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
84       if (i != 2) {
85         PetscCall(PetscSleep(3));
86         continue;
87       }
88       PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS,PETSC_COMM_SELF,PETSC_ERR_GPU_RESOURCE,"Unable to initialize %s",cupmBlasName());
89     }
90     PetscFunctionReturn(0);
91   }
92 
93   PETSC_CXX_COMPAT_DECL(PetscErrorCode set_handle_stream_(cupmBlasHandle_t &handle, cupmStream_t &stream))
94   {
95     cupmStream_t    cupmStream;
96 
97     PetscFunctionBegin;
98     PetscCallCUPMBLAS(cupmBlasGetStream(handle,&cupmStream));
99     if (cupmStream != stream) PetscCallCUPMBLAS(cupmBlasSetStream(handle,stream));
100     PetscFunctionReturn(0);
101   }
102 
103   PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_())
104   {
105     PetscFunctionBegin;
106     for (auto&& handle : blashandles_) {
107       if (handle) {
108         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
109         handle     = nullptr;
110       }
111     }
112     for (auto&& handle : solverhandles_) {
113       if (handle) {
114         PetscCall(cupmBlasInterface_t::DestroyHandle(handle));
115         handle    = nullptr;
116       }
117     }
118     initialized_ = false;
119     PetscFunctionReturn(0);
120   }
121 
122   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_(PetscInt id, PetscDeviceContext_IMPLS *dci))
123   {
124     PetscFunctionBegin;
125     PetscCall(PetscDeviceCheckDeviceCount_Internal(id));
126     if (!initialized_) {
127       initialized_ = true;
128       PetscCall(PetscRegisterFinalize(finalize_));
129     }
130     // use the blashandle as a canary
131     if (!blashandles_[id]) {
132       PetscCall(initialize_handle_(blashandles_[id]));
133       PetscCall(cupmBlasInterface_t::InitializeHandle(solverhandles_[id]));
134     }
135     PetscCall(set_handle_stream_(blashandles_[id],dci->stream));
136     PetscCall(cupmBlasInterface_t::SetHandleStream(solverhandles_[id],dci->stream));
137     dci->blas   = blashandles_[id];
138     dci->solver = solverhandles_[id];
139     PetscFunctionReturn(0);
140   }
141 
142 public:
143   const struct _DeviceContextOps ops = {
144     destroy,
145     changeStreamType,
146     setUp,
147     query,
148     waitForContext,
149     synchronize,
150     getHandle<blas_tag>,
151     getHandle<solver_tag>,
152     getHandle<stream_tag>,
153     beginTimer,
154     endTimer,
155   };
156 
157   // All of these functions MUST be static in order to be callable from C, otherwise they
158   // get the implicit 'this' pointer tacked on
159   PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext));
160   PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType));
161   PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext));
162   PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext,PetscBool*));
163   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext));
164   PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext));
165   template <typename Handle_t>
166   PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext,void*));
167   PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext));
168   PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*));
169 };
170 
171 template <DeviceType T>
172 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx))
173 {
174   auto           dci = impls_cast_(dctx);
175 
176   PetscFunctionBegin;
177   if (dci->stream) PetscCallCUPM(cupmStreamDestroy(dci->stream));
178   if (dci->event)  PetscCallCUPM(cupmEventDestroy(dci->event));
179   if (dci->begin)  PetscCallCUPM(cupmEventDestroy(dci->begin));
180   if (dci->end)    PetscCallCUPM(cupmEventDestroy(dci->end));
181   PetscCall(PetscFree(dctx->data));
182   PetscFunctionReturn(0);
183 }
184 
185 template <DeviceType T>
186 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype))
187 {
188   auto dci = impls_cast_(dctx);
189 
190   PetscFunctionBegin;
191   if (dci->stream) {
192     PetscCallCUPM(cupmStreamDestroy(dci->stream));
193     dci->stream = nullptr;
194   }
195   // set these to null so they aren't usable until setup is called again
196   dci->blas   = nullptr;
197   dci->solver = nullptr;
198   PetscFunctionReturn(0);
199 }
200 
201 template <DeviceType T>
202 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx))
203 {
204   auto           dci = impls_cast_(dctx);
205 
206   PetscFunctionBegin;
207   if (dci->stream) {
208     PetscCallCUPM(cupmStreamDestroy(dci->stream));
209     dci->stream = nullptr;
210   }
211   switch (dctx->streamType) {
212   case PETSC_STREAM_GLOBAL_BLOCKING:
213     // don't create a stream for global blocking
214     break;
215   case PETSC_STREAM_DEFAULT_BLOCKING:
216     PetscCallCUPM(cupmStreamCreate(&dci->stream));
217     break;
218   case PETSC_STREAM_GLOBAL_NONBLOCKING:
219     PetscCallCUPM(cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking));
220     break;
221   default:
222     SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[util::integral_value(dctx->streamType)]);
223     break;
224   }
225   if (!dci->event) PetscCallCUPM(cupmEventCreate(&dci->event));
226 #if PetscDefined(USE_DEBUG)
227   dci->timerInUse = PETSC_FALSE;
228 #endif
229   PetscCall(initialize_(dctx->device->deviceId,dci));
230   PetscFunctionReturn(0);
231 }
232 
233 template <DeviceType T>
234 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle))
235 {
236   cupmError_t cerr;
237 
238   PetscFunctionBegin;
239   cerr = cupmStreamQuery(impls_cast_(dctx)->stream);
240   if (cerr == cupmSuccess) *idle = PETSC_TRUE;
241   else {
242     // somethings gone wrong
243     if (PetscUnlikely(cerr != cupmErrorNotReady)) PetscCallCUPM(cerr);
244     *idle = PETSC_FALSE;
245   }
246   PetscFunctionReturn(0);
247 }
248 
249 template <DeviceType T>
250 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb))
251 {
252   auto        dcib = impls_cast_(dctxb);
253 
254   PetscFunctionBegin;
255   PetscCallCUPM(cupmEventRecord(dcib->event,dcib->stream));
256   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream,dcib->event,0));
257   PetscFunctionReturn(0);
258 }
259 
260 template <DeviceType T>
261 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx))
262 {
263   auto        dci = impls_cast_(dctx);
264 
265   PetscFunctionBegin;
266   // in case anything was queued on the event
267   PetscCallCUPM(cupmStreamWaitEvent(dci->stream,dci->event,0));
268   PetscCallCUPM(cupmStreamSynchronize(dci->stream));
269   PetscFunctionReturn(0);
270 }
271 
272 template <DeviceType T>
273 template <typename handle_t>
274 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle))
275 {
276   PetscFunctionBegin;
277   *static_cast<typename handle_t::type*>(handle) = impls_cast_(dctx)->get(handle_t());
278   PetscFunctionReturn(0);
279 }
280 
281 template <DeviceType T>
282 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx))
283 {
284   auto        dci = impls_cast_(dctx);
285 
286   PetscFunctionBegin;
287 #if PetscDefined(USE_DEBUG)
288   PetscCheck(!dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?");
289   dci->timerInUse = PETSC_TRUE;
290 #endif
291   if (!dci->begin) {
292     PetscCallCUPM(cupmEventCreate(&dci->begin));
293     PetscCallCUPM(cupmEventCreate(&dci->end));
294   }
295   PetscCallCUPM(cupmEventRecord(dci->begin,dci->stream));
296   PetscFunctionReturn(0);
297 }
298 
299 template <DeviceType T>
300 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed))
301 {
302   float       gtime;
303   auto        dci = impls_cast_(dctx);
304 
305   PetscFunctionBegin;
306 #if PetscDefined(USE_DEBUG)
307   PetscCheck(dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?");
308   dci->timerInUse = PETSC_FALSE;
309 #endif
310   PetscCallCUPM(cupmEventRecord(dci->end,dci->stream));
311   PetscCallCUPM(cupmEventSynchronize(dci->end));
312   PetscCallCUPM(cupmEventElapsedTime(&gtime,dci->begin,dci->end));
313   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
314   PetscFunctionReturn(0);
315 }
316 
317 // initialize the static member variables
318 template <DeviceType T> bool DeviceContext<T>::initialized_ = false;
319 
320 template <DeviceType T>
321 std::array<typename DeviceContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   DeviceContext<T>::blashandles_ = {};
322 
323 template <DeviceType T>
324 std::array<typename DeviceContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
325 
326 } // namespace Impl
327 
328 // shorten this one up a bit (and instantiate the templates)
329 using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>;
330 using CUPMContextHip  = Impl::DeviceContext<DeviceType::HIP>;
331 
332 // shorthand for what is an EXTREMELY long name
333 #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
334 
335 } // namespace CUPM
336 
337 } // namespace Device
338 
339 } // namespace Petsc
340 
341 #endif // PETSCDEVICECONTEXTCUDA_HPP
342