xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 0f1503a698da0451f8f2e3e02a165d4e1f526457)
1 #ifndef PETSCDEVICECONTEXTCUPM_HPP
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupmblasinterface.hpp>
6 
7 #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11)
8 #  error PetscDeviceContext backends for CUDA and HIP requires C++11
9 #endif
10 
11 #include <array>
12 
13 namespace Petsc
14 {
15 
16 namespace Device
17 {
18 
19 namespace CUPM
20 {
21 
22 namespace Impl
23 {
24 
25 namespace detail
26 {
27 
28 // for tag-based dispatch of handle retrieval
29 template <typename T> struct HandleTag { using type = T; };
30 
31 } // namespace detail
32 
33 // Forward declare
34 template <DeviceType T> class PETSC_VISIBILITY_INTERNAL DeviceContext;
35 
36 template <DeviceType T>
37 class DeviceContext : Impl::BlasInterface<T>
38 {
39 public:
40   PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t,T);
41 
42 private:
43   template <typename H> using HandleTag = typename detail::HandleTag<H>;
44   using stream_tag = HandleTag<cupmStream_t>;
45   using blas_tag   = HandleTag<cupmBlasHandle_t>;
46   using solver_tag = HandleTag<cupmSolverHandle_t>;
47 
48 public:
49   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
50   // header, but since we are using the power of templates it must be declared part of
51   // this class to have easy access the same typedefs. Technically one can make a
52   // templated struct outside the class but it's more code for the same result.
53   struct PetscDeviceContext_IMPLS
54   {
55     cupmStream_t       stream;
56     cupmEvent_t        event;
57     cupmEvent_t        begin; // timer-only
58     cupmEvent_t        end;   // timer-only
59 #if PetscDefined(USE_DEBUG)
60     PetscBool          timerInUse;
61 #endif
62     cupmBlasHandle_t   blas;
63     cupmSolverHandle_t solver;
64 
65     PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; }
66     PETSC_NODISCARD auto get(blas_tag)   const -> decltype(this->blas)   { return this->blas;   }
67     PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; }
68   };
69 
70 private:
71   static bool initialized_;
72   static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   blashandles_;
73   static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> solverhandles_;
74 
75   PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS* impls_cast_(PetscDeviceContext ptr))
76   {
77     return static_cast<PetscDeviceContext_IMPLS*>(ptr->data);
78   }
79 
80   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(cupmBlasHandle_t &handle))
81   {
82     PetscFunctionBegin;
83     if (handle) PetscFunctionReturn(0);
84     for (auto i = 0; i < 3; ++i) {
85       auto cberr = cupmBlasCreate(&handle);
86       if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
87       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) CHKERRCUPMBLAS(cberr);
88       if (i != 2) {
89         auto ierr = PetscSleep(3);CHKERRQ(ierr);
90         continue;
91       }
92       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_SUCCESS)) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_GPU_RESOURCE,"Unable to initialize %s",cupmBlasName());
93     }
94     PetscFunctionReturn(0);
95   }
96 
97   PETSC_CXX_COMPAT_DECL(PetscErrorCode set_handle_stream_(cupmBlasHandle_t &handle, cupmStream_t &stream))
98   {
99     cupmStream_t    cupmStream;
100     cupmBlasError_t cberr;
101 
102     PetscFunctionBegin;
103     cberr = cupmBlasGetStream(handle,&cupmStream);CHKERRCUPMBLAS(cberr);
104     if (cupmStream != stream) {cberr = cupmBlasSetStream(handle,stream);CHKERRCUPMBLAS(cberr);}
105     PetscFunctionReturn(0);
106   }
107 
108   PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_())
109   {
110     PetscFunctionBegin;
111     for (auto&& handle : blashandles_) {
112       if (handle) {
113         auto cberr = cupmBlasDestroy(handle);CHKERRCUPMBLAS(cberr);
114         handle     = nullptr;
115       }
116     }
117     for (auto&& handle : solverhandles_) {
118       if (handle) {
119         auto ierr = cupmBlasInterface_t::DestroyHandle(handle);CHKERRQ(ierr);
120         handle    = nullptr;
121       }
122     }
123     initialized_ = false;
124     PetscFunctionReturn(0);
125   }
126 
127   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_(PetscInt id, PetscDeviceContext_IMPLS *dci))
128   {
129     PetscErrorCode ierr;
130 
131     PetscFunctionBegin;
132     ierr = PetscDeviceCheckDeviceCount_Internal(id);CHKERRQ(ierr);
133     if (!initialized_) {
134       initialized_ = true;
135       ierr = PetscRegisterFinalize(finalize_);CHKERRQ(ierr);
136     }
137     // use the blashandle as a canary
138     if (!blashandles_[id]) {
139       ierr = initialize_handle_(blashandles_[id]);CHKERRQ(ierr);
140       ierr = cupmBlasInterface_t::InitializeHandle(solverhandles_[id]);CHKERRQ(ierr);
141     }
142     ierr = set_handle_stream_(blashandles_[id],dci->stream);CHKERRQ(ierr);
143     ierr = cupmBlasInterface_t::SetHandleStream(solverhandles_[id],dci->stream);CHKERRQ(ierr);
144     dci->blas   = blashandles_[id];
145     dci->solver = solverhandles_[id];
146     PetscFunctionReturn(0);
147   }
148 
149 public:
150   const struct _DeviceContextOps ops = {
151     destroy,
152     changeStreamType,
153     setUp,
154     query,
155     waitForContext,
156     synchronize,
157     getHandle<blas_tag>,
158     getHandle<solver_tag>,
159     getHandle<stream_tag>,
160     beginTimer,
161     endTimer,
162   };
163 
164   // All of these functions MUST be static in order to be callable from C, otherwise they
165   // get the implicit 'this' pointer tacked on
166   PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext));
167   PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType));
168   PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext));
169   PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext,PetscBool*));
170   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext));
171   PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext));
172   template <typename Handle_t>
173   PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext,void*));
174   PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext));
175   PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*));
176 };
177 
178 template <DeviceType T>
179 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx))
180 {
181   cupmError_t    cerr;
182   PetscErrorCode ierr;
183   auto           dci = impls_cast_(dctx);
184 
185   PetscFunctionBegin;
186   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
187   if (dci->event)  {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr);  }
188   if (dci->begin)  {cerr = cupmEventDestroy(dci->begin);CHKERRCUPM(cerr);  }
189   if (dci->end)    {cerr = cupmEventDestroy(dci->end);CHKERRCUPM(cerr);    }
190   ierr = PetscFree(dctx->data);CHKERRQ(ierr);
191   PetscFunctionReturn(0);
192 }
193 
194 template <DeviceType T>
195 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype))
196 {
197   auto dci = impls_cast_(dctx);
198 
199   PetscFunctionBegin;
200   if (dci->stream) {
201     auto cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);
202     dci->stream = nullptr;
203   }
204   // set these to null so they aren't usable until setup is called again
205   dci->blas   = nullptr;
206   dci->solver = nullptr;
207   PetscFunctionReturn(0);
208 }
209 
210 template <DeviceType T>
211 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx))
212 {
213   PetscErrorCode ierr;
214   cupmError_t    cerr;
215   auto           dci = impls_cast_(dctx);
216 
217   PetscFunctionBegin;
218   if (dci->stream) {
219     cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);
220     dci->stream = nullptr;
221   }
222   switch (dctx->streamType) {
223   case PETSC_STREAM_GLOBAL_BLOCKING:
224     // don't create a stream for global blocking
225     break;
226   case PETSC_STREAM_DEFAULT_BLOCKING:
227     cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr);
228     break;
229   case PETSC_STREAM_GLOBAL_NONBLOCKING:
230     cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr);
231     break;
232   default:
233     SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[util::integral_value(dctx->streamType)]);
234     break;
235   }
236   if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);}
237 #if PetscDefined(USE_DEBUG)
238   dci->timerInUse = PETSC_FALSE;
239 #endif
240   ierr = initialize_(dctx->device->deviceId,dci);CHKERRQ(ierr);
241   PetscFunctionReturn(0);
242 }
243 
244 template <DeviceType T>
245 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle))
246 {
247   cupmError_t cerr;
248 
249   PetscFunctionBegin;
250   cerr = cupmStreamQuery(impls_cast_(dctx)->stream);
251   if (cerr == cupmSuccess) *idle = PETSC_TRUE;
252   else {
253     // somethings gone wrong
254     if (PetscUnlikely(cerr != cupmErrorNotReady)) CHKERRCUPM(cerr);
255     *idle = PETSC_FALSE;
256   }
257   PetscFunctionReturn(0);
258 }
259 
260 template <DeviceType T>
261 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb))
262 {
263   cupmError_t cerr;
264   auto        dcib = impls_cast_(dctxb);
265 
266   PetscFunctionBegin;
267   cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr);
268   cerr = cupmStreamWaitEvent(impls_cast_(dctxa)->stream,dcib->event,0);CHKERRCUPM(cerr);
269   PetscFunctionReturn(0);
270 }
271 
272 template <DeviceType T>
273 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx))
274 {
275   cupmError_t cerr;
276   auto        dci = impls_cast_(dctx);
277 
278   PetscFunctionBegin;
279   // in case anything was queued on the event
280   cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr);
281   cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr);
282   PetscFunctionReturn(0);
283 }
284 
285 template <DeviceType T>
286 template <typename handle_t>
287 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle))
288 {
289   PetscFunctionBegin;
290   *static_cast<typename handle_t::type*>(handle) = impls_cast_(dctx)->get(handle_t());
291   PetscFunctionReturn(0);
292 }
293 
294 template <DeviceType T>
295 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx))
296 {
297   auto        dci = impls_cast_(dctx);
298   cupmError_t cerr;
299 
300   PetscFunctionBegin;
301 #if PetscDefined(USE_DEBUG)
302   if (PetscUnlikely(dci->timerInUse)) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?");
303   dci->timerInUse = PETSC_TRUE;
304 #endif
305   if (!dci->begin) {
306     cerr = cupmEventCreate(&dci->begin);CHKERRCUPM(cerr);
307     cerr = cupmEventCreate(&dci->end);CHKERRCUPM(cerr);
308   }
309   cerr = cupmEventRecord(dci->begin,dci->stream);CHKERRCUPM(cerr);
310   PetscFunctionReturn(0);
311 }
312 
313 template <DeviceType T>
314 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed))
315 {
316   cupmError_t cerr;
317   float       gtime;
318   auto        dci = impls_cast_(dctx);
319 
320   PetscFunctionBegin;
321 #if PetscDefined(USE_DEBUG)
322   if (PetscUnlikely(!dci->timerInUse)) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?");
323   dci->timerInUse = PETSC_FALSE;
324 #endif
325   cerr = cupmEventRecord(dci->end,dci->stream);CHKERRCUPM(cerr);
326   cerr = cupmEventSynchronize(dci->end);CHKERRCUPM(cerr);
327   cerr = cupmEventElapsedTime(&gtime,dci->begin,dci->end);CHKERRCUPM(cerr);
328   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
329   PetscFunctionReturn(0);
330 }
331 
332 // initialize the static member variables
333 template <DeviceType T> bool DeviceContext<T>::initialized_ = false;
334 
335 template <DeviceType T>
336 std::array<typename DeviceContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   DeviceContext<T>::blashandles_ = {};
337 
338 template <DeviceType T>
339 std::array<typename DeviceContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
340 
341 } // namespace Impl
342 
343 // shorten this one up a bit (and instantiate the templates)
344 using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>;
345 using CUPMContextHip  = Impl::DeviceContext<DeviceType::HIP>;
346 
347 // shorthand for what is an EXTREMELY long name
348 #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
349 
350 } // namespace CUPM
351 
352 } // namespace Device
353 
354 } // namespace Petsc
355 
356 #endif // PETSCDEVICECONTEXTCUDA_HPP
357