xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 030f984af8d8bb4c203755d35bded3c05b3d83ce)
1 #if !defined(PETSCDEVICECONTEXTCUPM_HPP)
2 #define PETSCDEVICECONTEXTCUPM_HPP
3 
4 #include <petsc/private/deviceimpl.h> /*I "petscdevice.h" I*/
5 #include <petsc/private/cupminterface.hpp>
6 
7 #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11)
8 #error PetscDeviceContext backends for CUDA and HIP requires C++11
9 #endif
10 
11 namespace Petsc {
12 
13 // Forward declare
14 template <CUPMDeviceKind T> class CUPMContext;
15 
16 template <CUPMDeviceKind T>
17 class CUPMContext : CUPMInterface<T>
18 {
19 public:
20   PETSC_INHERIT_CUPM_INTERFACE_TYPEDEFS_USING(cupmInterface_t,T);
21 
22   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
23   // header, but since we are using the power of templates it must be declared part of
24   // this class to have easy access the same typedefs. Technically one can make a
25   // templated struct outside the class but it's more code for the same result.
26   struct PetscDeviceContext_IMPLS
27   {
28     cupmStream_t       stream;
29     cupmEvent_t        event;
30     cupmBlasHandle_t   blas;
31     cupmSolverHandle_t solver;
32   };
33 
34 private:
35   static cupmBlasHandle_t   _blashandle;
36   static cupmSolverHandle_t _solverhandle;
37 
38   PETSC_NODISCARD static PetscErrorCode __finalizeBLASHandle() noexcept
39   {
40     PetscErrorCode ierr;
41 
42     PetscFunctionBegin;
43     ierr = cupmInterface_t::DestroyHandle(_blashandle);CHKERRQ(ierr);
44     PetscFunctionReturn(0);
45   }
46 
47   PETSC_NODISCARD static PetscErrorCode __finalizeSOLVERHandle() noexcept
48   {
49     PetscErrorCode ierr;
50 
51     PetscFunctionBegin;
52     ierr = cupmInterface_t::DestroyHandle(_solverhandle);CHKERRQ(ierr);
53     PetscFunctionReturn(0);
54   }
55 
56   PETSC_NODISCARD static PetscErrorCode __setupHandles(PetscDeviceContext_IMPLS *dci) noexcept
57   {
58     PetscErrorCode  ierr;
59 
60     PetscFunctionBegin;
61     if (!_blashandle) {
62       ierr = cupmInterface_t::InitializeHandle(_blashandle);CHKERRQ(ierr);
63       ierr = PetscRegisterFinalize(__finalizeBLASHandle);CHKERRQ(ierr);
64     }
65     if (!_solverhandle) {
66       ierr = cupmInterface_t::InitializeHandle(_solverhandle);CHKERRQ(ierr);
67       ierr = PetscRegisterFinalize(__finalizeSOLVERHandle);CHKERRQ(ierr);
68     }
69     ierr = cupmInterface_t::SetHandleStream(_blashandle,dci->stream);CHKERRQ(ierr);
70     ierr = cupmInterface_t::SetHandleStream(_solverhandle,dci->stream);CHKERRQ(ierr);
71     dci->blas   = _blashandle;
72     dci->solver = _solverhandle;
73     PetscFunctionReturn(0);
74   }
75 
76 public:
77   const struct _DeviceContextOps ops {destroy,changeStreamType,setUp,query,waitForContext,synchronize};
78 
79   // default constructor
80   constexpr CUPMContext() noexcept = default;
81 
82   // All of these functions MUST be static in order to be callable from C, otherwise they
83   // get the implicit 'this' pointer tacked on
84   PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept;
85   PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType) noexcept;
86   PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept;
87   PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext,PetscBool*) noexcept;
88   PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext) noexcept;
89   PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
90 };
91 
92 #define IMPLS_RCAST_(obj_) static_cast<PetscDeviceContext_IMPLS*>((obj_)->data)
93 
94 template <CUPMDeviceKind T>
95 inline PetscErrorCode CUPMContext<T>::destroy(PetscDeviceContext dctx) noexcept
96 {
97   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
98   cupmError_t              cerr;
99   PetscErrorCode           ierr;
100 
101   PetscFunctionBegin;
102   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
103   if (dci->event)  {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr);}
104   ierr = PetscFree(dctx->data);CHKERRQ(ierr);
105   PetscFunctionReturn(0);
106 }
107 
108 template <CUPMDeviceKind T>
109 inline PetscErrorCode CUPMContext<T>::changeStreamType(PetscDeviceContext dctx, PetscStreamType stype) noexcept
110 {
111   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
112 
113   PetscFunctionBegin;
114   if (dci->stream) {
115     cupmError_t cerr;
116 
117     cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);
118     dci->stream = nullptr;
119   }
120   // set these to null so they aren't usable until setup is called again
121   dci->blas   = nullptr;
122   dci->solver = nullptr;
123   PetscFunctionReturn(0);
124 }
125 
126 template <CUPMDeviceKind T>
127 inline PetscErrorCode CUPMContext<T>::setUp(PetscDeviceContext dctx) noexcept
128 {
129   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
130   PetscErrorCode           ierr;
131   cupmError_t              cerr;
132 
133   PetscFunctionBegin;
134   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
135   switch (dctx->streamType) {
136   case PETSC_STREAM_GLOBAL_BLOCKING:
137     // don't create a stream for global blocking
138     dci->stream = nullptr;
139     break;
140   case PETSC_STREAM_DEFAULT_BLOCKING:
141     cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr);
142     break;
143   case PETSC_STREAM_GLOBAL_NONBLOCKING:
144     cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr);
145     break;
146   default:
147     SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %d",dctx->streamType);
148     break;
149   }
150   if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);}
151   ierr = __setupHandles(dci);CHKERRQ(ierr);
152   PetscFunctionReturn(0);
153 }
154 
155 template <CUPMDeviceKind T>
156 inline PetscErrorCode CUPMContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
157 {
158   cupmError_t cerr;
159 
160   PetscFunctionBegin;
161   cerr = cupmStreamQuery(IMPLS_RCAST_(dctx)->stream);
162   if (cerr == cupmSuccess)
163     *idle = PETSC_TRUE;
164   else if (cerr == cupmErrorNotReady) {
165     *idle = PETSC_FALSE;
166   } else {
167     // somethings gone wrong
168     CHKERRCUPM(cerr);
169   }
170   PetscFunctionReturn(0);
171 }
172 
173 template <CUPMDeviceKind T>
174 inline PetscErrorCode CUPMContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
175 {
176   PetscDeviceContext_IMPLS *dcia = IMPLS_RCAST_(dctxa);
177   PetscDeviceContext_IMPLS *dcib = IMPLS_RCAST_(dctxb);
178   cupmError_t               cerr;
179 
180   PetscFunctionBegin;
181   cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr);
182   cerr = cupmStreamWaitEvent(dcia->stream,dcib->event,0);CHKERRCUPM(cerr);
183   PetscFunctionReturn(0);
184 }
185 
186 template <CUPMDeviceKind T>
187 inline PetscErrorCode CUPMContext<T>::synchronize(PetscDeviceContext dctx) noexcept
188 {
189   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
190   cupmError_t               cerr;
191 
192   PetscFunctionBegin;
193   // in case anything was queued on the event
194   cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr);
195   cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr);
196   PetscFunctionReturn(0);
197 }
198 
199 // initialize the static member variables
200 template <CUPMDeviceKind T>
201 typename CUPMContext<T>::cupmBlasHandle_t   CUPMContext<T>::_blashandle   = nullptr;
202 
203 template <CUPMDeviceKind T>
204 typename CUPMContext<T>::cupmSolverHandle_t CUPMContext<T>::_solverhandle = nullptr;
205 
206 // shorten this one up a bit
207 using CUPMContextCuda = CUPMContext<CUPMDeviceKind::CUDA>;
208 using CUPMContextHip  = CUPMContext<CUPMDeviceKind::HIP>;
209 
210 // make sure these doesn't leak out
211 #undef CHKERRCUPM
212 #undef IMPLS_RCAST_
213 
214 } // namespace Petsc
215 
216 // shorthand for what is an EXTREMELY long name
217 #define PetscDeviceContext_(impls_) Petsc::CUPMContext<Petsc::CUPMDeviceKind::impls_>::PetscDeviceContext_IMPLS
218 
219 // shorthand for casting dctx->data to the appropriate object to access the handles
220 #define PDC_IMPLS_RCAST(impls_,obj_) reinterpret_cast<PetscDeviceContext_(impls_) *>((obj_)->data)
221 
222 #endif /* PETSCDEVICECONTEXTCUDA_HPP */
223