xref: /honee/src/mat-ceed.c (revision 58600ac3b26c12c428243f1945a2fbe2d8f3ffe2)
1*58600ac3SJames Wright /// @file
2*58600ac3SJames Wright /// MatCeed and it's related operators
3*58600ac3SJames Wright 
4*58600ac3SJames Wright #include <ceed.h>
5*58600ac3SJames Wright #include <ceed/backend.h>
6*58600ac3SJames Wright #include <mat-ceed-impl.h>
7*58600ac3SJames Wright #include <mat-ceed.h>
8*58600ac3SJames Wright #include <petscdmplex.h>
9*58600ac3SJames Wright #include <stdlib.h>
10*58600ac3SJames Wright #include <string.h>
11*58600ac3SJames Wright 
12*58600ac3SJames Wright PetscClassId  MATCEED_CLASSID;
13*58600ac3SJames Wright PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE;
14*58600ac3SJames Wright 
15*58600ac3SJames Wright /**
16*58600ac3SJames Wright   @brief Register MATCEED log events.
17*58600ac3SJames Wright 
18*58600ac3SJames Wright   Not collective across MPI processes.
19*58600ac3SJames Wright 
20*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
21*58600ac3SJames Wright **/
22*58600ac3SJames Wright static PetscErrorCode MatCeedRegisterLogEvents() {
23*58600ac3SJames Wright   static bool registered = false;
24*58600ac3SJames Wright 
25*58600ac3SJames Wright   PetscFunctionBeginUser;
26*58600ac3SJames Wright   if (registered) PetscFunctionReturn(PETSC_SUCCESS);
27*58600ac3SJames Wright   PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID));
28*58600ac3SJames Wright   PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT));
29*58600ac3SJames Wright   PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE));
30*58600ac3SJames Wright   registered = true;
31*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
32*58600ac3SJames Wright }
33*58600ac3SJames Wright 
34*58600ac3SJames Wright /**
35*58600ac3SJames Wright   @brief Translate PetscMemType to CeedMemType
36*58600ac3SJames Wright 
37*58600ac3SJames Wright   @param[in]  mem_type  PetscMemType
38*58600ac3SJames Wright 
39*58600ac3SJames Wright   @return Equivalent CeedMemType
40*58600ac3SJames Wright **/
41*58600ac3SJames Wright static inline CeedMemType MemTypeP2C(PetscMemType mem_type) { return PetscMemTypeDevice(mem_type) ? CEED_MEM_DEVICE : CEED_MEM_HOST; }
42*58600ac3SJames Wright 
43*58600ac3SJames Wright /**
44*58600ac3SJames Wright   @brief Translate array of `CeedInt` to `PetscInt`.
45*58600ac3SJames Wright     If the types differ, `array_ceed` is freed with `free()` and `array_petsc` is allocated with `malloc()`.
46*58600ac3SJames Wright     Caller is responsible for freeing `array_petsc` with `free()`.
47*58600ac3SJames Wright 
48*58600ac3SJames Wright   Not collective across MPI processes.
49*58600ac3SJames Wright 
50*58600ac3SJames Wright   @param[in]      num_entries  Number of array entries
51*58600ac3SJames Wright   @param[in,out]  array_ceed   Array of `CeedInt`
52*58600ac3SJames Wright   @param[out]     array_petsc  Array of `PetscInt`
53*58600ac3SJames Wright 
54*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
55*58600ac3SJames Wright **/
56*58600ac3SJames Wright static inline PetscErrorCode IntArrayC2P(PetscInt num_entries, CeedInt **array_ceed, PetscInt **array_petsc) {
57*58600ac3SJames Wright   const CeedInt  int_c = 0;
58*58600ac3SJames Wright   const PetscInt int_p = 0;
59*58600ac3SJames Wright 
60*58600ac3SJames Wright   PetscFunctionBeginUser;
61*58600ac3SJames Wright   if (sizeof(int_c) == sizeof(int_p)) {
62*58600ac3SJames Wright     *array_petsc = (PetscInt *)*array_ceed;
63*58600ac3SJames Wright   } else {
64*58600ac3SJames Wright     *array_petsc = malloc(num_entries * sizeof(PetscInt));
65*58600ac3SJames Wright     for (PetscInt i = 0; i < num_entries; i++) (*array_petsc)[i] = (*array_ceed)[i];
66*58600ac3SJames Wright     free(*array_ceed);
67*58600ac3SJames Wright   }
68*58600ac3SJames Wright   *array_ceed = NULL;
69*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
70*58600ac3SJames Wright }
71*58600ac3SJames Wright 
72*58600ac3SJames Wright /**
73*58600ac3SJames Wright   @brief Transfer array from PETSc `Vec` to `CeedVector`.
74*58600ac3SJames Wright 
75*58600ac3SJames Wright   Collective across MPI processes.
76*58600ac3SJames Wright 
77*58600ac3SJames Wright   @param[in]   ceed      libCEED context
78*58600ac3SJames Wright   @param[in]   X_petsc   PETSc `Vec`
79*58600ac3SJames Wright   @param[out]  mem_type  PETSc `MemType`
80*58600ac3SJames Wright   @param[out]  x_ceed    `CeedVector`
81*58600ac3SJames Wright 
82*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
83*58600ac3SJames Wright **/
84*58600ac3SJames Wright static inline PetscErrorCode VecP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
85*58600ac3SJames Wright   PetscScalar *x;
86*58600ac3SJames Wright 
87*58600ac3SJames Wright   PetscFunctionBeginUser;
88*58600ac3SJames Wright   PetscCall(VecGetArrayAndMemType(X_petsc, &x, mem_type));
89*58600ac3SJames Wright   PetscCeedCall(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x));
90*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
91*58600ac3SJames Wright }
92*58600ac3SJames Wright 
93*58600ac3SJames Wright /**
94*58600ac3SJames Wright   @brief Transfer array from `CeedVector` to PETSc `Vec`.
95*58600ac3SJames Wright 
96*58600ac3SJames Wright   Collective across MPI processes.
97*58600ac3SJames Wright 
98*58600ac3SJames Wright   @param[in]   ceed      libCEED context
99*58600ac3SJames Wright   @param[in]   x_ceed    `CeedVector`
100*58600ac3SJames Wright   @param[in]   mem_type  PETSc `MemType`
101*58600ac3SJames Wright   @param[out]  X_petsc   PETSc `Vec`
102*58600ac3SJames Wright 
103*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
104*58600ac3SJames Wright **/
105*58600ac3SJames Wright static inline PetscErrorCode VecC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
106*58600ac3SJames Wright   PetscScalar *x;
107*58600ac3SJames Wright 
108*58600ac3SJames Wright   PetscFunctionBeginUser;
109*58600ac3SJames Wright   PetscCeedCall(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x));
110*58600ac3SJames Wright   PetscCall(VecRestoreArrayAndMemType(X_petsc, &x));
111*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
112*58600ac3SJames Wright }
113*58600ac3SJames Wright 
114*58600ac3SJames Wright /**
115*58600ac3SJames Wright   @brief Transfer read only array from PETSc `Vec` to `CeedVector`.
116*58600ac3SJames Wright 
117*58600ac3SJames Wright   Collective across MPI processes.
118*58600ac3SJames Wright 
119*58600ac3SJames Wright   @param[in]   ceed      libCEED context
120*58600ac3SJames Wright   @param[in]   X_petsc   PETSc `Vec`
121*58600ac3SJames Wright   @param[out]  mem_type  PETSc `MemType`
122*58600ac3SJames Wright   @param[out]  x_ceed    `CeedVector`
123*58600ac3SJames Wright 
124*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
125*58600ac3SJames Wright **/
126*58600ac3SJames Wright static inline PetscErrorCode VecReadP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
127*58600ac3SJames Wright   PetscScalar *x;
128*58600ac3SJames Wright 
129*58600ac3SJames Wright   PetscFunctionBeginUser;
130*58600ac3SJames Wright   PetscCall(VecGetArrayReadAndMemType(X_petsc, (const PetscScalar **)&x, mem_type));
131*58600ac3SJames Wright   PetscCeedCall(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x));
132*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
133*58600ac3SJames Wright }
134*58600ac3SJames Wright 
135*58600ac3SJames Wright /**
136*58600ac3SJames Wright   @brief Transfer read only array from `CeedVector` to PETSc `Vec`.
137*58600ac3SJames Wright 
138*58600ac3SJames Wright   Collective across MPI processes.
139*58600ac3SJames Wright 
140*58600ac3SJames Wright   @param[in]   ceed      libCEED context
141*58600ac3SJames Wright   @param[in]   x_ceed    `CeedVector`
142*58600ac3SJames Wright   @param[in]   mem_type  PETSc `MemType`
143*58600ac3SJames Wright   @param[out]  X_petsc   PETSc `Vec`
144*58600ac3SJames Wright 
145*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
146*58600ac3SJames Wright **/
147*58600ac3SJames Wright static inline PetscErrorCode VecReadC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
148*58600ac3SJames Wright   PetscScalar *x;
149*58600ac3SJames Wright 
150*58600ac3SJames Wright   PetscFunctionBeginUser;
151*58600ac3SJames Wright   PetscCeedCall(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x));
152*58600ac3SJames Wright   PetscCall(VecRestoreArrayReadAndMemType(X_petsc, (const PetscScalar **)&x));
153*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
154*58600ac3SJames Wright }
155*58600ac3SJames Wright 
156*58600ac3SJames Wright /**
157*58600ac3SJames Wright   @brief Setup inner `Mat` for `PC` operations not directly supported by libCEED.
158*58600ac3SJames Wright 
159*58600ac3SJames Wright   Collective across MPI processes.
160*58600ac3SJames Wright 
161*58600ac3SJames Wright   @param[in]   mat_ceed   `MATCEED` to setup
162*58600ac3SJames Wright   @param[out]  mat_inner  Inner `Mat`
163*58600ac3SJames Wright 
164*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
165*58600ac3SJames Wright **/
166*58600ac3SJames Wright static PetscErrorCode MatCeedSetupInnerMat(Mat mat_ceed, Mat *mat_inner) {
167*58600ac3SJames Wright   MatCeedContext ctx;
168*58600ac3SJames Wright 
169*58600ac3SJames Wright   PetscFunctionBeginUser;
170*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
171*58600ac3SJames Wright 
172*58600ac3SJames Wright   PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "PC only supported for MATCEED on a single DM");
173*58600ac3SJames Wright 
174*58600ac3SJames Wright   // Check cl mat type
175*58600ac3SJames Wright   {
176*58600ac3SJames Wright     PetscBool is_internal_mat_type_cl = PETSC_FALSE;
177*58600ac3SJames Wright     char      internal_mat_type_cl[64];
178*58600ac3SJames Wright 
179*58600ac3SJames Wright     // Check for specific CL inner mat type for this Mat
180*58600ac3SJames Wright     {
181*58600ac3SJames Wright       const char *mat_ceed_prefix = NULL;
182*58600ac3SJames Wright 
183*58600ac3SJames Wright       PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix));
184*58600ac3SJames Wright       PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL);
185*58600ac3SJames Wright       PetscCall(PetscOptionsFList("-ceed_inner_mat_type", "MATCEED inner assembled MatType for PC support", NULL, MatList, internal_mat_type_cl,
186*58600ac3SJames Wright                                   internal_mat_type_cl, sizeof(internal_mat_type_cl), &is_internal_mat_type_cl));
187*58600ac3SJames Wright       PetscOptionsEnd();
188*58600ac3SJames Wright       if (is_internal_mat_type_cl) {
189*58600ac3SJames Wright         PetscCall(PetscFree(ctx->internal_mat_type));
190*58600ac3SJames Wright         PetscCall(PetscStrallocpy(internal_mat_type_cl, &ctx->internal_mat_type));
191*58600ac3SJames Wright       }
192*58600ac3SJames Wright     }
193*58600ac3SJames Wright   }
194*58600ac3SJames Wright 
195*58600ac3SJames Wright   // Create sparse matrix
196*58600ac3SJames Wright   {
197*58600ac3SJames Wright     MatType dm_mat_type, dm_mat_type_copy;
198*58600ac3SJames Wright 
199*58600ac3SJames Wright     PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type));
200*58600ac3SJames Wright     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
201*58600ac3SJames Wright     PetscCall(DMSetMatType(ctx->dm_x, ctx->internal_mat_type));
202*58600ac3SJames Wright     PetscCall(DMCreateMatrix(ctx->dm_x, mat_inner));
203*58600ac3SJames Wright     PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy));
204*58600ac3SJames Wright     PetscCall(PetscFree(dm_mat_type_copy));
205*58600ac3SJames Wright   }
206*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
207*58600ac3SJames Wright }
208*58600ac3SJames Wright 
209*58600ac3SJames Wright /**
210*58600ac3SJames Wright   @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar.
211*58600ac3SJames Wright          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
212*58600ac3SJames Wright          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
213*58600ac3SJames Wright 
214*58600ac3SJames Wright   Collective across MPI processes.
215*58600ac3SJames Wright 
216*58600ac3SJames Wright   @param[in]      mat_ceed  `MATCEED` to assemble
217*58600ac3SJames Wright   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
218*58600ac3SJames Wright 
219*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
220*58600ac3SJames Wright **/
221*58600ac3SJames Wright static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) {
222*58600ac3SJames Wright   MatCeedContext ctx;
223*58600ac3SJames Wright 
224*58600ac3SJames Wright   PetscFunctionBeginUser;
225*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
226*58600ac3SJames Wright 
227*58600ac3SJames Wright   // Check if COO pattern set
228*58600ac3SJames Wright   {
229*58600ac3SJames Wright     PetscInt index = -1;
230*58600ac3SJames Wright 
231*58600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
232*58600ac3SJames Wright       if (ctx->mats_assembled_pbd[i] == mat_coo) index = i;
233*58600ac3SJames Wright     }
234*58600ac3SJames Wright     if (index == -1) {
235*58600ac3SJames Wright       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
236*58600ac3SJames Wright       CeedInt      *rows_ceed, *cols_ceed;
237*58600ac3SJames Wright       PetscCount    num_entries;
238*58600ac3SJames Wright       PetscLogStage stage_amg_setup;
239*58600ac3SJames Wright 
240*58600ac3SJames Wright       // -- Assemble sparsity pattern if mat hasn't been assembled before
241*58600ac3SJames Wright       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
242*58600ac3SJames Wright       if (stage_amg_setup == -1) {
243*58600ac3SJames Wright         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
244*58600ac3SJames Wright       }
245*58600ac3SJames Wright       PetscCall(PetscLogStagePush(stage_amg_setup));
246*58600ac3SJames Wright       PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
247*58600ac3SJames Wright       PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc));
248*58600ac3SJames Wright       PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc));
249*58600ac3SJames Wright       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
250*58600ac3SJames Wright       free(rows_petsc);
251*58600ac3SJames Wright       free(cols_petsc);
252*58600ac3SJames Wright       if (!ctx->coo_values_pbd) PetscCeedCall(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd));
253*58600ac3SJames Wright       PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd));
254*58600ac3SJames Wright       ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo;
255*58600ac3SJames Wright       PetscCall(PetscLogStagePop());
256*58600ac3SJames Wright     }
257*58600ac3SJames Wright   }
258*58600ac3SJames Wright 
259*58600ac3SJames Wright   // Assemble mat_ceed
260*58600ac3SJames Wright   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
261*58600ac3SJames Wright   {
262*58600ac3SJames Wright     const CeedScalar *values;
263*58600ac3SJames Wright     MatType           mat_type;
264*58600ac3SJames Wright     CeedMemType       mem_type = CEED_MEM_HOST;
265*58600ac3SJames Wright     PetscBool         is_spd, is_spd_known;
266*58600ac3SJames Wright 
267*58600ac3SJames Wright     PetscCall(MatGetType(mat_coo, &mat_type));
268*58600ac3SJames Wright     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
269*58600ac3SJames Wright     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
270*58600ac3SJames Wright     else mem_type = CEED_MEM_HOST;
271*58600ac3SJames Wright 
272*58600ac3SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE));
273*58600ac3SJames Wright     PetscCeedCall(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values));
274*58600ac3SJames Wright     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
275*58600ac3SJames Wright     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
276*58600ac3SJames Wright     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
277*58600ac3SJames Wright     PetscCeedCall(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values));
278*58600ac3SJames Wright   }
279*58600ac3SJames Wright   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
280*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
281*58600ac3SJames Wright }
282*58600ac3SJames Wright 
283*58600ac3SJames Wright /**
284*58600ac3SJames Wright   @brief Assemble inner `Mat` for diagonal `PC` operations
285*58600ac3SJames Wright 
286*58600ac3SJames Wright   Collective across MPI processes.
287*58600ac3SJames Wright 
288*58600ac3SJames Wright   @param[in]   mat_ceed      `MATCEED` to invert
289*58600ac3SJames Wright   @param[in]   use_ceed_pbd  Boolean flag to use libCEED PBD assembly
290*58600ac3SJames Wright   @param[out]  mat_inner     Inner `Mat` for diagonal operations
291*58600ac3SJames Wright 
292*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
293*58600ac3SJames Wright **/
294*58600ac3SJames Wright static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) {
295*58600ac3SJames Wright   MatCeedContext ctx;
296*58600ac3SJames Wright 
297*58600ac3SJames Wright   PetscFunctionBeginUser;
298*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
299*58600ac3SJames Wright   if (use_ceed_pbd) {
300*58600ac3SJames Wright     // Check if COO pattern set
301*58600ac3SJames Wright     if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_pbd_internal));
302*58600ac3SJames Wright 
303*58600ac3SJames Wright     // Assemble mat_assembled_full_internal
304*58600ac3SJames Wright     PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal));
305*58600ac3SJames Wright     if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal;
306*58600ac3SJames Wright   } else {
307*58600ac3SJames Wright     // Check if COO pattern set
308*58600ac3SJames Wright     if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_full_internal));
309*58600ac3SJames Wright 
310*58600ac3SJames Wright     // Assemble mat_assembled_full_internal
311*58600ac3SJames Wright     PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal));
312*58600ac3SJames Wright     if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal;
313*58600ac3SJames Wright   }
314*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
315*58600ac3SJames Wright }
316*58600ac3SJames Wright 
317*58600ac3SJames Wright /**
318*58600ac3SJames Wright   @brief Get `MATCEED` diagonal block for Jacobi.
319*58600ac3SJames Wright 
320*58600ac3SJames Wright   Collective across MPI processes.
321*58600ac3SJames Wright 
322*58600ac3SJames Wright   @param[in]   mat_ceed   `MATCEED` to invert
323*58600ac3SJames Wright   @param[out]  mat_block  The diagonal block matrix
324*58600ac3SJames Wright 
325*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
326*58600ac3SJames Wright **/
327*58600ac3SJames Wright static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) {
328*58600ac3SJames Wright   Mat            mat_inner = NULL;
329*58600ac3SJames Wright   MatCeedContext ctx;
330*58600ac3SJames Wright 
331*58600ac3SJames Wright   PetscFunctionBeginUser;
332*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
333*58600ac3SJames Wright 
334*58600ac3SJames Wright   // Assemble inner mat if needed
335*58600ac3SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
336*58600ac3SJames Wright 
337*58600ac3SJames Wright   // Get block diagonal
338*58600ac3SJames Wright   PetscCall(MatGetDiagonalBlock(mat_inner, mat_block));
339*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
340*58600ac3SJames Wright }
341*58600ac3SJames Wright 
342*58600ac3SJames Wright /**
343*58600ac3SJames Wright   @brief Invert `MATCEED` diagonal block for Jacobi.
344*58600ac3SJames Wright 
345*58600ac3SJames Wright   Collective across MPI processes.
346*58600ac3SJames Wright 
347*58600ac3SJames Wright   @param[in]   mat_ceed  `MATCEED` to invert
348*58600ac3SJames Wright   @param[out]  values    The block inverses in column major order
349*58600ac3SJames Wright 
350*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
351*58600ac3SJames Wright **/
352*58600ac3SJames Wright static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) {
353*58600ac3SJames Wright   Mat            mat_inner = NULL;
354*58600ac3SJames Wright   MatCeedContext ctx;
355*58600ac3SJames Wright 
356*58600ac3SJames Wright   PetscFunctionBeginUser;
357*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
358*58600ac3SJames Wright 
359*58600ac3SJames Wright   // Assemble inner mat if needed
360*58600ac3SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
361*58600ac3SJames Wright 
362*58600ac3SJames Wright   // Invert PB diagonal
363*58600ac3SJames Wright   PetscCall(MatInvertBlockDiagonal(mat_inner, values));
364*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
365*58600ac3SJames Wright }
366*58600ac3SJames Wright 
367*58600ac3SJames Wright /**
368*58600ac3SJames Wright   @brief Invert `MATCEED` variable diagonal block for Jacobi.
369*58600ac3SJames Wright 
370*58600ac3SJames Wright   Collective across MPI processes.
371*58600ac3SJames Wright 
372*58600ac3SJames Wright   @param[in]   mat_ceed     `MATCEED` to invert
373*58600ac3SJames Wright   @param[in]   num_blocks   The number of blocks on the process
374*58600ac3SJames Wright   @param[in]   block_sizes  The size of each block on the process
375*58600ac3SJames Wright   @param[out]  values       The block inverses in column major order
376*58600ac3SJames Wright 
377*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
378*58600ac3SJames Wright **/
379*58600ac3SJames Wright static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) {
380*58600ac3SJames Wright   Mat            mat_inner = NULL;
381*58600ac3SJames Wright   MatCeedContext ctx;
382*58600ac3SJames Wright 
383*58600ac3SJames Wright   PetscFunctionBeginUser;
384*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
385*58600ac3SJames Wright 
386*58600ac3SJames Wright   // Assemble inner mat if needed
387*58600ac3SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner));
388*58600ac3SJames Wright 
389*58600ac3SJames Wright   // Invert PB diagonal
390*58600ac3SJames Wright   PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values));
391*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
392*58600ac3SJames Wright }
393*58600ac3SJames Wright 
394*58600ac3SJames Wright // -----------------------------------------------------------------------------
395*58600ac3SJames Wright // MatCeed
396*58600ac3SJames Wright // -----------------------------------------------------------------------------
397*58600ac3SJames Wright 
398*58600ac3SJames Wright /**
399*58600ac3SJames Wright   @brief Create PETSc `Mat` from libCEED operators.
400*58600ac3SJames Wright 
401*58600ac3SJames Wright   Collective across MPI processes.
402*58600ac3SJames Wright 
403*58600ac3SJames Wright   @param[in]   dm_x                      Input `DM`
404*58600ac3SJames Wright   @param[in]   dm_y                      Output `DM`
405*58600ac3SJames Wright   @param[in]   op_mult                   `CeedOperator` for forward evaluation
406*58600ac3SJames Wright   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
407*58600ac3SJames Wright   @param[out]  mat                        New MatCeed
408*58600ac3SJames Wright 
409*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
410*58600ac3SJames Wright **/
411*58600ac3SJames Wright PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) {
412*58600ac3SJames Wright   PetscInt       X_l_size, X_g_size, Y_l_size, Y_g_size;
413*58600ac3SJames Wright   VecType        vec_type;
414*58600ac3SJames Wright   MatCeedContext ctx;
415*58600ac3SJames Wright 
416*58600ac3SJames Wright   PetscFunctionBeginUser;
417*58600ac3SJames Wright   PetscCall(MatCeedRegisterLogEvents());
418*58600ac3SJames Wright 
419*58600ac3SJames Wright   // Collect context data
420*58600ac3SJames Wright   PetscCall(DMGetVecType(dm_x, &vec_type));
421*58600ac3SJames Wright   {
422*58600ac3SJames Wright     Vec X;
423*58600ac3SJames Wright 
424*58600ac3SJames Wright     PetscCall(DMGetGlobalVector(dm_x, &X));
425*58600ac3SJames Wright     PetscCall(VecGetSize(X, &X_g_size));
426*58600ac3SJames Wright     PetscCall(VecGetLocalSize(X, &X_l_size));
427*58600ac3SJames Wright     PetscCall(DMRestoreGlobalVector(dm_x, &X));
428*58600ac3SJames Wright   }
429*58600ac3SJames Wright   if (dm_y) {
430*58600ac3SJames Wright     Vec Y;
431*58600ac3SJames Wright 
432*58600ac3SJames Wright     PetscCall(DMGetGlobalVector(dm_y, &Y));
433*58600ac3SJames Wright     PetscCall(VecGetSize(Y, &Y_g_size));
434*58600ac3SJames Wright     PetscCall(VecGetLocalSize(Y, &Y_l_size));
435*58600ac3SJames Wright     PetscCall(DMRestoreGlobalVector(dm_y, &Y));
436*58600ac3SJames Wright   } else {
437*58600ac3SJames Wright     dm_y     = dm_x;
438*58600ac3SJames Wright     Y_g_size = X_g_size;
439*58600ac3SJames Wright     Y_l_size = X_l_size;
440*58600ac3SJames Wright   }
441*58600ac3SJames Wright   // Create context
442*58600ac3SJames Wright   {
443*58600ac3SJames Wright     Vec X_loc, Y_loc_transpose = NULL;
444*58600ac3SJames Wright 
445*58600ac3SJames Wright     PetscCall(DMCreateLocalVector(dm_x, &X_loc));
446*58600ac3SJames Wright     PetscCall(VecZeroEntries(X_loc));
447*58600ac3SJames Wright     if (op_mult_transpose) {
448*58600ac3SJames Wright       PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose));
449*58600ac3SJames Wright       PetscCall(VecZeroEntries(Y_loc_transpose));
450*58600ac3SJames Wright     }
451*58600ac3SJames Wright     PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx));
452*58600ac3SJames Wright     PetscCall(VecDestroy(&X_loc));
453*58600ac3SJames Wright     PetscCall(VecDestroy(&Y_loc_transpose));
454*58600ac3SJames Wright   }
455*58600ac3SJames Wright 
456*58600ac3SJames Wright   // Create mat
457*58600ac3SJames Wright   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat));
458*58600ac3SJames Wright   PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED));
459*58600ac3SJames Wright   // -- Set block and variable block sizes
460*58600ac3SJames Wright   if (dm_x == dm_y) {
461*58600ac3SJames Wright     MatType dm_mat_type, dm_mat_type_copy;
462*58600ac3SJames Wright     Mat     temp_mat;
463*58600ac3SJames Wright 
464*58600ac3SJames Wright     PetscCall(DMGetMatType(dm_x, &dm_mat_type));
465*58600ac3SJames Wright     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
466*58600ac3SJames Wright     PetscCall(DMSetMatType(dm_x, MATAIJ));
467*58600ac3SJames Wright     PetscCall(DMCreateMatrix(dm_x, &temp_mat));
468*58600ac3SJames Wright     PetscCall(DMSetMatType(dm_x, dm_mat_type_copy));
469*58600ac3SJames Wright     PetscCall(PetscFree(dm_mat_type_copy));
470*58600ac3SJames Wright 
471*58600ac3SJames Wright     {
472*58600ac3SJames Wright       PetscInt        block_size, num_blocks, max_vblock_size = PETSC_INT_MAX;
473*58600ac3SJames Wright       const PetscInt *vblock_sizes;
474*58600ac3SJames Wright 
475*58600ac3SJames Wright       // -- Get block sizes
476*58600ac3SJames Wright       PetscCall(MatGetBlockSize(temp_mat, &block_size));
477*58600ac3SJames Wright       PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes));
478*58600ac3SJames Wright       {
479*58600ac3SJames Wright         PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX};
480*58600ac3SJames Wright 
481*58600ac3SJames Wright         for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]);
482*58600ac3SJames Wright         PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max));
483*58600ac3SJames Wright         max_vblock_size = global_min_max[1];
484*58600ac3SJames Wright       }
485*58600ac3SJames Wright 
486*58600ac3SJames Wright       // -- Copy block sizes
487*58600ac3SJames Wright       if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size));
488*58600ac3SJames Wright       if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes));
489*58600ac3SJames Wright 
490*58600ac3SJames Wright       // -- Check libCEED compatibility
491*58600ac3SJames Wright       {
492*58600ac3SJames Wright         bool is_composite;
493*58600ac3SJames Wright 
494*58600ac3SJames Wright         ctx->is_ceed_pbd_valid  = PETSC_TRUE;
495*58600ac3SJames Wright         ctx->is_ceed_vpbd_valid = PETSC_TRUE;
496*58600ac3SJames Wright         PetscCeedCall(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite));
497*58600ac3SJames Wright         if (is_composite) {
498*58600ac3SJames Wright           CeedInt       num_sub_operators;
499*58600ac3SJames Wright           CeedOperator *sub_operators;
500*58600ac3SJames Wright 
501*58600ac3SJames Wright           PetscCeedCall(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators));
502*58600ac3SJames Wright           PetscCeedCall(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators));
503*58600ac3SJames Wright           for (CeedInt i = 0; i < num_sub_operators; i++) {
504*58600ac3SJames Wright             CeedInt                  num_bases, num_comp;
505*58600ac3SJames Wright             CeedBasis               *active_bases;
506*58600ac3SJames Wright             CeedOperatorAssemblyData assembly_data;
507*58600ac3SJames Wright 
508*58600ac3SJames Wright             PetscCeedCall(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data));
509*58600ac3SJames Wright             PetscCeedCall(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
510*58600ac3SJames Wright             PetscCeedCall(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
511*58600ac3SJames Wright             if (num_bases > 1) {
512*58600ac3SJames Wright               ctx->is_ceed_pbd_valid  = PETSC_FALSE;
513*58600ac3SJames Wright               ctx->is_ceed_vpbd_valid = PETSC_FALSE;
514*58600ac3SJames Wright             }
515*58600ac3SJames Wright             if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
516*58600ac3SJames Wright             if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
517*58600ac3SJames Wright           }
518*58600ac3SJames Wright         } else {
519*58600ac3SJames Wright           // LCOV_EXCL_START
520*58600ac3SJames Wright           CeedInt                  num_bases, num_comp;
521*58600ac3SJames Wright           CeedBasis               *active_bases;
522*58600ac3SJames Wright           CeedOperatorAssemblyData assembly_data;
523*58600ac3SJames Wright 
524*58600ac3SJames Wright           PetscCeedCall(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data));
525*58600ac3SJames Wright           PetscCeedCall(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
526*58600ac3SJames Wright           PetscCeedCall(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
527*58600ac3SJames Wright           if (num_bases > 1) {
528*58600ac3SJames Wright             ctx->is_ceed_pbd_valid  = PETSC_FALSE;
529*58600ac3SJames Wright             ctx->is_ceed_vpbd_valid = PETSC_FALSE;
530*58600ac3SJames Wright           }
531*58600ac3SJames Wright           if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
532*58600ac3SJames Wright           if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
533*58600ac3SJames Wright           // LCOV_EXCL_STOP
534*58600ac3SJames Wright         }
535*58600ac3SJames Wright         {
536*58600ac3SJames Wright           PetscInt local_is_valid[2], global_is_valid[2];
537*58600ac3SJames Wright 
538*58600ac3SJames Wright           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid;
539*58600ac3SJames Wright           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
540*58600ac3SJames Wright           ctx->is_ceed_pbd_valid = global_is_valid[0];
541*58600ac3SJames Wright           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid;
542*58600ac3SJames Wright           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
543*58600ac3SJames Wright           ctx->is_ceed_vpbd_valid = global_is_valid[0];
544*58600ac3SJames Wright         }
545*58600ac3SJames Wright       }
546*58600ac3SJames Wright     }
547*58600ac3SJames Wright     PetscCall(MatDestroy(&temp_mat));
548*58600ac3SJames Wright   }
549*58600ac3SJames Wright   // -- Set internal mat type
550*58600ac3SJames Wright   {
551*58600ac3SJames Wright     VecType vec_type;
552*58600ac3SJames Wright     MatType internal_mat_type = MATAIJ;
553*58600ac3SJames Wright 
554*58600ac3SJames Wright     PetscCall(VecGetType(ctx->X_loc, &vec_type));
555*58600ac3SJames Wright     if (strstr(vec_type, VECCUDA)) internal_mat_type = MATAIJCUSPARSE;
556*58600ac3SJames Wright     else if (strstr(vec_type, VECKOKKOS)) internal_mat_type = MATAIJKOKKOS;
557*58600ac3SJames Wright     else internal_mat_type = MATAIJ;
558*58600ac3SJames Wright     PetscCall(PetscStrallocpy(internal_mat_type, &ctx->internal_mat_type));
559*58600ac3SJames Wright   }
560*58600ac3SJames Wright   // -- Set mat operations
561*58600ac3SJames Wright   PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
562*58600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed));
563*58600ac3SJames Wright   if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
564*58600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
565*58600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
566*58600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
567*58600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
568*58600ac3SJames Wright   PetscCall(MatShellSetVecType(*mat, vec_type));
569*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
570*58600ac3SJames Wright }
571*58600ac3SJames Wright 
572*58600ac3SJames Wright /**
573*58600ac3SJames Wright   @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`.
574*58600ac3SJames Wright 
575*58600ac3SJames Wright   Collective across MPI processes.
576*58600ac3SJames Wright 
577*58600ac3SJames Wright   @param[in]   mat_ceed   `MATCEED` to copy from
578*58600ac3SJames Wright   @param[out]  mat_other  `MatShell` or `MATCEED` to copy into
579*58600ac3SJames Wright 
580*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
581*58600ac3SJames Wright **/
582*58600ac3SJames Wright PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) {
583*58600ac3SJames Wright   PetscFunctionBeginUser;
584*58600ac3SJames Wright   PetscCall(MatCeedRegisterLogEvents());
585*58600ac3SJames Wright 
586*58600ac3SJames Wright   // Check type compatibility
587*58600ac3SJames Wright   {
588*58600ac3SJames Wright     MatType mat_type_ceed, mat_type_other;
589*58600ac3SJames Wright 
590*58600ac3SJames Wright     PetscCall(MatGetType(mat_ceed, &mat_type_ceed));
591*58600ac3SJames Wright     PetscCheck(!strcmp(mat_type_ceed, MATCEED), PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED);
592*58600ac3SJames Wright     PetscCall(MatGetType(mat_ceed, &mat_type_other));
593*58600ac3SJames Wright     PetscCheck(!strcmp(mat_type_other, MATCEED) || !strcmp(mat_type_other, MATSHELL), PETSC_COMM_SELF, PETSC_ERR_LIB,
594*58600ac3SJames Wright                "mat_other must have type " MATCEED " or " MATSHELL);
595*58600ac3SJames Wright   }
596*58600ac3SJames Wright 
597*58600ac3SJames Wright   // Check dimension compatibility
598*58600ac3SJames Wright   {
599*58600ac3SJames Wright     PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size;
600*58600ac3SJames Wright 
601*58600ac3SJames Wright     PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size));
602*58600ac3SJames Wright     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size));
603*58600ac3SJames Wright     PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size));
604*58600ac3SJames Wright     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size));
605*58600ac3SJames Wright     PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) &&
606*58600ac3SJames Wright                    (X_l_ceed_size == X_l_other_size),
607*58600ac3SJames Wright                PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ,
608*58600ac3SJames Wright                "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT
609*58600ac3SJames Wright                "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT
610*58600ac3SJames Wright                ", %" PetscInt_FMT ")",
611*58600ac3SJames Wright                Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size);
612*58600ac3SJames Wright   }
613*58600ac3SJames Wright 
614*58600ac3SJames Wright   // Convert
615*58600ac3SJames Wright   {
616*58600ac3SJames Wright     VecType        vec_type;
617*58600ac3SJames Wright     MatCeedContext ctx;
618*58600ac3SJames Wright 
619*58600ac3SJames Wright     PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED));
620*58600ac3SJames Wright     PetscCall(MatShellGetContext(mat_ceed, &ctx));
621*58600ac3SJames Wright     PetscCall(MatCeedContextReference(ctx));
622*58600ac3SJames Wright     PetscCall(MatShellSetContext(mat_other, ctx));
623*58600ac3SJames Wright     PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
624*58600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed));
625*58600ac3SJames Wright     if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
626*58600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
627*58600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
628*58600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
629*58600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
630*58600ac3SJames Wright     {
631*58600ac3SJames Wright       PetscInt block_size;
632*58600ac3SJames Wright 
633*58600ac3SJames Wright       PetscCall(MatGetBlockSize(mat_ceed, &block_size));
634*58600ac3SJames Wright       if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size));
635*58600ac3SJames Wright     }
636*58600ac3SJames Wright     {
637*58600ac3SJames Wright       PetscInt        num_blocks;
638*58600ac3SJames Wright       const PetscInt *block_sizes;
639*58600ac3SJames Wright 
640*58600ac3SJames Wright       PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes));
641*58600ac3SJames Wright       if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes));
642*58600ac3SJames Wright     }
643*58600ac3SJames Wright     PetscCall(DMGetVecType(ctx->dm_x, &vec_type));
644*58600ac3SJames Wright     PetscCall(MatShellSetVecType(mat_other, vec_type));
645*58600ac3SJames Wright   }
646*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
647*58600ac3SJames Wright }
648*58600ac3SJames Wright 
649*58600ac3SJames Wright /**
650*58600ac3SJames Wright   @brief Assemble a `MATCEED` into a `MATAIJ` or similar.
651*58600ac3SJames Wright          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
652*58600ac3SJames Wright          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
653*58600ac3SJames Wright 
654*58600ac3SJames Wright   Collective across MPI processes.
655*58600ac3SJames Wright 
656*58600ac3SJames Wright   @param[in]      mat_ceed  `MATCEED` to assemble
657*58600ac3SJames Wright   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
658*58600ac3SJames Wright 
659*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
660*58600ac3SJames Wright **/
661*58600ac3SJames Wright PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) {
662*58600ac3SJames Wright   MatCeedContext ctx;
663*58600ac3SJames Wright 
664*58600ac3SJames Wright   PetscFunctionBeginUser;
665*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
666*58600ac3SJames Wright 
667*58600ac3SJames Wright   // Check if COO pattern set
668*58600ac3SJames Wright   {
669*58600ac3SJames Wright     PetscInt index = -1;
670*58600ac3SJames Wright 
671*58600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
672*58600ac3SJames Wright       if (ctx->mats_assembled_full[i] == mat_coo) index = i;
673*58600ac3SJames Wright     }
674*58600ac3SJames Wright     if (index == -1) {
675*58600ac3SJames Wright       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
676*58600ac3SJames Wright       CeedInt      *rows_ceed, *cols_ceed;
677*58600ac3SJames Wright       PetscCount    num_entries;
678*58600ac3SJames Wright       PetscLogStage stage_amg_setup;
679*58600ac3SJames Wright 
680*58600ac3SJames Wright       // -- Assemble sparsity pattern if mat hasn't been assembled before
681*58600ac3SJames Wright       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
682*58600ac3SJames Wright       if (stage_amg_setup == -1) {
683*58600ac3SJames Wright         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
684*58600ac3SJames Wright       }
685*58600ac3SJames Wright       PetscCall(PetscLogStagePush(stage_amg_setup));
686*58600ac3SJames Wright       PetscCeedCall(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
687*58600ac3SJames Wright       PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc));
688*58600ac3SJames Wright       PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc));
689*58600ac3SJames Wright       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
690*58600ac3SJames Wright       free(rows_petsc);
691*58600ac3SJames Wright       free(cols_petsc);
692*58600ac3SJames Wright       if (!ctx->coo_values_full) PetscCeedCall(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full));
693*58600ac3SJames Wright       PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full));
694*58600ac3SJames Wright       ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo;
695*58600ac3SJames Wright       PetscCall(PetscLogStagePop());
696*58600ac3SJames Wright     }
697*58600ac3SJames Wright   }
698*58600ac3SJames Wright 
699*58600ac3SJames Wright   // Assemble mat_ceed
700*58600ac3SJames Wright   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
701*58600ac3SJames Wright   {
702*58600ac3SJames Wright     const CeedScalar *values;
703*58600ac3SJames Wright     MatType           mat_type;
704*58600ac3SJames Wright     CeedMemType       mem_type = CEED_MEM_HOST;
705*58600ac3SJames Wright     PetscBool         is_spd, is_spd_known;
706*58600ac3SJames Wright 
707*58600ac3SJames Wright     PetscCall(MatGetType(mat_coo, &mat_type));
708*58600ac3SJames Wright     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
709*58600ac3SJames Wright     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
710*58600ac3SJames Wright     else mem_type = CEED_MEM_HOST;
711*58600ac3SJames Wright 
712*58600ac3SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full));
713*58600ac3SJames Wright     PetscCeedCall(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values));
714*58600ac3SJames Wright     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
715*58600ac3SJames Wright     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
716*58600ac3SJames Wright     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
717*58600ac3SJames Wright     PetscCeedCall(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values));
718*58600ac3SJames Wright   }
719*58600ac3SJames Wright   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
720*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
721*58600ac3SJames Wright }
722*58600ac3SJames Wright 
723*58600ac3SJames Wright /**
724*58600ac3SJames Wright   @brief Set user context for a `MATCEED`.
725*58600ac3SJames Wright 
726*58600ac3SJames Wright   Collective across MPI processes.
727*58600ac3SJames Wright 
728*58600ac3SJames Wright   @param[in,out]  mat  `MATCEED`
729*58600ac3SJames Wright   @param[in]      f    The context destroy function, or NULL
730*58600ac3SJames Wright   @param[in]      ctx  User context, or NULL to unset
731*58600ac3SJames Wright 
732*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
733*58600ac3SJames Wright **/
734*58600ac3SJames Wright PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) {
735*58600ac3SJames Wright   PetscContainer user_ctx = NULL;
736*58600ac3SJames Wright 
737*58600ac3SJames Wright   PetscFunctionBeginUser;
738*58600ac3SJames Wright   if (ctx) {
739*58600ac3SJames Wright     PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx));
740*58600ac3SJames Wright     PetscCall(PetscContainerSetPointer(user_ctx, ctx));
741*58600ac3SJames Wright     PetscCall(PetscContainerSetUserDestroy(user_ctx, f));
742*58600ac3SJames Wright   }
743*58600ac3SJames Wright   PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx));
744*58600ac3SJames Wright   PetscCall(PetscContainerDestroy(&user_ctx));
745*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
746*58600ac3SJames Wright }
747*58600ac3SJames Wright 
748*58600ac3SJames Wright /**
749*58600ac3SJames Wright   @brief Retrieve the user context for a `MATCEED`.
750*58600ac3SJames Wright 
751*58600ac3SJames Wright   Collective across MPI processes.
752*58600ac3SJames Wright 
753*58600ac3SJames Wright   @param[in,out]  mat  `MATCEED`
754*58600ac3SJames Wright   @param[in]      ctx  User context
755*58600ac3SJames Wright 
756*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
757*58600ac3SJames Wright **/
758*58600ac3SJames Wright PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) {
759*58600ac3SJames Wright   PetscContainer user_ctx;
760*58600ac3SJames Wright 
761*58600ac3SJames Wright   PetscFunctionBeginUser;
762*58600ac3SJames Wright   PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx));
763*58600ac3SJames Wright   if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx));
764*58600ac3SJames Wright   else *(void **)ctx = NULL;
765*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
766*58600ac3SJames Wright }
767*58600ac3SJames Wright 
768*58600ac3SJames Wright /**
769*58600ac3SJames Wright   @brief Sets the inner matrix type as a string from the `MATCEED`.
770*58600ac3SJames Wright 
771*58600ac3SJames Wright   Collective across MPI processes.
772*58600ac3SJames Wright 
773*58600ac3SJames Wright   @param[in,out]  mat   `MATCEED`
774*58600ac3SJames Wright   @param[in]      type  Inner `MatType` to set
775*58600ac3SJames Wright 
776*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
777*58600ac3SJames Wright **/
778*58600ac3SJames Wright PetscErrorCode MatCeedSetInnerMatType(Mat mat, MatType type) {
779*58600ac3SJames Wright   MatCeedContext ctx;
780*58600ac3SJames Wright 
781*58600ac3SJames Wright   PetscFunctionBeginUser;
782*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
783*58600ac3SJames Wright   // Check if same
784*58600ac3SJames Wright   {
785*58600ac3SJames Wright     size_t    len_old, len_new;
786*58600ac3SJames Wright     PetscBool is_same = PETSC_FALSE;
787*58600ac3SJames Wright 
788*58600ac3SJames Wright     PetscCall(PetscStrlen(ctx->internal_mat_type, &len_old));
789*58600ac3SJames Wright     PetscCall(PetscStrlen(type, &len_new));
790*58600ac3SJames Wright     if (len_old == len_new) PetscCall(PetscStrncmp(ctx->internal_mat_type, type, len_old, &is_same));
791*58600ac3SJames Wright     if (is_same) PetscFunctionReturn(PETSC_SUCCESS);
792*58600ac3SJames Wright   }
793*58600ac3SJames Wright   // Clean up old mats in different format
794*58600ac3SJames Wright   // LCOV_EXCL_START
795*58600ac3SJames Wright   if (ctx->mat_assembled_full_internal) {
796*58600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
797*58600ac3SJames Wright       if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) {
798*58600ac3SJames Wright         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) {
799*58600ac3SJames Wright           ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j];
800*58600ac3SJames Wright         }
801*58600ac3SJames Wright         ctx->num_mats_assembled_full--;
802*58600ac3SJames Wright         // Note: we'll realloc this array again, so no need to shrink the allocation
803*58600ac3SJames Wright         PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
804*58600ac3SJames Wright       }
805*58600ac3SJames Wright     }
806*58600ac3SJames Wright   }
807*58600ac3SJames Wright   if (ctx->mat_assembled_pbd_internal) {
808*58600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
809*58600ac3SJames Wright       if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) {
810*58600ac3SJames Wright         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) {
811*58600ac3SJames Wright           ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j];
812*58600ac3SJames Wright         }
813*58600ac3SJames Wright         // Note: we'll realloc this array again, so no need to shrink the allocation
814*58600ac3SJames Wright         ctx->num_mats_assembled_pbd--;
815*58600ac3SJames Wright         PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
816*58600ac3SJames Wright       }
817*58600ac3SJames Wright     }
818*58600ac3SJames Wright   }
819*58600ac3SJames Wright   PetscCall(PetscFree(ctx->internal_mat_type));
820*58600ac3SJames Wright   PetscCall(PetscStrallocpy(type, &ctx->internal_mat_type));
821*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
822*58600ac3SJames Wright   // LCOV_EXCL_STOP
823*58600ac3SJames Wright }
824*58600ac3SJames Wright 
825*58600ac3SJames Wright /**
826*58600ac3SJames Wright   @brief Gets the inner matrix type as a string from the `MATCEED`.
827*58600ac3SJames Wright 
828*58600ac3SJames Wright   Collective across MPI processes.
829*58600ac3SJames Wright 
830*58600ac3SJames Wright   @param[in,out]  mat   `MATCEED`
831*58600ac3SJames Wright   @param[in]      type  Inner `MatType`
832*58600ac3SJames Wright 
833*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
834*58600ac3SJames Wright **/
835*58600ac3SJames Wright PetscErrorCode MatCeedGetInnerMatType(Mat mat, MatType *type) {
836*58600ac3SJames Wright   MatCeedContext ctx;
837*58600ac3SJames Wright 
838*58600ac3SJames Wright   PetscFunctionBeginUser;
839*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
840*58600ac3SJames Wright   *type = ctx->internal_mat_type;
841*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
842*58600ac3SJames Wright }
843*58600ac3SJames Wright 
844*58600ac3SJames Wright /**
845*58600ac3SJames Wright   @brief Set a user defined matrix operation for a `MATCEED` matrix.
846*58600ac3SJames Wright 
847*58600ac3SJames Wright   Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by
848*58600ac3SJames Wright `MatCeedSetContext()`.
849*58600ac3SJames Wright 
850*58600ac3SJames Wright   Collective across MPI processes.
851*58600ac3SJames Wright 
852*58600ac3SJames Wright   @param[in,out]  mat  `MATCEED`
853*58600ac3SJames Wright   @param[in]      op   Name of the `MatOperation`
854*58600ac3SJames Wright   @param[in]      g    Function that provides the operation
855*58600ac3SJames Wright 
856*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
857*58600ac3SJames Wright **/
858*58600ac3SJames Wright PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) {
859*58600ac3SJames Wright   PetscFunctionBeginUser;
860*58600ac3SJames Wright   PetscCall(MatShellSetOperation(mat, op, g));
861*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
862*58600ac3SJames Wright }
863*58600ac3SJames Wright 
864*58600ac3SJames Wright /**
865*58600ac3SJames Wright   @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
866*58600ac3SJames Wright 
867*58600ac3SJames Wright   Not collective across MPI processes.
868*58600ac3SJames Wright 
869*58600ac3SJames Wright   @param[in,out]  mat              `MATCEED`
870*58600ac3SJames Wright   @param[in]      X_loc            Input PETSc local vector, or NULL
871*58600ac3SJames Wright   @param[in]      Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
872*58600ac3SJames Wright 
873*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
874*58600ac3SJames Wright **/
875*58600ac3SJames Wright PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) {
876*58600ac3SJames Wright   MatCeedContext ctx;
877*58600ac3SJames Wright 
878*58600ac3SJames Wright   PetscFunctionBeginUser;
879*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
880*58600ac3SJames Wright   if (X_loc) {
881*58600ac3SJames Wright     PetscInt len_old, len_new;
882*58600ac3SJames Wright 
883*58600ac3SJames Wright     PetscCall(VecGetSize(ctx->X_loc, &len_old));
884*58600ac3SJames Wright     PetscCall(VecGetSize(X_loc, &len_new));
885*58600ac3SJames Wright     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT,
886*58600ac3SJames Wright                len_new, len_old);
887*58600ac3SJames Wright     PetscCall(VecDestroy(&ctx->X_loc));
888*58600ac3SJames Wright     ctx->X_loc = X_loc;
889*58600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)X_loc));
890*58600ac3SJames Wright   }
891*58600ac3SJames Wright   if (Y_loc_transpose) {
892*58600ac3SJames Wright     PetscInt len_old, len_new;
893*58600ac3SJames Wright 
894*58600ac3SJames Wright     PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old));
895*58600ac3SJames Wright     PetscCall(VecGetSize(Y_loc_transpose, &len_new));
896*58600ac3SJames Wright     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB,
897*58600ac3SJames Wright                "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old);
898*58600ac3SJames Wright     PetscCall(VecDestroy(&ctx->Y_loc_transpose));
899*58600ac3SJames Wright     ctx->Y_loc_transpose = Y_loc_transpose;
900*58600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
901*58600ac3SJames Wright   }
902*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
903*58600ac3SJames Wright }
904*58600ac3SJames Wright 
905*58600ac3SJames Wright /**
906*58600ac3SJames Wright   @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
907*58600ac3SJames Wright 
908*58600ac3SJames Wright   Not collective across MPI processes.
909*58600ac3SJames Wright 
910*58600ac3SJames Wright   @param[in,out]  mat              `MATCEED`
911*58600ac3SJames Wright   @param[out]     X_loc            Input PETSc local vector, or NULL
912*58600ac3SJames Wright   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
913*58600ac3SJames Wright 
914*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
915*58600ac3SJames Wright **/
916*58600ac3SJames Wright PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
917*58600ac3SJames Wright   MatCeedContext ctx;
918*58600ac3SJames Wright 
919*58600ac3SJames Wright   PetscFunctionBeginUser;
920*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
921*58600ac3SJames Wright   if (X_loc) {
922*58600ac3SJames Wright     *X_loc = ctx->X_loc;
923*58600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)*X_loc));
924*58600ac3SJames Wright   }
925*58600ac3SJames Wright   if (Y_loc_transpose) {
926*58600ac3SJames Wright     *Y_loc_transpose = ctx->Y_loc_transpose;
927*58600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)*Y_loc_transpose));
928*58600ac3SJames Wright   }
929*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
930*58600ac3SJames Wright }
931*58600ac3SJames Wright 
932*58600ac3SJames Wright /**
933*58600ac3SJames Wright   @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
934*58600ac3SJames Wright 
935*58600ac3SJames Wright   Not collective across MPI processes.
936*58600ac3SJames Wright 
937*58600ac3SJames Wright   @param[in,out]  mat              MatCeed
938*58600ac3SJames Wright   @param[out]     X_loc            Input PETSc local vector, or NULL
939*58600ac3SJames Wright   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
940*58600ac3SJames Wright 
941*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
942*58600ac3SJames Wright **/
943*58600ac3SJames Wright PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
944*58600ac3SJames Wright   PetscFunctionBeginUser;
945*58600ac3SJames Wright   if (X_loc) PetscCall(VecDestroy(X_loc));
946*58600ac3SJames Wright   if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose));
947*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
948*58600ac3SJames Wright }
949*58600ac3SJames Wright 
950*58600ac3SJames Wright /**
951*58600ac3SJames Wright   @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
952*58600ac3SJames Wright 
953*58600ac3SJames Wright   Not collective across MPI processes.
954*58600ac3SJames Wright 
955*58600ac3SJames Wright   @param[in,out]  mat                MatCeed
956*58600ac3SJames Wright   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
957*58600ac3SJames Wright   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
958*58600ac3SJames Wright 
959*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
960*58600ac3SJames Wright **/
961*58600ac3SJames Wright PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
962*58600ac3SJames Wright   MatCeedContext ctx;
963*58600ac3SJames Wright 
964*58600ac3SJames Wright   PetscFunctionBeginUser;
965*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
966*58600ac3SJames Wright   if (op_mult) {
967*58600ac3SJames Wright     *op_mult = NULL;
968*58600ac3SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult));
969*58600ac3SJames Wright   }
970*58600ac3SJames Wright   if (op_mult_transpose) {
971*58600ac3SJames Wright     *op_mult_transpose = NULL;
972*58600ac3SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose));
973*58600ac3SJames Wright   }
974*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
975*58600ac3SJames Wright }
976*58600ac3SJames Wright 
977*58600ac3SJames Wright /**
978*58600ac3SJames Wright   @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
979*58600ac3SJames Wright 
980*58600ac3SJames Wright   Not collective across MPI processes.
981*58600ac3SJames Wright 
982*58600ac3SJames Wright   @param[in,out]  mat                MatCeed
983*58600ac3SJames Wright   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
984*58600ac3SJames Wright   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
985*58600ac3SJames Wright 
986*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
987*58600ac3SJames Wright **/
988*58600ac3SJames Wright PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
989*58600ac3SJames Wright   MatCeedContext ctx;
990*58600ac3SJames Wright 
991*58600ac3SJames Wright   PetscFunctionBeginUser;
992*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
993*58600ac3SJames Wright   if (op_mult) PetscCeedCall(ctx->ceed, CeedOperatorDestroy(op_mult));
994*58600ac3SJames Wright   if (op_mult_transpose) PetscCeedCall(ctx->ceed, CeedOperatorDestroy(op_mult_transpose));
995*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
996*58600ac3SJames Wright }
997*58600ac3SJames Wright 
998*58600ac3SJames Wright /**
999*58600ac3SJames Wright   @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
1000*58600ac3SJames Wright 
1001*58600ac3SJames Wright   Not collective across MPI processes.
1002*58600ac3SJames Wright 
1003*58600ac3SJames Wright   @param[in,out]  mat                       MatCeed
1004*58600ac3SJames Wright   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
1005*58600ac3SJames Wright   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
1006*58600ac3SJames Wright 
1007*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
1008*58600ac3SJames Wright **/
1009*58600ac3SJames Wright PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) {
1010*58600ac3SJames Wright   MatCeedContext ctx;
1011*58600ac3SJames Wright 
1012*58600ac3SJames Wright   PetscFunctionBeginUser;
1013*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
1014*58600ac3SJames Wright   if (log_event_mult) ctx->log_event_mult = log_event_mult;
1015*58600ac3SJames Wright   if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose;
1016*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1017*58600ac3SJames Wright }
1018*58600ac3SJames Wright 
1019*58600ac3SJames Wright /**
1020*58600ac3SJames Wright   @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
1021*58600ac3SJames Wright 
1022*58600ac3SJames Wright   Not collective across MPI processes.
1023*58600ac3SJames Wright 
1024*58600ac3SJames Wright   @param[in,out]  mat                       MatCeed
1025*58600ac3SJames Wright   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
1026*58600ac3SJames Wright   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
1027*58600ac3SJames Wright 
1028*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
1029*58600ac3SJames Wright **/
1030*58600ac3SJames Wright PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) {
1031*58600ac3SJames Wright   MatCeedContext ctx;
1032*58600ac3SJames Wright 
1033*58600ac3SJames Wright   PetscFunctionBeginUser;
1034*58600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
1035*58600ac3SJames Wright   if (log_event_mult) *log_event_mult = ctx->log_event_mult;
1036*58600ac3SJames Wright   if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose;
1037*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1038*58600ac3SJames Wright }
1039*58600ac3SJames Wright 
1040*58600ac3SJames Wright // -----------------------------------------------------------------------------
1041*58600ac3SJames Wright // Operator context data
1042*58600ac3SJames Wright // -----------------------------------------------------------------------------
1043*58600ac3SJames Wright 
1044*58600ac3SJames Wright /**
1045*58600ac3SJames Wright   @brief Setup context data for operator application.
1046*58600ac3SJames Wright 
1047*58600ac3SJames Wright   Collective across MPI processes.
1048*58600ac3SJames Wright 
1049*58600ac3SJames Wright   @param[in]   dm_x                      Input `DM`
1050*58600ac3SJames Wright   @param[in]   dm_y                      Output `DM`
1051*58600ac3SJames Wright   @param[in]   X_loc                     Input PETSc local vector, or NULL
1052*58600ac3SJames Wright   @param[in]   Y_loc_transpose           Input PETSc local vector for transpose operation, or NULL
1053*58600ac3SJames Wright   @param[in]   op_mult                   `CeedOperator` for forward evaluation
1054*58600ac3SJames Wright   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
1055*58600ac3SJames Wright   @param[in]   log_event_mult            `PetscLogEvent` for forward evaluation
1056*58600ac3SJames Wright   @param[in]   log_event_mult_transpose  `PetscLogEvent` for transpose evaluation
1057*58600ac3SJames Wright   @param[out]  ctx                       Context data for operator evaluation
1058*58600ac3SJames Wright 
1059*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
1060*58600ac3SJames Wright **/
1061*58600ac3SJames Wright PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose,
1062*58600ac3SJames Wright                                     PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) {
1063*58600ac3SJames Wright   CeedSize x_loc_len, y_loc_len;
1064*58600ac3SJames Wright 
1065*58600ac3SJames Wright   PetscFunctionBeginUser;
1066*58600ac3SJames Wright 
1067*58600ac3SJames Wright   // Allocate
1068*58600ac3SJames Wright   PetscCall(PetscNew(ctx));
1069*58600ac3SJames Wright   (*ctx)->ref_count = 1;
1070*58600ac3SJames Wright 
1071*58600ac3SJames Wright   // Logging
1072*58600ac3SJames Wright   (*ctx)->log_event_mult           = log_event_mult;
1073*58600ac3SJames Wright   (*ctx)->log_event_mult_transpose = log_event_mult_transpose;
1074*58600ac3SJames Wright 
1075*58600ac3SJames Wright   // PETSc objects
1076*58600ac3SJames Wright   PetscCall(PetscObjectReference((PetscObject)dm_x));
1077*58600ac3SJames Wright   (*ctx)->dm_x = dm_x;
1078*58600ac3SJames Wright   PetscCall(PetscObjectReference((PetscObject)dm_y));
1079*58600ac3SJames Wright   (*ctx)->dm_y = dm_y;
1080*58600ac3SJames Wright   if (X_loc) PetscCall(PetscObjectReference((PetscObject)X_loc));
1081*58600ac3SJames Wright   (*ctx)->X_loc = X_loc;
1082*58600ac3SJames Wright   if (Y_loc_transpose) PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
1083*58600ac3SJames Wright   (*ctx)->Y_loc_transpose = Y_loc_transpose;
1084*58600ac3SJames Wright 
1085*58600ac3SJames Wright   // Memtype
1086*58600ac3SJames Wright   {
1087*58600ac3SJames Wright     const PetscScalar *x;
1088*58600ac3SJames Wright     Vec                X;
1089*58600ac3SJames Wright 
1090*58600ac3SJames Wright     PetscCall(DMGetLocalVector(dm_x, &X));
1091*58600ac3SJames Wright     PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type));
1092*58600ac3SJames Wright     PetscCall(VecRestoreArrayReadAndMemType(X, &x));
1093*58600ac3SJames Wright     PetscCall(DMRestoreLocalVector(dm_x, &X));
1094*58600ac3SJames Wright   }
1095*58600ac3SJames Wright 
1096*58600ac3SJames Wright   // libCEED objects
1097*58600ac3SJames Wright   PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB,
1098*58600ac3SJames Wright              "retrieving Ceed context object failed");
1099*58600ac3SJames Wright   PetscCeedCall((*ctx)->ceed, CeedReference((*ctx)->ceed));
1100*58600ac3SJames Wright   PetscCeedCall((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len));
1101*58600ac3SJames Wright   PetscCeedCall((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult));
1102*58600ac3SJames Wright   if (op_mult_transpose) PetscCeedCall((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose));
1103*58600ac3SJames Wright   PetscCeedCall((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc));
1104*58600ac3SJames Wright   PetscCeedCall((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc));
1105*58600ac3SJames Wright 
1106*58600ac3SJames Wright   // Flop counting
1107*58600ac3SJames Wright   {
1108*58600ac3SJames Wright     CeedSize ceed_flops_estimate = 0;
1109*58600ac3SJames Wright 
1110*58600ac3SJames Wright     PetscCeedCall((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate));
1111*58600ac3SJames Wright     (*ctx)->flops_mult = ceed_flops_estimate;
1112*58600ac3SJames Wright     if (op_mult_transpose) {
1113*58600ac3SJames Wright       PetscCeedCall((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate));
1114*58600ac3SJames Wright       (*ctx)->flops_mult_transpose = ceed_flops_estimate;
1115*58600ac3SJames Wright     }
1116*58600ac3SJames Wright   }
1117*58600ac3SJames Wright 
1118*58600ac3SJames Wright   // Check sizes
1119*58600ac3SJames Wright   if (x_loc_len > 0 || y_loc_len > 0) {
1120*58600ac3SJames Wright     CeedSize ctx_x_loc_len, ctx_y_loc_len;
1121*58600ac3SJames Wright     PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len;
1122*58600ac3SJames Wright     Vec      dm_X_loc, dm_Y_loc;
1123*58600ac3SJames Wright 
1124*58600ac3SJames Wright     // -- Input
1125*58600ac3SJames Wright     PetscCall(DMGetLocalVector(dm_x, &dm_X_loc));
1126*58600ac3SJames Wright     PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len));
1127*58600ac3SJames Wright     PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc));
1128*58600ac3SJames Wright     if (X_loc) PetscCall(VecGetLocalSize(X_loc, &X_loc_len));
1129*58600ac3SJames Wright     PetscCeedCall((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len));
1130*58600ac3SJames Wright     if (X_loc) PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "X_loc must match dm_x dimensions");
1131*58600ac3SJames Wright     PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_x dimensions");
1132*58600ac3SJames Wright     PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc must match op dimensions");
1133*58600ac3SJames Wright 
1134*58600ac3SJames Wright     // -- Output
1135*58600ac3SJames Wright     PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc));
1136*58600ac3SJames Wright     PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len));
1137*58600ac3SJames Wright     PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc));
1138*58600ac3SJames Wright     PetscCeedCall((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len));
1139*58600ac3SJames Wright     PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_y dimensions");
1140*58600ac3SJames Wright 
1141*58600ac3SJames Wright     // -- Transpose
1142*58600ac3SJames Wright     if (Y_loc_transpose) {
1143*58600ac3SJames Wright       PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len));
1144*58600ac3SJames Wright       PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "Y_loc_transpose must match dm_y dimensions");
1145*58600ac3SJames Wright     }
1146*58600ac3SJames Wright   }
1147*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1148*58600ac3SJames Wright }
1149*58600ac3SJames Wright 
1150*58600ac3SJames Wright /**
1151*58600ac3SJames Wright   @brief Increment reference counter for `MATCEED` context.
1152*58600ac3SJames Wright 
1153*58600ac3SJames Wright   Not collective across MPI processes.
1154*58600ac3SJames Wright 
1155*58600ac3SJames Wright   @param[in,out]  ctx  Context data
1156*58600ac3SJames Wright 
1157*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
1158*58600ac3SJames Wright **/
1159*58600ac3SJames Wright PetscErrorCode MatCeedContextReference(MatCeedContext ctx) {
1160*58600ac3SJames Wright   PetscFunctionBeginUser;
1161*58600ac3SJames Wright   ctx->ref_count++;
1162*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1163*58600ac3SJames Wright }
1164*58600ac3SJames Wright 
1165*58600ac3SJames Wright /**
1166*58600ac3SJames Wright   @brief Copy reference for `MATCEED`.
1167*58600ac3SJames Wright          Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`.
1168*58600ac3SJames Wright 
1169*58600ac3SJames Wright   Not collective across MPI processes.
1170*58600ac3SJames Wright 
1171*58600ac3SJames Wright   @param[in]   ctx       Context data
1172*58600ac3SJames Wright   @param[out]  ctx_copy  Copy of pointer to context data
1173*58600ac3SJames Wright 
1174*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
1175*58600ac3SJames Wright **/
1176*58600ac3SJames Wright PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) {
1177*58600ac3SJames Wright   PetscFunctionBeginUser;
1178*58600ac3SJames Wright   PetscCall(MatCeedContextReference(ctx));
1179*58600ac3SJames Wright   PetscCall(MatCeedContextDestroy(*ctx_copy));
1180*58600ac3SJames Wright   *ctx_copy = ctx;
1181*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1182*58600ac3SJames Wright }
1183*58600ac3SJames Wright 
1184*58600ac3SJames Wright /**
1185*58600ac3SJames Wright   @brief Destroy context data for operator application.
1186*58600ac3SJames Wright 
1187*58600ac3SJames Wright   Collective across MPI processes.
1188*58600ac3SJames Wright 
1189*58600ac3SJames Wright   @param[in,out]  ctx  Context data for operator evaluation
1190*58600ac3SJames Wright 
1191*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
1192*58600ac3SJames Wright **/
1193*58600ac3SJames Wright PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) {
1194*58600ac3SJames Wright   PetscFunctionBeginUser;
1195*58600ac3SJames Wright   if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS);
1196*58600ac3SJames Wright 
1197*58600ac3SJames Wright   // PETSc objects
1198*58600ac3SJames Wright   PetscCall(DMDestroy(&ctx->dm_x));
1199*58600ac3SJames Wright   PetscCall(DMDestroy(&ctx->dm_y));
1200*58600ac3SJames Wright   PetscCall(VecDestroy(&ctx->X_loc));
1201*58600ac3SJames Wright   PetscCall(VecDestroy(&ctx->Y_loc_transpose));
1202*58600ac3SJames Wright   PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
1203*58600ac3SJames Wright   PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
1204*58600ac3SJames Wright   PetscCall(PetscFree(ctx->internal_mat_type));
1205*58600ac3SJames Wright   PetscCall(PetscFree(ctx->mats_assembled_full));
1206*58600ac3SJames Wright   PetscCall(PetscFree(ctx->mats_assembled_pbd));
1207*58600ac3SJames Wright 
1208*58600ac3SJames Wright   // libCEED objects
1209*58600ac3SJames Wright   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->x_loc));
1210*58600ac3SJames Wright   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->y_loc));
1211*58600ac3SJames Wright   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full));
1212*58600ac3SJames Wright   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd));
1213*58600ac3SJames Wright   PetscCeedCall(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult));
1214*58600ac3SJames Wright   PetscCeedCall(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose));
1215*58600ac3SJames Wright   PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed");
1216*58600ac3SJames Wright 
1217*58600ac3SJames Wright   // Deallocate
1218*58600ac3SJames Wright   ctx->is_destroyed = PETSC_TRUE;  // Flag as destroyed in case someone has stale ref
1219*58600ac3SJames Wright   PetscCall(PetscFree(ctx));
1220*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1221*58600ac3SJames Wright }
1222*58600ac3SJames Wright 
1223*58600ac3SJames Wright /**
1224*58600ac3SJames Wright   @brief Compute the diagonal of an operator via libCEED.
1225*58600ac3SJames Wright 
1226*58600ac3SJames Wright   Collective across MPI processes.
1227*58600ac3SJames Wright 
1228*58600ac3SJames Wright   @param[in]   A  `MATCEED`
1229*58600ac3SJames Wright   @param[out]  D  Vector holding operator diagonal
1230*58600ac3SJames Wright 
1231*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
1232*58600ac3SJames Wright **/
1233*58600ac3SJames Wright PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) {
1234*58600ac3SJames Wright   PetscMemType   mem_type;
1235*58600ac3SJames Wright   Vec            D_loc;
1236*58600ac3SJames Wright   MatCeedContext ctx;
1237*58600ac3SJames Wright 
1238*58600ac3SJames Wright   PetscFunctionBeginUser;
1239*58600ac3SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
1240*58600ac3SJames Wright 
1241*58600ac3SJames Wright   // Place PETSc vector in libCEED vector
1242*58600ac3SJames Wright   PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc));
1243*58600ac3SJames Wright   PetscCall(VecP2C(ctx->ceed, D_loc, &mem_type, ctx->x_loc));
1244*58600ac3SJames Wright 
1245*58600ac3SJames Wright   // Compute Diagonal
1246*58600ac3SJames Wright   PetscCeedCall(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1247*58600ac3SJames Wright 
1248*58600ac3SJames Wright   // Restore PETSc vector
1249*58600ac3SJames Wright   PetscCall(VecC2P(ctx->ceed, ctx->x_loc, mem_type, D_loc));
1250*58600ac3SJames Wright 
1251*58600ac3SJames Wright   // Local-to-Global
1252*58600ac3SJames Wright   PetscCall(VecZeroEntries(D));
1253*58600ac3SJames Wright   PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D));
1254*58600ac3SJames Wright   PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc));
1255*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1256*58600ac3SJames Wright }
1257*58600ac3SJames Wright 
1258*58600ac3SJames Wright /**
1259*58600ac3SJames Wright   @brief Compute `A X = Y` for a `MATCEED`.
1260*58600ac3SJames Wright 
1261*58600ac3SJames Wright   Collective across MPI processes.
1262*58600ac3SJames Wright 
1263*58600ac3SJames Wright   @param[in]   A  `MATCEED`
1264*58600ac3SJames Wright   @param[in]   X  Input PETSc vector
1265*58600ac3SJames Wright   @param[out]  Y  Output PETSc vector
1266*58600ac3SJames Wright 
1267*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
1268*58600ac3SJames Wright **/
1269*58600ac3SJames Wright PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) {
1270*58600ac3SJames Wright   MatCeedContext ctx;
1271*58600ac3SJames Wright 
1272*58600ac3SJames Wright   PetscFunctionBeginUser;
1273*58600ac3SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
1274*58600ac3SJames Wright   PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0));
1275*58600ac3SJames Wright 
1276*58600ac3SJames Wright   {
1277*58600ac3SJames Wright     PetscMemType x_mem_type, y_mem_type;
1278*58600ac3SJames Wright     Vec          X_loc = ctx->X_loc, Y_loc;
1279*58600ac3SJames Wright 
1280*58600ac3SJames Wright     // Get local vectors
1281*58600ac3SJames Wright     if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1282*58600ac3SJames Wright     PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1283*58600ac3SJames Wright 
1284*58600ac3SJames Wright     // Global-to-local
1285*58600ac3SJames Wright     PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc));
1286*58600ac3SJames Wright 
1287*58600ac3SJames Wright     // Setup libCEED vectors
1288*58600ac3SJames Wright     PetscCall(VecReadP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc));
1289*58600ac3SJames Wright     PetscCall(VecZeroEntries(Y_loc));
1290*58600ac3SJames Wright     PetscCall(VecP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc));
1291*58600ac3SJames Wright 
1292*58600ac3SJames Wright     // Apply libCEED operator
1293*58600ac3SJames Wright     PetscCall(PetscLogGpuTimeBegin());
1294*58600ac3SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE));
1295*58600ac3SJames Wright     PetscCall(PetscLogGpuTimeEnd());
1296*58600ac3SJames Wright 
1297*58600ac3SJames Wright     // Restore PETSc vectors
1298*58600ac3SJames Wright     PetscCall(VecReadC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc));
1299*58600ac3SJames Wright     PetscCall(VecC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc));
1300*58600ac3SJames Wright 
1301*58600ac3SJames Wright     // Local-to-global
1302*58600ac3SJames Wright     PetscCall(VecZeroEntries(Y));
1303*58600ac3SJames Wright     PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y));
1304*58600ac3SJames Wright 
1305*58600ac3SJames Wright     // Restore local vectors, as needed
1306*58600ac3SJames Wright     if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1307*58600ac3SJames Wright     PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1308*58600ac3SJames Wright   }
1309*58600ac3SJames Wright 
1310*58600ac3SJames Wright   // Log flops
1311*58600ac3SJames Wright   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult));
1312*58600ac3SJames Wright   else PetscCall(PetscLogFlops(ctx->flops_mult));
1313*58600ac3SJames Wright 
1314*58600ac3SJames Wright   PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0));
1315*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1316*58600ac3SJames Wright }
1317*58600ac3SJames Wright 
1318*58600ac3SJames Wright /**
1319*58600ac3SJames Wright   @brief Compute `A^T Y = X` for a `MATCEED`.
1320*58600ac3SJames Wright 
1321*58600ac3SJames Wright   Collective across MPI processes.
1322*58600ac3SJames Wright 
1323*58600ac3SJames Wright   @param[in]   A  `MATCEED`
1324*58600ac3SJames Wright   @param[in]   Y  Input PETSc vector
1325*58600ac3SJames Wright   @param[out]  X  Output PETSc vector
1326*58600ac3SJames Wright 
1327*58600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
1328*58600ac3SJames Wright **/
1329*58600ac3SJames Wright PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) {
1330*58600ac3SJames Wright   MatCeedContext ctx;
1331*58600ac3SJames Wright 
1332*58600ac3SJames Wright   PetscFunctionBeginUser;
1333*58600ac3SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
1334*58600ac3SJames Wright   PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0));
1335*58600ac3SJames Wright 
1336*58600ac3SJames Wright   {
1337*58600ac3SJames Wright     PetscMemType x_mem_type, y_mem_type;
1338*58600ac3SJames Wright     Vec          X_loc, Y_loc = ctx->Y_loc_transpose;
1339*58600ac3SJames Wright 
1340*58600ac3SJames Wright     // Get local vectors
1341*58600ac3SJames Wright     if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1342*58600ac3SJames Wright     PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1343*58600ac3SJames Wright 
1344*58600ac3SJames Wright     // Global-to-local
1345*58600ac3SJames Wright     PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc));
1346*58600ac3SJames Wright 
1347*58600ac3SJames Wright     // Setup libCEED vectors
1348*58600ac3SJames Wright     PetscCall(VecReadP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc));
1349*58600ac3SJames Wright     PetscCall(VecZeroEntries(X_loc));
1350*58600ac3SJames Wright     PetscCall(VecP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc));
1351*58600ac3SJames Wright 
1352*58600ac3SJames Wright     // Apply libCEED operator
1353*58600ac3SJames Wright     PetscCall(PetscLogGpuTimeBegin());
1354*58600ac3SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1355*58600ac3SJames Wright     PetscCall(PetscLogGpuTimeEnd());
1356*58600ac3SJames Wright 
1357*58600ac3SJames Wright     // Restore PETSc vectors
1358*58600ac3SJames Wright     PetscCall(VecReadC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc));
1359*58600ac3SJames Wright     PetscCall(VecC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc));
1360*58600ac3SJames Wright 
1361*58600ac3SJames Wright     // Local-to-global
1362*58600ac3SJames Wright     PetscCall(VecZeroEntries(X));
1363*58600ac3SJames Wright     PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X));
1364*58600ac3SJames Wright 
1365*58600ac3SJames Wright     // Restore local vectors, as needed
1366*58600ac3SJames Wright     if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1367*58600ac3SJames Wright     PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1368*58600ac3SJames Wright   }
1369*58600ac3SJames Wright 
1370*58600ac3SJames Wright   // Log flops
1371*58600ac3SJames Wright   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose));
1372*58600ac3SJames Wright   else PetscCall(PetscLogFlops(ctx->flops_mult_transpose));
1373*58600ac3SJames Wright 
1374*58600ac3SJames Wright   PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0));
1375*58600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1376*58600ac3SJames Wright }
1377