xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision a2aba86c77ac869ca1007cc1e6f5ae9e8649f479)
1 #ifndef PETSCDEVICECONTEXTCUPM_HPP
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupmblasinterface.hpp>
6 
7 #include <array>
8 
9 namespace Petsc
10 {
11 
12 namespace Device
13 {
14 
15 namespace CUPM
16 {
17 
18 namespace Impl
19 {
20 
21 namespace detail
22 {
23 
24 // for tag-based dispatch of handle retrieval
25 template <typename T> struct HandleTag { using type = T; };
26 
27 } // namespace detail
28 
29 // Forward declare
30 template <DeviceType T> class PETSC_VISIBILITY_INTERNAL DeviceContext;
31 
32 template <DeviceType T>
33 class DeviceContext : Impl::BlasInterface<T>
34 {
35 public:
36   PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t,T);
37 
38 private:
39   template <typename H> using HandleTag = typename detail::HandleTag<H>;
40   using stream_tag = HandleTag<cupmStream_t>;
41   using blas_tag   = HandleTag<cupmBlasHandle_t>;
42   using solver_tag = HandleTag<cupmSolverHandle_t>;
43 
44 public:
45   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
46   // header, but since we are using the power of templates it must be declared part of
47   // this class to have easy access the same typedefs. Technically one can make a
48   // templated struct outside the class but it's more code for the same result.
49   struct PetscDeviceContext_IMPLS
50   {
51     cupmStream_t       stream;
52     cupmEvent_t        event;
53     cupmEvent_t        begin; // timer-only
54     cupmEvent_t        end;   // timer-only
55 #if PetscDefined(USE_DEBUG)
56     PetscBool          timerInUse;
57 #endif
58     cupmBlasHandle_t   blas;
59     cupmSolverHandle_t solver;
60 
61     PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; }
62     PETSC_NODISCARD auto get(blas_tag)   const -> decltype(this->blas)   { return this->blas;   }
63     PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; }
64   };
65 
66 private:
67   static bool initialized_;
68   static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   blashandles_;
69   static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> solverhandles_;
70 
71   PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS* impls_cast_(PetscDeviceContext ptr))
72   {
73     return static_cast<PetscDeviceContext_IMPLS*>(ptr->data);
74   }
75 
76   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(cupmBlasHandle_t &handle))
77   {
78     PetscFunctionBegin;
79     if (handle) PetscFunctionReturn(0);
80     for (auto i = 0; i < 3; ++i) {
81       auto cberr = cupmBlasCreate(&handle);
82       if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
83       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) CHKERRCUPMBLAS(cberr);
84       if (i != 2) {
85         auto ierr = PetscSleep(3);CHKERRQ(ierr);
86         continue;
87       }
88       PetscCheckFalse(cberr != CUPMBLAS_STATUS_SUCCESS,PETSC_COMM_SELF,PETSC_ERR_GPU_RESOURCE,"Unable to initialize %s",cupmBlasName());
89     }
90     PetscFunctionReturn(0);
91   }
92 
93   PETSC_CXX_COMPAT_DECL(PetscErrorCode set_handle_stream_(cupmBlasHandle_t &handle, cupmStream_t &stream))
94   {
95     cupmStream_t    cupmStream;
96     cupmBlasError_t cberr;
97 
98     PetscFunctionBegin;
99     cberr = cupmBlasGetStream(handle,&cupmStream);CHKERRCUPMBLAS(cberr);
100     if (cupmStream != stream) {cberr = cupmBlasSetStream(handle,stream);CHKERRCUPMBLAS(cberr);}
101     PetscFunctionReturn(0);
102   }
103 
104   PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_())
105   {
106     PetscFunctionBegin;
107     for (auto&& handle : blashandles_) {
108       if (handle) {
109         auto cberr = cupmBlasDestroy(handle);CHKERRCUPMBLAS(cberr);
110         handle     = nullptr;
111       }
112     }
113     for (auto&& handle : solverhandles_) {
114       if (handle) {
115         auto ierr = cupmBlasInterface_t::DestroyHandle(handle);CHKERRQ(ierr);
116         handle    = nullptr;
117       }
118     }
119     initialized_ = false;
120     PetscFunctionReturn(0);
121   }
122 
123   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_(PetscInt id, PetscDeviceContext_IMPLS *dci))
124   {
125     PetscErrorCode ierr;
126 
127     PetscFunctionBegin;
128     ierr = PetscDeviceCheckDeviceCount_Internal(id);CHKERRQ(ierr);
129     if (!initialized_) {
130       initialized_ = true;
131       ierr = PetscRegisterFinalize(finalize_);CHKERRQ(ierr);
132     }
133     // use the blashandle as a canary
134     if (!blashandles_[id]) {
135       ierr = initialize_handle_(blashandles_[id]);CHKERRQ(ierr);
136       ierr = cupmBlasInterface_t::InitializeHandle(solverhandles_[id]);CHKERRQ(ierr);
137     }
138     ierr = set_handle_stream_(blashandles_[id],dci->stream);CHKERRQ(ierr);
139     ierr = cupmBlasInterface_t::SetHandleStream(solverhandles_[id],dci->stream);CHKERRQ(ierr);
140     dci->blas   = blashandles_[id];
141     dci->solver = solverhandles_[id];
142     PetscFunctionReturn(0);
143   }
144 
145 public:
146   const struct _DeviceContextOps ops = {
147     destroy,
148     changeStreamType,
149     setUp,
150     query,
151     waitForContext,
152     synchronize,
153     getHandle<blas_tag>,
154     getHandle<solver_tag>,
155     getHandle<stream_tag>,
156     beginTimer,
157     endTimer,
158   };
159 
160   // All of these functions MUST be static in order to be callable from C, otherwise they
161   // get the implicit 'this' pointer tacked on
162   PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext));
163   PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType));
164   PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext));
165   PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext,PetscBool*));
166   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext));
167   PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext));
168   template <typename Handle_t>
169   PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext,void*));
170   PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext));
171   PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*));
172 };
173 
174 template <DeviceType T>
175 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx))
176 {
177   cupmError_t    cerr;
178   PetscErrorCode ierr;
179   auto           dci = impls_cast_(dctx);
180 
181   PetscFunctionBegin;
182   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
183   if (dci->event)  {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr);  }
184   if (dci->begin)  {cerr = cupmEventDestroy(dci->begin);CHKERRCUPM(cerr);  }
185   if (dci->end)    {cerr = cupmEventDestroy(dci->end);CHKERRCUPM(cerr);    }
186   ierr = PetscFree(dctx->data);CHKERRQ(ierr);
187   PetscFunctionReturn(0);
188 }
189 
190 template <DeviceType T>
191 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype))
192 {
193   auto dci = impls_cast_(dctx);
194 
195   PetscFunctionBegin;
196   if (dci->stream) {
197     auto cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);
198     dci->stream = nullptr;
199   }
200   // set these to null so they aren't usable until setup is called again
201   dci->blas   = nullptr;
202   dci->solver = nullptr;
203   PetscFunctionReturn(0);
204 }
205 
206 template <DeviceType T>
207 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx))
208 {
209   PetscErrorCode ierr;
210   cupmError_t    cerr;
211   auto           dci = impls_cast_(dctx);
212 
213   PetscFunctionBegin;
214   if (dci->stream) {
215     cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);
216     dci->stream = nullptr;
217   }
218   switch (dctx->streamType) {
219   case PETSC_STREAM_GLOBAL_BLOCKING:
220     // don't create a stream for global blocking
221     break;
222   case PETSC_STREAM_DEFAULT_BLOCKING:
223     cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr);
224     break;
225   case PETSC_STREAM_GLOBAL_NONBLOCKING:
226     cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr);
227     break;
228   default:
229     SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[util::integral_value(dctx->streamType)]);
230     break;
231   }
232   if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);}
233 #if PetscDefined(USE_DEBUG)
234   dci->timerInUse = PETSC_FALSE;
235 #endif
236   ierr = initialize_(dctx->device->deviceId,dci);CHKERRQ(ierr);
237   PetscFunctionReturn(0);
238 }
239 
240 template <DeviceType T>
241 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle))
242 {
243   cupmError_t cerr;
244 
245   PetscFunctionBegin;
246   cerr = cupmStreamQuery(impls_cast_(dctx)->stream);
247   if (cerr == cupmSuccess) *idle = PETSC_TRUE;
248   else {
249     // somethings gone wrong
250     if (PetscUnlikely(cerr != cupmErrorNotReady)) CHKERRCUPM(cerr);
251     *idle = PETSC_FALSE;
252   }
253   PetscFunctionReturn(0);
254 }
255 
256 template <DeviceType T>
257 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb))
258 {
259   cupmError_t cerr;
260   auto        dcib = impls_cast_(dctxb);
261 
262   PetscFunctionBegin;
263   cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr);
264   cerr = cupmStreamWaitEvent(impls_cast_(dctxa)->stream,dcib->event,0);CHKERRCUPM(cerr);
265   PetscFunctionReturn(0);
266 }
267 
268 template <DeviceType T>
269 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx))
270 {
271   cupmError_t cerr;
272   auto        dci = impls_cast_(dctx);
273 
274   PetscFunctionBegin;
275   // in case anything was queued on the event
276   cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr);
277   cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr);
278   PetscFunctionReturn(0);
279 }
280 
281 template <DeviceType T>
282 template <typename handle_t>
283 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle))
284 {
285   PetscFunctionBegin;
286   *static_cast<typename handle_t::type*>(handle) = impls_cast_(dctx)->get(handle_t());
287   PetscFunctionReturn(0);
288 }
289 
290 template <DeviceType T>
291 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx))
292 {
293   auto        dci = impls_cast_(dctx);
294   cupmError_t cerr;
295 
296   PetscFunctionBegin;
297 #if PetscDefined(USE_DEBUG)
298   PetscCheckFalse(dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?");
299   dci->timerInUse = PETSC_TRUE;
300 #endif
301   if (!dci->begin) {
302     cerr = cupmEventCreate(&dci->begin);CHKERRCUPM(cerr);
303     cerr = cupmEventCreate(&dci->end);CHKERRCUPM(cerr);
304   }
305   cerr = cupmEventRecord(dci->begin,dci->stream);CHKERRCUPM(cerr);
306   PetscFunctionReturn(0);
307 }
308 
309 template <DeviceType T>
310 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed))
311 {
312   cupmError_t cerr;
313   float       gtime;
314   auto        dci = impls_cast_(dctx);
315 
316   PetscFunctionBegin;
317 #if PetscDefined(USE_DEBUG)
318   PetscCheckFalse(!dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?");
319   dci->timerInUse = PETSC_FALSE;
320 #endif
321   cerr = cupmEventRecord(dci->end,dci->stream);CHKERRCUPM(cerr);
322   cerr = cupmEventSynchronize(dci->end);CHKERRCUPM(cerr);
323   cerr = cupmEventElapsedTime(&gtime,dci->begin,dci->end);CHKERRCUPM(cerr);
324   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
325   PetscFunctionReturn(0);
326 }
327 
328 // initialize the static member variables
329 template <DeviceType T> bool DeviceContext<T>::initialized_ = false;
330 
331 template <DeviceType T>
332 std::array<typename DeviceContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   DeviceContext<T>::blashandles_ = {};
333 
334 template <DeviceType T>
335 std::array<typename DeviceContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
336 
337 } // namespace Impl
338 
339 // shorten this one up a bit (and instantiate the templates)
340 using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>;
341 using CUPMContextHip  = Impl::DeviceContext<DeviceType::HIP>;
342 
343 // shorthand for what is an EXTREMELY long name
344 #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
345 
346 } // namespace CUPM
347 
348 } // namespace Device
349 
350 } // namespace Petsc
351 
352 #endif // PETSCDEVICECONTEXTCUDA_HPP
353