xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 3ca90d2d9fe4d5ec7086bd4aee14f89370d16392)
1 #if !defined(PETSCDEVICECONTEXTCUPM_HPP)
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h>
5 #include <petsc/private/cupminterface.hpp>
6 
7 #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11)
8 #error PetscDeviceContext backends for CUDA and HIP requires C++11
9 #endif
10 
11 #include <array>
12 
13 namespace Petsc
14 {
15 
16 namespace detail
17 {
18 
19 // for tag-based dispatch of handle retrieval
20 template <typename HT> struct HandleTag { };
21 
22 } // namespace detail
23 
24 // Forward declare
25 template <CUPMDeviceType T> class CUPMContext;
26 
27 template <CUPMDeviceType T>
28 class CUPMContext : CUPMInterface<T>
29 {
30   template <typename H>
31   using HandleTag = typename detail::HandleTag<H>;
32 
33 public:
34   PETSC_INHERIT_CUPM_INTERFACE_TYPEDEFS_USING(cupmInterface_t,T)
35 
36   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
37   // header, but since we are using the power of templates it must be declared part of
38   // this class to have easy access the same typedefs. Technically one can make a
39   // templated struct outside the class but it's more code for the same result.
40   struct PetscDeviceContext_IMPLS
41   {
42     cupmStream_t       stream;
43     cupmEvent_t        event;
44     cupmEvent_t        begin; // timer-only
45     cupmEvent_t        end;   // timer-only
46 #if PetscDefined(USE_DEBUG)
47     PetscBool          timerInUse;
48 #endif
49     cupmBlasHandle_t   blas;
50     cupmSolverHandle_t solver;
51 
52     PETSC_NODISCARD cupmBlasHandle_t   handle(HandleTag<cupmBlasHandle_t>)   { return blas;   }
53     PETSC_NODISCARD cupmSolverHandle_t handle(HandleTag<cupmSolverHandle_t>) { return solver; }
54   };
55 
56 private:
57   static bool _initialized;
58   static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   _blashandles;
59   static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> _solverhandles;
60 
61   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS* __impls_cast(PetscDeviceContext ptr) noexcept
62   {
63     return static_cast<PetscDeviceContext_IMPLS*>(ptr->data);
64   }
65 
66   PETSC_NODISCARD static PetscErrorCode __finalize() noexcept
67   {
68     PetscErrorCode ierr;
69 
70     PetscFunctionBegin;
71     for (auto&& handle : _blashandles) {
72       if (handle) {ierr = cupmInterface_t::DestroyHandle(handle);CHKERRQ(ierr);}
73     }
74     for (auto&& handle : _solverhandles) {
75       if (handle) {ierr = cupmInterface_t::DestroyHandle(handle);CHKERRQ(ierr);}
76     }
77     _initialized = false;
78     PetscFunctionReturn(0);
79   }
80 
81   PETSC_NODISCARD static PetscErrorCode __initialize(PetscInt id, PetscDeviceContext_IMPLS *dci) noexcept
82   {
83     PetscErrorCode ierr;
84 
85     PetscFunctionBegin;
86     ierr = PetscDeviceCheckDeviceCount_Internal(id);CHKERRQ(ierr);
87     if (!_initialized) {
88       _initialized = true;
89       ierr = PetscRegisterFinalize(__finalize);CHKERRQ(ierr);
90     }
91     // use the blashandle as a canary
92     if (!_blashandles[id]) {
93       ierr = cupmInterface_t::InitializeHandle(_blashandles[id]);CHKERRQ(ierr);
94       ierr = cupmInterface_t::InitializeHandle(_solverhandles[id]);CHKERRQ(ierr);
95     }
96     ierr = cupmInterface_t::SetHandleStream(_blashandles[id],dci->stream);CHKERRQ(ierr);
97     ierr = cupmInterface_t::SetHandleStream(_solverhandles[id],dci->stream);CHKERRQ(ierr);
98     dci->blas   = _blashandles[id];
99     dci->solver = _solverhandles[id];
100     PetscFunctionReturn(0);
101   }
102 
103 public:
104   const struct _DeviceContextOps ops = {
105     destroy,
106     changeStreamType,
107     setUp,
108     query,
109     waitForContext,
110     synchronize,
111     getHandle<cupmBlasHandle_t>,
112     getHandle<cupmSolverHandle_t>,
113     beginTimer,
114     endTimer
115   };
116 
117   // default constructor
118   constexpr CUPMContext() noexcept = default;
119 
120   // All of these functions MUST be static in order to be callable from C, otherwise they
121   // get the implicit 'this' pointer tacked on
122   PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept;
123   PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType) noexcept;
124   PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept;
125   PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext,PetscBool*) noexcept;
126   PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext) noexcept;
127   PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
128   template <typename Handle_t>
129   PETSC_NODISCARD static PetscErrorCode getHandle(PetscDeviceContext,void*) noexcept;
130   PETSC_NODISCARD static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
131   PETSC_NODISCARD static PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*) noexcept;
132 };
133 
134 template <CUPMDeviceType T>
135 inline PetscErrorCode CUPMContext<T>::destroy(PetscDeviceContext dctx) noexcept
136 {
137   cupmError_t    cerr;
138   PetscErrorCode ierr;
139   auto           dci = __impls_cast(dctx);
140 
141   PetscFunctionBegin;
142   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
143   if (dci->event)  {
144     cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr);
145     cerr = cupmEventDestroy(dci->begin);CHKERRCUPM(cerr);
146     cerr = cupmEventDestroy(dci->end);CHKERRCUPM(cerr);
147   }
148   ierr = PetscFree(dctx->data);CHKERRQ(ierr);
149   PetscFunctionReturn(0);
150 }
151 
152 template <CUPMDeviceType T>
153 inline PetscErrorCode CUPMContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
154 {
155   auto dci = __impls_cast(dctx);
156 
157   PetscFunctionBegin;
158   if (dci->stream) {
159     cupmError_t cerr;
160 
161     cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);
162     dci->stream = nullptr;
163   }
164   // set these to null so they aren't usable until setup is called again
165   dci->blas   = nullptr;
166   dci->solver = nullptr;
167   PetscFunctionReturn(0);
168 }
169 
170 template <CUPMDeviceType T>
171 inline PetscErrorCode CUPMContext<T>::setUp(PetscDeviceContext dctx) noexcept
172 {
173   PetscErrorCode ierr;
174   cupmError_t    cerr;
175   auto           dci = __impls_cast(dctx);
176 
177   PetscFunctionBegin;
178   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
179   switch (dctx->streamType) {
180   case PETSC_STREAM_GLOBAL_BLOCKING:
181     // don't create a stream for global blocking
182     dci->stream = nullptr;
183     break;
184   case PETSC_STREAM_DEFAULT_BLOCKING:
185     cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr);
186     break;
187   case PETSC_STREAM_GLOBAL_NONBLOCKING:
188     cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr);
189     break;
190   default:
191     SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[static_cast<int>(dctx->streamType)]);
192     break;
193   }
194   if (!dci->event) {
195     cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);
196     cerr = cupmEventCreate(&dci->begin);CHKERRCUPM(cerr);
197     cerr = cupmEventCreate(&dci->end);CHKERRCUPM(cerr);
198   }
199 #if PetscDefined(USE_DEBUG)
200   dci->timerInUse = PETSC_FALSE;
201 #endif
202   ierr = __initialize(dctx->device->deviceId,dci);CHKERRQ(ierr);
203   PetscFunctionReturn(0);
204 }
205 
206 template <CUPMDeviceType T>
207 inline PetscErrorCode CUPMContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
208 {
209   cupmError_t cerr;
210 
211   PetscFunctionBegin;
212   cerr = cupmStreamQuery(__impls_cast(dctx)->stream);
213   if (cerr == cupmSuccess) *idle = PETSC_TRUE;
214   else {
215     // somethings gone wrong
216     if (PetscUnlikely(cerr != cupmErrorNotReady)) CHKERRCUPM(cerr);
217     *idle = PETSC_FALSE;
218   }
219   PetscFunctionReturn(0);
220 }
221 
222 template <CUPMDeviceType T>
223 inline PetscErrorCode CUPMContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
224 {
225   cupmError_t cerr;
226   auto        dcia = __impls_cast(dctxa),dcib = __impls_cast(dctxb);
227 
228   PetscFunctionBegin;
229   cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr);
230   cerr = cupmStreamWaitEvent(dcia->stream,dcib->event,0);CHKERRCUPM(cerr);
231   PetscFunctionReturn(0);
232 }
233 
234 template <CUPMDeviceType T>
235 inline PetscErrorCode CUPMContext<T>::synchronize(PetscDeviceContext dctx) noexcept
236 {
237   cupmError_t cerr;
238   auto        dci = __impls_cast(dctx);
239 
240   PetscFunctionBegin;
241   // in case anything was queued on the event
242   cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr);
243   cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr);
244   PetscFunctionReturn(0);
245 }
246 
247 template <CUPMDeviceType T>
248 template <typename Handle_T>
249 inline PetscErrorCode CUPMContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
250 {
251   PetscFunctionBegin;
252   *static_cast<Handle_T*>(handle) = __impls_cast(dctx)->handle(HandleTag<Handle_T>());
253   PetscFunctionReturn(0);
254 }
255 
256 template <CUPMDeviceType T>
257 inline PetscErrorCode CUPMContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
258 {
259   auto        dci = __impls_cast(dctx);
260   cupmError_t cerr;
261 
262   PetscFunctionBegin;
263 #if PetscDefined(USE_DEBUG)
264   if (PetscUnlikely(dci->timerInUse)) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?");
265   dci->timerInUse = PETSC_TRUE;
266 #endif
267   cerr = cupmEventRecord(dci->begin,dci->stream);CHKERRCUPM(cerr);
268   PetscFunctionReturn(0);
269 }
270 
271 template <CUPMDeviceType T>
272 inline PetscErrorCode CUPMContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
273 {
274   cupmError_t cerr;
275   float       gtime;
276   auto        dci = __impls_cast(dctx);
277 
278   PetscFunctionBegin;
279 #if PetscDefined(USE_DEBUG)
280   if (PetscUnlikely(!dci->timerInUse)) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?");
281   dci->timerInUse = PETSC_FALSE;
282 #endif
283   cerr = cupmEventRecord(dci->end,dci->stream);CHKERRCUPM(cerr);
284   cerr = cupmEventSynchronize(dci->end);CHKERRCUPM(cerr);
285   cerr = cupmEventElapsedTime(&gtime,dci->begin,dci->end);CHKERRCUPM(cerr);
286   *elapsed = static_cast<PetscLogDouble>(gtime);
287   PetscFunctionReturn(0);
288 }
289 
290 // initialize the static member variables
291 template <CUPMDeviceType T> bool CUPMContext<T>::_initialized = false;
292 
293 template <CUPMDeviceType T>
294 std::array<typename CUPMContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES>   CUPMContext<T>::_blashandles = {};
295 
296 template <CUPMDeviceType T>
297 std::array<typename CUPMContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> CUPMContext<T>::_solverhandles = {};
298 
299 // shorten this one up a bit (and instantiate the templates)
300 using CUPMContextCuda = CUPMContext<CUPMDeviceType::CUDA>;
301 using CUPMContextHip  = CUPMContext<CUPMDeviceType::HIP>;
302 
303 } // namespace Petsc
304 
305 // shorthand for what is an EXTREMELY long name
306 #define PetscDeviceContext_(IMPLS) Petsc::CUPMContext<Petsc::CUPMDeviceType::IMPLS>::PetscDeviceContext_IMPLS
307 
308 // shorthand for casting dctx->data to the appropriate object to access the handles
309 #define PDC_IMPLS_STATIC_CAST(IMPLS,obj) static_cast<PetscDeviceContext_(IMPLS) *>((obj)->data)
310 
311 #endif // PETSCDEVICECONTEXTCUDA_HPP
312