xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 030f984af8d8bb4c203755d35bded3c05b3d83ce)
1*030f984aSJacob Faibussowitsch #if !defined(PETSCDEVICECONTEXTCUPM_HPP)
2*030f984aSJacob Faibussowitsch #define PETSCDEVICECONTEXTCUPM_HPP
3*030f984aSJacob Faibussowitsch 
4*030f984aSJacob Faibussowitsch #include <petsc/private/deviceimpl.h> /*I "petscdevice.h" I*/
5*030f984aSJacob Faibussowitsch #include <petsc/private/cupminterface.hpp>
6*030f984aSJacob Faibussowitsch 
7*030f984aSJacob Faibussowitsch #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11)
8*030f984aSJacob Faibussowitsch #error PetscDeviceContext backends for CUDA and HIP requires C++11
9*030f984aSJacob Faibussowitsch #endif
10*030f984aSJacob Faibussowitsch 
11*030f984aSJacob Faibussowitsch namespace Petsc {
12*030f984aSJacob Faibussowitsch 
13*030f984aSJacob Faibussowitsch // Forward declare
14*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> class CUPMContext;
15*030f984aSJacob Faibussowitsch 
16*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
17*030f984aSJacob Faibussowitsch class CUPMContext : CUPMInterface<T>
18*030f984aSJacob Faibussowitsch {
19*030f984aSJacob Faibussowitsch public:
20*030f984aSJacob Faibussowitsch   PETSC_INHERIT_CUPM_INTERFACE_TYPEDEFS_USING(cupmInterface_t,T);
21*030f984aSJacob Faibussowitsch 
22*030f984aSJacob Faibussowitsch   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
23*030f984aSJacob Faibussowitsch   // header, but since we are using the power of templates it must be declared part of
24*030f984aSJacob Faibussowitsch   // this class to have easy access the same typedefs. Technically one can make a
25*030f984aSJacob Faibussowitsch   // templated struct outside the class but it's more code for the same result.
26*030f984aSJacob Faibussowitsch   struct PetscDeviceContext_IMPLS
27*030f984aSJacob Faibussowitsch   {
28*030f984aSJacob Faibussowitsch     cupmStream_t       stream;
29*030f984aSJacob Faibussowitsch     cupmEvent_t        event;
30*030f984aSJacob Faibussowitsch     cupmBlasHandle_t   blas;
31*030f984aSJacob Faibussowitsch     cupmSolverHandle_t solver;
32*030f984aSJacob Faibussowitsch   };
33*030f984aSJacob Faibussowitsch 
34*030f984aSJacob Faibussowitsch private:
35*030f984aSJacob Faibussowitsch   static cupmBlasHandle_t   _blashandle;
36*030f984aSJacob Faibussowitsch   static cupmSolverHandle_t _solverhandle;
37*030f984aSJacob Faibussowitsch 
38*030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode __finalizeBLASHandle() noexcept
39*030f984aSJacob Faibussowitsch   {
40*030f984aSJacob Faibussowitsch     PetscErrorCode ierr;
41*030f984aSJacob Faibussowitsch 
42*030f984aSJacob Faibussowitsch     PetscFunctionBegin;
43*030f984aSJacob Faibussowitsch     ierr = cupmInterface_t::DestroyHandle(_blashandle);CHKERRQ(ierr);
44*030f984aSJacob Faibussowitsch     PetscFunctionReturn(0);
45*030f984aSJacob Faibussowitsch   }
46*030f984aSJacob Faibussowitsch 
47*030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode __finalizeSOLVERHandle() noexcept
48*030f984aSJacob Faibussowitsch   {
49*030f984aSJacob Faibussowitsch     PetscErrorCode ierr;
50*030f984aSJacob Faibussowitsch 
51*030f984aSJacob Faibussowitsch     PetscFunctionBegin;
52*030f984aSJacob Faibussowitsch     ierr = cupmInterface_t::DestroyHandle(_solverhandle);CHKERRQ(ierr);
53*030f984aSJacob Faibussowitsch     PetscFunctionReturn(0);
54*030f984aSJacob Faibussowitsch   }
55*030f984aSJacob Faibussowitsch 
56*030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode __setupHandles(PetscDeviceContext_IMPLS *dci) noexcept
57*030f984aSJacob Faibussowitsch   {
58*030f984aSJacob Faibussowitsch     PetscErrorCode  ierr;
59*030f984aSJacob Faibussowitsch 
60*030f984aSJacob Faibussowitsch     PetscFunctionBegin;
61*030f984aSJacob Faibussowitsch     if (!_blashandle) {
62*030f984aSJacob Faibussowitsch       ierr = cupmInterface_t::InitializeHandle(_blashandle);CHKERRQ(ierr);
63*030f984aSJacob Faibussowitsch       ierr = PetscRegisterFinalize(__finalizeBLASHandle);CHKERRQ(ierr);
64*030f984aSJacob Faibussowitsch     }
65*030f984aSJacob Faibussowitsch     if (!_solverhandle) {
66*030f984aSJacob Faibussowitsch       ierr = cupmInterface_t::InitializeHandle(_solverhandle);CHKERRQ(ierr);
67*030f984aSJacob Faibussowitsch       ierr = PetscRegisterFinalize(__finalizeSOLVERHandle);CHKERRQ(ierr);
68*030f984aSJacob Faibussowitsch     }
69*030f984aSJacob Faibussowitsch     ierr = cupmInterface_t::SetHandleStream(_blashandle,dci->stream);CHKERRQ(ierr);
70*030f984aSJacob Faibussowitsch     ierr = cupmInterface_t::SetHandleStream(_solverhandle,dci->stream);CHKERRQ(ierr);
71*030f984aSJacob Faibussowitsch     dci->blas   = _blashandle;
72*030f984aSJacob Faibussowitsch     dci->solver = _solverhandle;
73*030f984aSJacob Faibussowitsch     PetscFunctionReturn(0);
74*030f984aSJacob Faibussowitsch   }
75*030f984aSJacob Faibussowitsch 
76*030f984aSJacob Faibussowitsch public:
77*030f984aSJacob Faibussowitsch   const struct _DeviceContextOps ops {destroy,changeStreamType,setUp,query,waitForContext,synchronize};
78*030f984aSJacob Faibussowitsch 
79*030f984aSJacob Faibussowitsch   // default constructor
80*030f984aSJacob Faibussowitsch   constexpr CUPMContext() noexcept = default;
81*030f984aSJacob Faibussowitsch 
82*030f984aSJacob Faibussowitsch   // All of these functions MUST be static in order to be callable from C, otherwise they
83*030f984aSJacob Faibussowitsch   // get the implicit 'this' pointer tacked on
84*030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept;
85*030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType) noexcept;
86*030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept;
87*030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext,PetscBool*) noexcept;
88*030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext) noexcept;
89*030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
90*030f984aSJacob Faibussowitsch };
91*030f984aSJacob Faibussowitsch 
92*030f984aSJacob Faibussowitsch #define IMPLS_RCAST_(obj_) static_cast<PetscDeviceContext_IMPLS*>((obj_)->data)
93*030f984aSJacob Faibussowitsch 
94*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
95*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::destroy(PetscDeviceContext dctx) noexcept
96*030f984aSJacob Faibussowitsch {
97*030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
98*030f984aSJacob Faibussowitsch   cupmError_t              cerr;
99*030f984aSJacob Faibussowitsch   PetscErrorCode           ierr;
100*030f984aSJacob Faibussowitsch 
101*030f984aSJacob Faibussowitsch   PetscFunctionBegin;
102*030f984aSJacob Faibussowitsch   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
103*030f984aSJacob Faibussowitsch   if (dci->event)  {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr);}
104*030f984aSJacob Faibussowitsch   ierr = PetscFree(dctx->data);CHKERRQ(ierr);
105*030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
106*030f984aSJacob Faibussowitsch }
107*030f984aSJacob Faibussowitsch 
108*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
109*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::changeStreamType(PetscDeviceContext dctx, PetscStreamType stype) noexcept
110*030f984aSJacob Faibussowitsch {
111*030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
112*030f984aSJacob Faibussowitsch 
113*030f984aSJacob Faibussowitsch   PetscFunctionBegin;
114*030f984aSJacob Faibussowitsch   if (dci->stream) {
115*030f984aSJacob Faibussowitsch     cupmError_t cerr;
116*030f984aSJacob Faibussowitsch 
117*030f984aSJacob Faibussowitsch     cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);
118*030f984aSJacob Faibussowitsch     dci->stream = nullptr;
119*030f984aSJacob Faibussowitsch   }
120*030f984aSJacob Faibussowitsch   // set these to null so they aren't usable until setup is called again
121*030f984aSJacob Faibussowitsch   dci->blas   = nullptr;
122*030f984aSJacob Faibussowitsch   dci->solver = nullptr;
123*030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
124*030f984aSJacob Faibussowitsch }
125*030f984aSJacob Faibussowitsch 
126*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
127*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::setUp(PetscDeviceContext dctx) noexcept
128*030f984aSJacob Faibussowitsch {
129*030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
130*030f984aSJacob Faibussowitsch   PetscErrorCode           ierr;
131*030f984aSJacob Faibussowitsch   cupmError_t              cerr;
132*030f984aSJacob Faibussowitsch 
133*030f984aSJacob Faibussowitsch   PetscFunctionBegin;
134*030f984aSJacob Faibussowitsch   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
135*030f984aSJacob Faibussowitsch   switch (dctx->streamType) {
136*030f984aSJacob Faibussowitsch   case PETSC_STREAM_GLOBAL_BLOCKING:
137*030f984aSJacob Faibussowitsch     // don't create a stream for global blocking
138*030f984aSJacob Faibussowitsch     dci->stream = nullptr;
139*030f984aSJacob Faibussowitsch     break;
140*030f984aSJacob Faibussowitsch   case PETSC_STREAM_DEFAULT_BLOCKING:
141*030f984aSJacob Faibussowitsch     cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr);
142*030f984aSJacob Faibussowitsch     break;
143*030f984aSJacob Faibussowitsch   case PETSC_STREAM_GLOBAL_NONBLOCKING:
144*030f984aSJacob Faibussowitsch     cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr);
145*030f984aSJacob Faibussowitsch     break;
146*030f984aSJacob Faibussowitsch   default:
147*030f984aSJacob Faibussowitsch     SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %d",dctx->streamType);
148*030f984aSJacob Faibussowitsch     break;
149*030f984aSJacob Faibussowitsch   }
150*030f984aSJacob Faibussowitsch   if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);}
151*030f984aSJacob Faibussowitsch   ierr = __setupHandles(dci);CHKERRQ(ierr);
152*030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
153*030f984aSJacob Faibussowitsch }
154*030f984aSJacob Faibussowitsch 
155*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
156*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
157*030f984aSJacob Faibussowitsch {
158*030f984aSJacob Faibussowitsch   cupmError_t cerr;
159*030f984aSJacob Faibussowitsch 
160*030f984aSJacob Faibussowitsch   PetscFunctionBegin;
161*030f984aSJacob Faibussowitsch   cerr = cupmStreamQuery(IMPLS_RCAST_(dctx)->stream);
162*030f984aSJacob Faibussowitsch   if (cerr == cupmSuccess)
163*030f984aSJacob Faibussowitsch     *idle = PETSC_TRUE;
164*030f984aSJacob Faibussowitsch   else if (cerr == cupmErrorNotReady) {
165*030f984aSJacob Faibussowitsch     *idle = PETSC_FALSE;
166*030f984aSJacob Faibussowitsch   } else {
167*030f984aSJacob Faibussowitsch     // somethings gone wrong
168*030f984aSJacob Faibussowitsch     CHKERRCUPM(cerr);
169*030f984aSJacob Faibussowitsch   }
170*030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
171*030f984aSJacob Faibussowitsch }
172*030f984aSJacob Faibussowitsch 
173*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
174*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
175*030f984aSJacob Faibussowitsch {
176*030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dcia = IMPLS_RCAST_(dctxa);
177*030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dcib = IMPLS_RCAST_(dctxb);
178*030f984aSJacob Faibussowitsch   cupmError_t               cerr;
179*030f984aSJacob Faibussowitsch 
180*030f984aSJacob Faibussowitsch   PetscFunctionBegin;
181*030f984aSJacob Faibussowitsch   cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr);
182*030f984aSJacob Faibussowitsch   cerr = cupmStreamWaitEvent(dcia->stream,dcib->event,0);CHKERRCUPM(cerr);
183*030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
184*030f984aSJacob Faibussowitsch }
185*030f984aSJacob Faibussowitsch 
186*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
187*030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::synchronize(PetscDeviceContext dctx) noexcept
188*030f984aSJacob Faibussowitsch {
189*030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
190*030f984aSJacob Faibussowitsch   cupmError_t               cerr;
191*030f984aSJacob Faibussowitsch 
192*030f984aSJacob Faibussowitsch   PetscFunctionBegin;
193*030f984aSJacob Faibussowitsch   // in case anything was queued on the event
194*030f984aSJacob Faibussowitsch   cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr);
195*030f984aSJacob Faibussowitsch   cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr);
196*030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
197*030f984aSJacob Faibussowitsch }
198*030f984aSJacob Faibussowitsch 
199*030f984aSJacob Faibussowitsch // initialize the static member variables
200*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
201*030f984aSJacob Faibussowitsch typename CUPMContext<T>::cupmBlasHandle_t   CUPMContext<T>::_blashandle   = nullptr;
202*030f984aSJacob Faibussowitsch 
203*030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
204*030f984aSJacob Faibussowitsch typename CUPMContext<T>::cupmSolverHandle_t CUPMContext<T>::_solverhandle = nullptr;
205*030f984aSJacob Faibussowitsch 
206*030f984aSJacob Faibussowitsch // shorten this one up a bit
207*030f984aSJacob Faibussowitsch using CUPMContextCuda = CUPMContext<CUPMDeviceKind::CUDA>;
208*030f984aSJacob Faibussowitsch using CUPMContextHip  = CUPMContext<CUPMDeviceKind::HIP>;
209*030f984aSJacob Faibussowitsch 
210*030f984aSJacob Faibussowitsch // make sure these doesn't leak out
211*030f984aSJacob Faibussowitsch #undef CHKERRCUPM
212*030f984aSJacob Faibussowitsch #undef IMPLS_RCAST_
213*030f984aSJacob Faibussowitsch 
214*030f984aSJacob Faibussowitsch } // namespace Petsc
215*030f984aSJacob Faibussowitsch 
216*030f984aSJacob Faibussowitsch // shorthand for what is an EXTREMELY long name
217*030f984aSJacob Faibussowitsch #define PetscDeviceContext_(impls_) Petsc::CUPMContext<Petsc::CUPMDeviceKind::impls_>::PetscDeviceContext_IMPLS
218*030f984aSJacob Faibussowitsch 
219*030f984aSJacob Faibussowitsch // shorthand for casting dctx->data to the appropriate object to access the handles
220*030f984aSJacob Faibussowitsch #define PDC_IMPLS_RCAST(impls_,obj_) reinterpret_cast<PetscDeviceContext_(impls_) *>((obj_)->data)
221*030f984aSJacob Faibussowitsch 
222*030f984aSJacob Faibussowitsch #endif /* PETSCDEVICECONTEXTCUDA_HPP */
223