xref: /petsc/src/sys/objects/device/impls/cupm/cupmcontext.hpp (revision 1b6e0089d1293b22d11fdb367ef334ac8b88849e)
1030f984aSJacob Faibussowitsch #if !defined(PETSCDEVICECONTEXTCUPM_HPP)
2030f984aSJacob Faibussowitsch #define PETSCDEVICECONTEXTCUPM_HPP
3030f984aSJacob Faibussowitsch 
4030f984aSJacob Faibussowitsch #include <petsc/private/deviceimpl.h> /*I "petscdevice.h" I*/
5030f984aSJacob Faibussowitsch #include <petsc/private/cupminterface.hpp>
6030f984aSJacob Faibussowitsch 
7030f984aSJacob Faibussowitsch #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11)
8030f984aSJacob Faibussowitsch #error PetscDeviceContext backends for CUDA and HIP requires C++11
9030f984aSJacob Faibussowitsch #endif
10030f984aSJacob Faibussowitsch 
11030f984aSJacob Faibussowitsch namespace Petsc {
12030f984aSJacob Faibussowitsch 
13030f984aSJacob Faibussowitsch // Forward declare
14030f984aSJacob Faibussowitsch template <CUPMDeviceKind T> class CUPMContext;
15030f984aSJacob Faibussowitsch 
16030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
17030f984aSJacob Faibussowitsch class CUPMContext : CUPMInterface<T>
18030f984aSJacob Faibussowitsch {
19030f984aSJacob Faibussowitsch public:
20*1b6e0089SPierre Jolivet   PETSC_INHERIT_CUPM_INTERFACE_TYPEDEFS_USING(cupmInterface_t,T)
21030f984aSJacob Faibussowitsch 
22030f984aSJacob Faibussowitsch   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
23030f984aSJacob Faibussowitsch   // header, but since we are using the power of templates it must be declared part of
24030f984aSJacob Faibussowitsch   // this class to have easy access the same typedefs. Technically one can make a
25030f984aSJacob Faibussowitsch   // templated struct outside the class but it's more code for the same result.
26030f984aSJacob Faibussowitsch   struct PetscDeviceContext_IMPLS
27030f984aSJacob Faibussowitsch   {
28030f984aSJacob Faibussowitsch     cupmStream_t       stream;
29030f984aSJacob Faibussowitsch     cupmEvent_t        event;
30030f984aSJacob Faibussowitsch     cupmBlasHandle_t   blas;
31030f984aSJacob Faibussowitsch     cupmSolverHandle_t solver;
32030f984aSJacob Faibussowitsch   };
33030f984aSJacob Faibussowitsch 
34030f984aSJacob Faibussowitsch private:
35030f984aSJacob Faibussowitsch   static cupmBlasHandle_t   _blashandle;
36030f984aSJacob Faibussowitsch   static cupmSolverHandle_t _solverhandle;
37030f984aSJacob Faibussowitsch 
38030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode __finalizeBLASHandle() noexcept
39030f984aSJacob Faibussowitsch   {
40030f984aSJacob Faibussowitsch     PetscErrorCode ierr;
41030f984aSJacob Faibussowitsch 
42030f984aSJacob Faibussowitsch     PetscFunctionBegin;
43030f984aSJacob Faibussowitsch     ierr = cupmInterface_t::DestroyHandle(_blashandle);CHKERRQ(ierr);
44030f984aSJacob Faibussowitsch     PetscFunctionReturn(0);
45030f984aSJacob Faibussowitsch   }
46030f984aSJacob Faibussowitsch 
47030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode __finalizeSOLVERHandle() noexcept
48030f984aSJacob Faibussowitsch   {
49030f984aSJacob Faibussowitsch     PetscErrorCode ierr;
50030f984aSJacob Faibussowitsch 
51030f984aSJacob Faibussowitsch     PetscFunctionBegin;
52030f984aSJacob Faibussowitsch     ierr = cupmInterface_t::DestroyHandle(_solverhandle);CHKERRQ(ierr);
53030f984aSJacob Faibussowitsch     PetscFunctionReturn(0);
54030f984aSJacob Faibussowitsch   }
55030f984aSJacob Faibussowitsch 
56030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode __setupHandles(PetscDeviceContext_IMPLS *dci) noexcept
57030f984aSJacob Faibussowitsch   {
58030f984aSJacob Faibussowitsch     PetscErrorCode  ierr;
59030f984aSJacob Faibussowitsch 
60030f984aSJacob Faibussowitsch     PetscFunctionBegin;
61030f984aSJacob Faibussowitsch     if (!_blashandle) {
62030f984aSJacob Faibussowitsch       ierr = cupmInterface_t::InitializeHandle(_blashandle);CHKERRQ(ierr);
63030f984aSJacob Faibussowitsch       ierr = PetscRegisterFinalize(__finalizeBLASHandle);CHKERRQ(ierr);
64030f984aSJacob Faibussowitsch     }
65030f984aSJacob Faibussowitsch     if (!_solverhandle) {
66030f984aSJacob Faibussowitsch       ierr = cupmInterface_t::InitializeHandle(_solverhandle);CHKERRQ(ierr);
67030f984aSJacob Faibussowitsch       ierr = PetscRegisterFinalize(__finalizeSOLVERHandle);CHKERRQ(ierr);
68030f984aSJacob Faibussowitsch     }
69030f984aSJacob Faibussowitsch     ierr = cupmInterface_t::SetHandleStream(_blashandle,dci->stream);CHKERRQ(ierr);
70030f984aSJacob Faibussowitsch     ierr = cupmInterface_t::SetHandleStream(_solverhandle,dci->stream);CHKERRQ(ierr);
71030f984aSJacob Faibussowitsch     dci->blas   = _blashandle;
72030f984aSJacob Faibussowitsch     dci->solver = _solverhandle;
73030f984aSJacob Faibussowitsch     PetscFunctionReturn(0);
74030f984aSJacob Faibussowitsch   }
75030f984aSJacob Faibussowitsch 
76030f984aSJacob Faibussowitsch public:
77030f984aSJacob Faibussowitsch   const struct _DeviceContextOps ops {destroy,changeStreamType,setUp,query,waitForContext,synchronize};
78030f984aSJacob Faibussowitsch 
79030f984aSJacob Faibussowitsch   // default constructor
80030f984aSJacob Faibussowitsch   constexpr CUPMContext() noexcept = default;
81030f984aSJacob Faibussowitsch 
82030f984aSJacob Faibussowitsch   // All of these functions MUST be static in order to be callable from C, otherwise they
83030f984aSJacob Faibussowitsch   // get the implicit 'this' pointer tacked on
84030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept;
85030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType) noexcept;
86030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept;
87030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext,PetscBool*) noexcept;
88030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext) noexcept;
89030f984aSJacob Faibussowitsch   PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
90030f984aSJacob Faibussowitsch };
91030f984aSJacob Faibussowitsch 
92030f984aSJacob Faibussowitsch #define IMPLS_RCAST_(obj_) static_cast<PetscDeviceContext_IMPLS*>((obj_)->data)
93030f984aSJacob Faibussowitsch 
94030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
95030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::destroy(PetscDeviceContext dctx) noexcept
96030f984aSJacob Faibussowitsch {
97030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
98030f984aSJacob Faibussowitsch   cupmError_t              cerr;
99030f984aSJacob Faibussowitsch   PetscErrorCode           ierr;
100030f984aSJacob Faibussowitsch 
101030f984aSJacob Faibussowitsch   PetscFunctionBegin;
102030f984aSJacob Faibussowitsch   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
103030f984aSJacob Faibussowitsch   if (dci->event)  {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr);}
104030f984aSJacob Faibussowitsch   ierr = PetscFree(dctx->data);CHKERRQ(ierr);
105030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
106030f984aSJacob Faibussowitsch }
107030f984aSJacob Faibussowitsch 
108030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
109030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::changeStreamType(PetscDeviceContext dctx, PetscStreamType stype) noexcept
110030f984aSJacob Faibussowitsch {
111030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
112030f984aSJacob Faibussowitsch 
113030f984aSJacob Faibussowitsch   PetscFunctionBegin;
114030f984aSJacob Faibussowitsch   if (dci->stream) {
115030f984aSJacob Faibussowitsch     cupmError_t cerr;
116030f984aSJacob Faibussowitsch 
117030f984aSJacob Faibussowitsch     cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);
118030f984aSJacob Faibussowitsch     dci->stream = nullptr;
119030f984aSJacob Faibussowitsch   }
120030f984aSJacob Faibussowitsch   // set these to null so they aren't usable until setup is called again
121030f984aSJacob Faibussowitsch   dci->blas   = nullptr;
122030f984aSJacob Faibussowitsch   dci->solver = nullptr;
123030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
124030f984aSJacob Faibussowitsch }
125030f984aSJacob Faibussowitsch 
126030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
127030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::setUp(PetscDeviceContext dctx) noexcept
128030f984aSJacob Faibussowitsch {
129030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
130030f984aSJacob Faibussowitsch   PetscErrorCode           ierr;
131030f984aSJacob Faibussowitsch   cupmError_t              cerr;
132030f984aSJacob Faibussowitsch 
133030f984aSJacob Faibussowitsch   PetscFunctionBegin;
134030f984aSJacob Faibussowitsch   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
135030f984aSJacob Faibussowitsch   switch (dctx->streamType) {
136030f984aSJacob Faibussowitsch   case PETSC_STREAM_GLOBAL_BLOCKING:
137030f984aSJacob Faibussowitsch     // don't create a stream for global blocking
138030f984aSJacob Faibussowitsch     dci->stream = nullptr;
139030f984aSJacob Faibussowitsch     break;
140030f984aSJacob Faibussowitsch   case PETSC_STREAM_DEFAULT_BLOCKING:
141030f984aSJacob Faibussowitsch     cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr);
142030f984aSJacob Faibussowitsch     break;
143030f984aSJacob Faibussowitsch   case PETSC_STREAM_GLOBAL_NONBLOCKING:
144030f984aSJacob Faibussowitsch     cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr);
145030f984aSJacob Faibussowitsch     break;
146030f984aSJacob Faibussowitsch   default:
147030f984aSJacob Faibussowitsch     SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %d",dctx->streamType);
148030f984aSJacob Faibussowitsch     break;
149030f984aSJacob Faibussowitsch   }
150030f984aSJacob Faibussowitsch   if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);}
151030f984aSJacob Faibussowitsch   ierr = __setupHandles(dci);CHKERRQ(ierr);
152030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
153030f984aSJacob Faibussowitsch }
154030f984aSJacob Faibussowitsch 
155030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
156030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
157030f984aSJacob Faibussowitsch {
158030f984aSJacob Faibussowitsch   cupmError_t cerr;
159030f984aSJacob Faibussowitsch 
160030f984aSJacob Faibussowitsch   PetscFunctionBegin;
161030f984aSJacob Faibussowitsch   cerr = cupmStreamQuery(IMPLS_RCAST_(dctx)->stream);
162030f984aSJacob Faibussowitsch   if (cerr == cupmSuccess)
163030f984aSJacob Faibussowitsch     *idle = PETSC_TRUE;
164030f984aSJacob Faibussowitsch   else if (cerr == cupmErrorNotReady) {
165030f984aSJacob Faibussowitsch     *idle = PETSC_FALSE;
166030f984aSJacob Faibussowitsch   } else {
167030f984aSJacob Faibussowitsch     // somethings gone wrong
168030f984aSJacob Faibussowitsch     CHKERRCUPM(cerr);
169030f984aSJacob Faibussowitsch   }
170030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
171030f984aSJacob Faibussowitsch }
172030f984aSJacob Faibussowitsch 
173030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
174030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
175030f984aSJacob Faibussowitsch {
176030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dcia = IMPLS_RCAST_(dctxa);
177030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dcib = IMPLS_RCAST_(dctxb);
178030f984aSJacob Faibussowitsch   cupmError_t               cerr;
179030f984aSJacob Faibussowitsch 
180030f984aSJacob Faibussowitsch   PetscFunctionBegin;
181030f984aSJacob Faibussowitsch   cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr);
182030f984aSJacob Faibussowitsch   cerr = cupmStreamWaitEvent(dcia->stream,dcib->event,0);CHKERRCUPM(cerr);
183030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
184030f984aSJacob Faibussowitsch }
185030f984aSJacob Faibussowitsch 
186030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
187030f984aSJacob Faibussowitsch inline PetscErrorCode CUPMContext<T>::synchronize(PetscDeviceContext dctx) noexcept
188030f984aSJacob Faibussowitsch {
189030f984aSJacob Faibussowitsch   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
190030f984aSJacob Faibussowitsch   cupmError_t               cerr;
191030f984aSJacob Faibussowitsch 
192030f984aSJacob Faibussowitsch   PetscFunctionBegin;
193030f984aSJacob Faibussowitsch   // in case anything was queued on the event
194030f984aSJacob Faibussowitsch   cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr);
195030f984aSJacob Faibussowitsch   cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr);
196030f984aSJacob Faibussowitsch   PetscFunctionReturn(0);
197030f984aSJacob Faibussowitsch }
198030f984aSJacob Faibussowitsch 
199030f984aSJacob Faibussowitsch // initialize the static member variables
200030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
201030f984aSJacob Faibussowitsch typename CUPMContext<T>::cupmBlasHandle_t   CUPMContext<T>::_blashandle   = nullptr;
202030f984aSJacob Faibussowitsch 
203030f984aSJacob Faibussowitsch template <CUPMDeviceKind T>
204030f984aSJacob Faibussowitsch typename CUPMContext<T>::cupmSolverHandle_t CUPMContext<T>::_solverhandle = nullptr;
205030f984aSJacob Faibussowitsch 
206030f984aSJacob Faibussowitsch // shorten this one up a bit
207030f984aSJacob Faibussowitsch using CUPMContextCuda = CUPMContext<CUPMDeviceKind::CUDA>;
208030f984aSJacob Faibussowitsch using CUPMContextHip  = CUPMContext<CUPMDeviceKind::HIP>;
209030f984aSJacob Faibussowitsch 
210030f984aSJacob Faibussowitsch // make sure these doesn't leak out
211030f984aSJacob Faibussowitsch #undef CHKERRCUPM
212030f984aSJacob Faibussowitsch #undef IMPLS_RCAST_
213030f984aSJacob Faibussowitsch 
214030f984aSJacob Faibussowitsch } // namespace Petsc
215030f984aSJacob Faibussowitsch 
216030f984aSJacob Faibussowitsch // shorthand for what is an EXTREMELY long name
217030f984aSJacob Faibussowitsch #define PetscDeviceContext_(impls_) Petsc::CUPMContext<Petsc::CUPMDeviceKind::impls_>::PetscDeviceContext_IMPLS
218030f984aSJacob Faibussowitsch 
219030f984aSJacob Faibussowitsch // shorthand for casting dctx->data to the appropriate object to access the handles
220030f984aSJacob Faibussowitsch #define PDC_IMPLS_RCAST(impls_,obj_) reinterpret_cast<PetscDeviceContext_(impls_) *>((obj_)->data)
221030f984aSJacob Faibussowitsch 
222030f984aSJacob Faibussowitsch #endif /* PETSCDEVICECONTEXTCUDA_HPP */
223