xref: /honee/src/mat-ceed.c (revision 58600ac3b26c12c428243f1945a2fbe2d8f3ffe2)
1 /// @file
2 /// MatCeed and it's related operators
3 
4 #include <ceed.h>
5 #include <ceed/backend.h>
6 #include <mat-ceed-impl.h>
7 #include <mat-ceed.h>
8 #include <petscdmplex.h>
9 #include <stdlib.h>
10 #include <string.h>
11 
12 PetscClassId  MATCEED_CLASSID;
13 PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE;
14 
15 /**
16   @brief Register MATCEED log events.
17 
18   Not collective across MPI processes.
19 
20   @return An error code: 0 - success, otherwise - failure
21 **/
22 static PetscErrorCode MatCeedRegisterLogEvents() {
23   static bool registered = false;
24 
25   PetscFunctionBeginUser;
26   if (registered) PetscFunctionReturn(PETSC_SUCCESS);
27   PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID));
28   PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT));
29   PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE));
30   registered = true;
31   PetscFunctionReturn(PETSC_SUCCESS);
32 }
33 
34 /**
35   @brief Translate PetscMemType to CeedMemType
36 
37   @param[in]  mem_type  PetscMemType
38 
39   @return Equivalent CeedMemType
40 **/
41 static inline CeedMemType MemTypeP2C(PetscMemType mem_type) { return PetscMemTypeDevice(mem_type) ? CEED_MEM_DEVICE : CEED_MEM_HOST; }
42 
43 /**
44   @brief Translate array of `CeedInt` to `PetscInt`.
45     If the types differ, `array_ceed` is freed with `free()` and `array_petsc` is allocated with `malloc()`.
46     Caller is responsible for freeing `array_petsc` with `free()`.
47 
48   Not collective across MPI processes.
49 
50   @param[in]      num_entries  Number of array entries
51   @param[in,out]  array_ceed   Array of `CeedInt`
52   @param[out]     array_petsc  Array of `PetscInt`
53 
54   @return An error code: 0 - success, otherwise - failure
55 **/
56 static inline PetscErrorCode IntArrayC2P(PetscInt num_entries, CeedInt **array_ceed, PetscInt **array_petsc) {
57   const CeedInt  int_c = 0;
58   const PetscInt int_p = 0;
59 
60   PetscFunctionBeginUser;
61   if (sizeof(int_c) == sizeof(int_p)) {
62     *array_petsc = (PetscInt *)*array_ceed;
63   } else {
64     *array_petsc = malloc(num_entries * sizeof(PetscInt));
65     for (PetscInt i = 0; i < num_entries; i++) (*array_petsc)[i] = (*array_ceed)[i];
66     free(*array_ceed);
67   }
68   *array_ceed = NULL;
69   PetscFunctionReturn(PETSC_SUCCESS);
70 }
71 
72 /**
73   @brief Transfer array from PETSc `Vec` to `CeedVector`.
74 
75   Collective across MPI processes.
76 
77   @param[in]   ceed      libCEED context
78   @param[in]   X_petsc   PETSc `Vec`
79   @param[out]  mem_type  PETSc `MemType`
80   @param[out]  x_ceed    `CeedVector`
81 
82   @return An error code: 0 - success, otherwise - failure
83 **/
84 static inline PetscErrorCode VecP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
85   PetscScalar *x;
86 
87   PetscFunctionBeginUser;
88   PetscCall(VecGetArrayAndMemType(X_petsc, &x, mem_type));
89   PetscCeedCall(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x));
90   PetscFunctionReturn(PETSC_SUCCESS);
91 }
92 
93 /**
94   @brief Transfer array from `CeedVector` to PETSc `Vec`.
95 
96   Collective across MPI processes.
97 
98   @param[in]   ceed      libCEED context
99   @param[in]   x_ceed    `CeedVector`
100   @param[in]   mem_type  PETSc `MemType`
101   @param[out]  X_petsc   PETSc `Vec`
102 
103   @return An error code: 0 - success, otherwise - failure
104 **/
105 static inline PetscErrorCode VecC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
106   PetscScalar *x;
107 
108   PetscFunctionBeginUser;
109   PetscCeedCall(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x));
110   PetscCall(VecRestoreArrayAndMemType(X_petsc, &x));
111   PetscFunctionReturn(PETSC_SUCCESS);
112 }
113 
114 /**
115   @brief Transfer read only array from PETSc `Vec` to `CeedVector`.
116 
117   Collective across MPI processes.
118 
119   @param[in]   ceed      libCEED context
120   @param[in]   X_petsc   PETSc `Vec`
121   @param[out]  mem_type  PETSc `MemType`
122   @param[out]  x_ceed    `CeedVector`
123 
124   @return An error code: 0 - success, otherwise - failure
125 **/
126 static inline PetscErrorCode VecReadP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
127   PetscScalar *x;
128 
129   PetscFunctionBeginUser;
130   PetscCall(VecGetArrayReadAndMemType(X_petsc, (const PetscScalar **)&x, mem_type));
131   PetscCeedCall(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x));
132   PetscFunctionReturn(PETSC_SUCCESS);
133 }
134 
135 /**
136   @brief Transfer read only array from `CeedVector` to PETSc `Vec`.
137 
138   Collective across MPI processes.
139 
140   @param[in]   ceed      libCEED context
141   @param[in]   x_ceed    `CeedVector`
142   @param[in]   mem_type  PETSc `MemType`
143   @param[out]  X_petsc   PETSc `Vec`
144 
145   @return An error code: 0 - success, otherwise - failure
146 **/
147 static inline PetscErrorCode VecReadC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
148   PetscScalar *x;
149 
150   PetscFunctionBeginUser;
151   PetscCeedCall(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x));
152   PetscCall(VecRestoreArrayReadAndMemType(X_petsc, (const PetscScalar **)&x));
153   PetscFunctionReturn(PETSC_SUCCESS);
154 }
155 
156 /**
157   @brief Setup inner `Mat` for `PC` operations not directly supported by libCEED.
158 
159   Collective across MPI processes.
160 
161   @param[in]   mat_ceed   `MATCEED` to setup
162   @param[out]  mat_inner  Inner `Mat`
163 
164   @return An error code: 0 - success, otherwise - failure
165 **/
166 static PetscErrorCode MatCeedSetupInnerMat(Mat mat_ceed, Mat *mat_inner) {
167   MatCeedContext ctx;
168 
169   PetscFunctionBeginUser;
170   PetscCall(MatShellGetContext(mat_ceed, &ctx));
171 
172   PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "PC only supported for MATCEED on a single DM");
173 
174   // Check cl mat type
175   {
176     PetscBool is_internal_mat_type_cl = PETSC_FALSE;
177     char      internal_mat_type_cl[64];
178 
179     // Check for specific CL inner mat type for this Mat
180     {
181       const char *mat_ceed_prefix = NULL;
182 
183       PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix));
184       PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL);
185       PetscCall(PetscOptionsFList("-ceed_inner_mat_type", "MATCEED inner assembled MatType for PC support", NULL, MatList, internal_mat_type_cl,
186                                   internal_mat_type_cl, sizeof(internal_mat_type_cl), &is_internal_mat_type_cl));
187       PetscOptionsEnd();
188       if (is_internal_mat_type_cl) {
189         PetscCall(PetscFree(ctx->internal_mat_type));
190         PetscCall(PetscStrallocpy(internal_mat_type_cl, &ctx->internal_mat_type));
191       }
192     }
193   }
194 
195   // Create sparse matrix
196   {
197     MatType dm_mat_type, dm_mat_type_copy;
198 
199     PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type));
200     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
201     PetscCall(DMSetMatType(ctx->dm_x, ctx->internal_mat_type));
202     PetscCall(DMCreateMatrix(ctx->dm_x, mat_inner));
203     PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy));
204     PetscCall(PetscFree(dm_mat_type_copy));
205   }
206   PetscFunctionReturn(PETSC_SUCCESS);
207 }
208 
209 /**
210   @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar.
211          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
212          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
213 
214   Collective across MPI processes.
215 
216   @param[in]      mat_ceed  `MATCEED` to assemble
217   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
218 
219   @return An error code: 0 - success, otherwise - failure
220 **/
221 static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) {
222   MatCeedContext ctx;
223 
224   PetscFunctionBeginUser;
225   PetscCall(MatShellGetContext(mat_ceed, &ctx));
226 
227   // Check if COO pattern set
228   {
229     PetscInt index = -1;
230 
231     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
232       if (ctx->mats_assembled_pbd[i] == mat_coo) index = i;
233     }
234     if (index == -1) {
235       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
236       CeedInt      *rows_ceed, *cols_ceed;
237       PetscCount    num_entries;
238       PetscLogStage stage_amg_setup;
239 
240       // -- Assemble sparsity pattern if mat hasn't been assembled before
241       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
242       if (stage_amg_setup == -1) {
243         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
244       }
245       PetscCall(PetscLogStagePush(stage_amg_setup));
246       PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
247       PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc));
248       PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc));
249       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
250       free(rows_petsc);
251       free(cols_petsc);
252       if (!ctx->coo_values_pbd) PetscCeedCall(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd));
253       PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd));
254       ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo;
255       PetscCall(PetscLogStagePop());
256     }
257   }
258 
259   // Assemble mat_ceed
260   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
261   {
262     const CeedScalar *values;
263     MatType           mat_type;
264     CeedMemType       mem_type = CEED_MEM_HOST;
265     PetscBool         is_spd, is_spd_known;
266 
267     PetscCall(MatGetType(mat_coo, &mat_type));
268     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
269     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
270     else mem_type = CEED_MEM_HOST;
271 
272     PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE));
273     PetscCeedCall(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values));
274     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
275     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
276     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
277     PetscCeedCall(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values));
278   }
279   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
280   PetscFunctionReturn(PETSC_SUCCESS);
281 }
282 
283 /**
284   @brief Assemble inner `Mat` for diagonal `PC` operations
285 
286   Collective across MPI processes.
287 
288   @param[in]   mat_ceed      `MATCEED` to invert
289   @param[in]   use_ceed_pbd  Boolean flag to use libCEED PBD assembly
290   @param[out]  mat_inner     Inner `Mat` for diagonal operations
291 
292   @return An error code: 0 - success, otherwise - failure
293 **/
294 static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) {
295   MatCeedContext ctx;
296 
297   PetscFunctionBeginUser;
298   PetscCall(MatShellGetContext(mat_ceed, &ctx));
299   if (use_ceed_pbd) {
300     // Check if COO pattern set
301     if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_pbd_internal));
302 
303     // Assemble mat_assembled_full_internal
304     PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal));
305     if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal;
306   } else {
307     // Check if COO pattern set
308     if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_full_internal));
309 
310     // Assemble mat_assembled_full_internal
311     PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal));
312     if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal;
313   }
314   PetscFunctionReturn(PETSC_SUCCESS);
315 }
316 
317 /**
318   @brief Get `MATCEED` diagonal block for Jacobi.
319 
320   Collective across MPI processes.
321 
322   @param[in]   mat_ceed   `MATCEED` to invert
323   @param[out]  mat_block  The diagonal block matrix
324 
325   @return An error code: 0 - success, otherwise - failure
326 **/
327 static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) {
328   Mat            mat_inner = NULL;
329   MatCeedContext ctx;
330 
331   PetscFunctionBeginUser;
332   PetscCall(MatShellGetContext(mat_ceed, &ctx));
333 
334   // Assemble inner mat if needed
335   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
336 
337   // Get block diagonal
338   PetscCall(MatGetDiagonalBlock(mat_inner, mat_block));
339   PetscFunctionReturn(PETSC_SUCCESS);
340 }
341 
342 /**
343   @brief Invert `MATCEED` diagonal block for Jacobi.
344 
345   Collective across MPI processes.
346 
347   @param[in]   mat_ceed  `MATCEED` to invert
348   @param[out]  values    The block inverses in column major order
349 
350   @return An error code: 0 - success, otherwise - failure
351 **/
352 static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) {
353   Mat            mat_inner = NULL;
354   MatCeedContext ctx;
355 
356   PetscFunctionBeginUser;
357   PetscCall(MatShellGetContext(mat_ceed, &ctx));
358 
359   // Assemble inner mat if needed
360   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
361 
362   // Invert PB diagonal
363   PetscCall(MatInvertBlockDiagonal(mat_inner, values));
364   PetscFunctionReturn(PETSC_SUCCESS);
365 }
366 
367 /**
368   @brief Invert `MATCEED` variable diagonal block for Jacobi.
369 
370   Collective across MPI processes.
371 
372   @param[in]   mat_ceed     `MATCEED` to invert
373   @param[in]   num_blocks   The number of blocks on the process
374   @param[in]   block_sizes  The size of each block on the process
375   @param[out]  values       The block inverses in column major order
376 
377   @return An error code: 0 - success, otherwise - failure
378 **/
379 static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) {
380   Mat            mat_inner = NULL;
381   MatCeedContext ctx;
382 
383   PetscFunctionBeginUser;
384   PetscCall(MatShellGetContext(mat_ceed, &ctx));
385 
386   // Assemble inner mat if needed
387   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner));
388 
389   // Invert PB diagonal
390   PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values));
391   PetscFunctionReturn(PETSC_SUCCESS);
392 }
393 
394 // -----------------------------------------------------------------------------
395 // MatCeed
396 // -----------------------------------------------------------------------------
397 
398 /**
399   @brief Create PETSc `Mat` from libCEED operators.
400 
401   Collective across MPI processes.
402 
403   @param[in]   dm_x                      Input `DM`
404   @param[in]   dm_y                      Output `DM`
405   @param[in]   op_mult                   `CeedOperator` for forward evaluation
406   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
407   @param[out]  mat                        New MatCeed
408 
409   @return An error code: 0 - success, otherwise - failure
410 **/
411 PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) {
412   PetscInt       X_l_size, X_g_size, Y_l_size, Y_g_size;
413   VecType        vec_type;
414   MatCeedContext ctx;
415 
416   PetscFunctionBeginUser;
417   PetscCall(MatCeedRegisterLogEvents());
418 
419   // Collect context data
420   PetscCall(DMGetVecType(dm_x, &vec_type));
421   {
422     Vec X;
423 
424     PetscCall(DMGetGlobalVector(dm_x, &X));
425     PetscCall(VecGetSize(X, &X_g_size));
426     PetscCall(VecGetLocalSize(X, &X_l_size));
427     PetscCall(DMRestoreGlobalVector(dm_x, &X));
428   }
429   if (dm_y) {
430     Vec Y;
431 
432     PetscCall(DMGetGlobalVector(dm_y, &Y));
433     PetscCall(VecGetSize(Y, &Y_g_size));
434     PetscCall(VecGetLocalSize(Y, &Y_l_size));
435     PetscCall(DMRestoreGlobalVector(dm_y, &Y));
436   } else {
437     dm_y     = dm_x;
438     Y_g_size = X_g_size;
439     Y_l_size = X_l_size;
440   }
441   // Create context
442   {
443     Vec X_loc, Y_loc_transpose = NULL;
444 
445     PetscCall(DMCreateLocalVector(dm_x, &X_loc));
446     PetscCall(VecZeroEntries(X_loc));
447     if (op_mult_transpose) {
448       PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose));
449       PetscCall(VecZeroEntries(Y_loc_transpose));
450     }
451     PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx));
452     PetscCall(VecDestroy(&X_loc));
453     PetscCall(VecDestroy(&Y_loc_transpose));
454   }
455 
456   // Create mat
457   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat));
458   PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED));
459   // -- Set block and variable block sizes
460   if (dm_x == dm_y) {
461     MatType dm_mat_type, dm_mat_type_copy;
462     Mat     temp_mat;
463 
464     PetscCall(DMGetMatType(dm_x, &dm_mat_type));
465     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
466     PetscCall(DMSetMatType(dm_x, MATAIJ));
467     PetscCall(DMCreateMatrix(dm_x, &temp_mat));
468     PetscCall(DMSetMatType(dm_x, dm_mat_type_copy));
469     PetscCall(PetscFree(dm_mat_type_copy));
470 
471     {
472       PetscInt        block_size, num_blocks, max_vblock_size = PETSC_INT_MAX;
473       const PetscInt *vblock_sizes;
474 
475       // -- Get block sizes
476       PetscCall(MatGetBlockSize(temp_mat, &block_size));
477       PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes));
478       {
479         PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX};
480 
481         for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]);
482         PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max));
483         max_vblock_size = global_min_max[1];
484       }
485 
486       // -- Copy block sizes
487       if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size));
488       if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes));
489 
490       // -- Check libCEED compatibility
491       {
492         bool is_composite;
493 
494         ctx->is_ceed_pbd_valid  = PETSC_TRUE;
495         ctx->is_ceed_vpbd_valid = PETSC_TRUE;
496         PetscCeedCall(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite));
497         if (is_composite) {
498           CeedInt       num_sub_operators;
499           CeedOperator *sub_operators;
500 
501           PetscCeedCall(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators));
502           PetscCeedCall(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators));
503           for (CeedInt i = 0; i < num_sub_operators; i++) {
504             CeedInt                  num_bases, num_comp;
505             CeedBasis               *active_bases;
506             CeedOperatorAssemblyData assembly_data;
507 
508             PetscCeedCall(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data));
509             PetscCeedCall(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
510             PetscCeedCall(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
511             if (num_bases > 1) {
512               ctx->is_ceed_pbd_valid  = PETSC_FALSE;
513               ctx->is_ceed_vpbd_valid = PETSC_FALSE;
514             }
515             if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
516             if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
517           }
518         } else {
519           // LCOV_EXCL_START
520           CeedInt                  num_bases, num_comp;
521           CeedBasis               *active_bases;
522           CeedOperatorAssemblyData assembly_data;
523 
524           PetscCeedCall(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data));
525           PetscCeedCall(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
526           PetscCeedCall(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
527           if (num_bases > 1) {
528             ctx->is_ceed_pbd_valid  = PETSC_FALSE;
529             ctx->is_ceed_vpbd_valid = PETSC_FALSE;
530           }
531           if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
532           if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
533           // LCOV_EXCL_STOP
534         }
535         {
536           PetscInt local_is_valid[2], global_is_valid[2];
537 
538           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid;
539           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
540           ctx->is_ceed_pbd_valid = global_is_valid[0];
541           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid;
542           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
543           ctx->is_ceed_vpbd_valid = global_is_valid[0];
544         }
545       }
546     }
547     PetscCall(MatDestroy(&temp_mat));
548   }
549   // -- Set internal mat type
550   {
551     VecType vec_type;
552     MatType internal_mat_type = MATAIJ;
553 
554     PetscCall(VecGetType(ctx->X_loc, &vec_type));
555     if (strstr(vec_type, VECCUDA)) internal_mat_type = MATAIJCUSPARSE;
556     else if (strstr(vec_type, VECKOKKOS)) internal_mat_type = MATAIJKOKKOS;
557     else internal_mat_type = MATAIJ;
558     PetscCall(PetscStrallocpy(internal_mat_type, &ctx->internal_mat_type));
559   }
560   // -- Set mat operations
561   PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
562   PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed));
563   if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
564   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
565   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
566   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
567   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
568   PetscCall(MatShellSetVecType(*mat, vec_type));
569   PetscFunctionReturn(PETSC_SUCCESS);
570 }
571 
572 /**
573   @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`.
574 
575   Collective across MPI processes.
576 
577   @param[in]   mat_ceed   `MATCEED` to copy from
578   @param[out]  mat_other  `MatShell` or `MATCEED` to copy into
579 
580   @return An error code: 0 - success, otherwise - failure
581 **/
582 PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) {
583   PetscFunctionBeginUser;
584   PetscCall(MatCeedRegisterLogEvents());
585 
586   // Check type compatibility
587   {
588     MatType mat_type_ceed, mat_type_other;
589 
590     PetscCall(MatGetType(mat_ceed, &mat_type_ceed));
591     PetscCheck(!strcmp(mat_type_ceed, MATCEED), PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED);
592     PetscCall(MatGetType(mat_ceed, &mat_type_other));
593     PetscCheck(!strcmp(mat_type_other, MATCEED) || !strcmp(mat_type_other, MATSHELL), PETSC_COMM_SELF, PETSC_ERR_LIB,
594                "mat_other must have type " MATCEED " or " MATSHELL);
595   }
596 
597   // Check dimension compatibility
598   {
599     PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size;
600 
601     PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size));
602     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size));
603     PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size));
604     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size));
605     PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) &&
606                    (X_l_ceed_size == X_l_other_size),
607                PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ,
608                "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT
609                "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT
610                ", %" PetscInt_FMT ")",
611                Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size);
612   }
613 
614   // Convert
615   {
616     VecType        vec_type;
617     MatCeedContext ctx;
618 
619     PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED));
620     PetscCall(MatShellGetContext(mat_ceed, &ctx));
621     PetscCall(MatCeedContextReference(ctx));
622     PetscCall(MatShellSetContext(mat_other, ctx));
623     PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
624     PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed));
625     if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
626     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
627     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
628     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
629     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
630     {
631       PetscInt block_size;
632 
633       PetscCall(MatGetBlockSize(mat_ceed, &block_size));
634       if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size));
635     }
636     {
637       PetscInt        num_blocks;
638       const PetscInt *block_sizes;
639 
640       PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes));
641       if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes));
642     }
643     PetscCall(DMGetVecType(ctx->dm_x, &vec_type));
644     PetscCall(MatShellSetVecType(mat_other, vec_type));
645   }
646   PetscFunctionReturn(PETSC_SUCCESS);
647 }
648 
649 /**
650   @brief Assemble a `MATCEED` into a `MATAIJ` or similar.
651          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
652          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
653 
654   Collective across MPI processes.
655 
656   @param[in]      mat_ceed  `MATCEED` to assemble
657   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
658 
659   @return An error code: 0 - success, otherwise - failure
660 **/
661 PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) {
662   MatCeedContext ctx;
663 
664   PetscFunctionBeginUser;
665   PetscCall(MatShellGetContext(mat_ceed, &ctx));
666 
667   // Check if COO pattern set
668   {
669     PetscInt index = -1;
670 
671     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
672       if (ctx->mats_assembled_full[i] == mat_coo) index = i;
673     }
674     if (index == -1) {
675       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
676       CeedInt      *rows_ceed, *cols_ceed;
677       PetscCount    num_entries;
678       PetscLogStage stage_amg_setup;
679 
680       // -- Assemble sparsity pattern if mat hasn't been assembled before
681       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
682       if (stage_amg_setup == -1) {
683         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
684       }
685       PetscCall(PetscLogStagePush(stage_amg_setup));
686       PetscCeedCall(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
687       PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc));
688       PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc));
689       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
690       free(rows_petsc);
691       free(cols_petsc);
692       if (!ctx->coo_values_full) PetscCeedCall(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full));
693       PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full));
694       ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo;
695       PetscCall(PetscLogStagePop());
696     }
697   }
698 
699   // Assemble mat_ceed
700   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
701   {
702     const CeedScalar *values;
703     MatType           mat_type;
704     CeedMemType       mem_type = CEED_MEM_HOST;
705     PetscBool         is_spd, is_spd_known;
706 
707     PetscCall(MatGetType(mat_coo, &mat_type));
708     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
709     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
710     else mem_type = CEED_MEM_HOST;
711 
712     PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full));
713     PetscCeedCall(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values));
714     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
715     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
716     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
717     PetscCeedCall(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values));
718   }
719   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
720   PetscFunctionReturn(PETSC_SUCCESS);
721 }
722 
723 /**
724   @brief Set user context for a `MATCEED`.
725 
726   Collective across MPI processes.
727 
728   @param[in,out]  mat  `MATCEED`
729   @param[in]      f    The context destroy function, or NULL
730   @param[in]      ctx  User context, or NULL to unset
731 
732   @return An error code: 0 - success, otherwise - failure
733 **/
734 PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) {
735   PetscContainer user_ctx = NULL;
736 
737   PetscFunctionBeginUser;
738   if (ctx) {
739     PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx));
740     PetscCall(PetscContainerSetPointer(user_ctx, ctx));
741     PetscCall(PetscContainerSetUserDestroy(user_ctx, f));
742   }
743   PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx));
744   PetscCall(PetscContainerDestroy(&user_ctx));
745   PetscFunctionReturn(PETSC_SUCCESS);
746 }
747 
748 /**
749   @brief Retrieve the user context for a `MATCEED`.
750 
751   Collective across MPI processes.
752 
753   @param[in,out]  mat  `MATCEED`
754   @param[in]      ctx  User context
755 
756   @return An error code: 0 - success, otherwise - failure
757 **/
758 PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) {
759   PetscContainer user_ctx;
760 
761   PetscFunctionBeginUser;
762   PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx));
763   if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx));
764   else *(void **)ctx = NULL;
765   PetscFunctionReturn(PETSC_SUCCESS);
766 }
767 
768 /**
769   @brief Sets the inner matrix type as a string from the `MATCEED`.
770 
771   Collective across MPI processes.
772 
773   @param[in,out]  mat   `MATCEED`
774   @param[in]      type  Inner `MatType` to set
775 
776   @return An error code: 0 - success, otherwise - failure
777 **/
778 PetscErrorCode MatCeedSetInnerMatType(Mat mat, MatType type) {
779   MatCeedContext ctx;
780 
781   PetscFunctionBeginUser;
782   PetscCall(MatShellGetContext(mat, &ctx));
783   // Check if same
784   {
785     size_t    len_old, len_new;
786     PetscBool is_same = PETSC_FALSE;
787 
788     PetscCall(PetscStrlen(ctx->internal_mat_type, &len_old));
789     PetscCall(PetscStrlen(type, &len_new));
790     if (len_old == len_new) PetscCall(PetscStrncmp(ctx->internal_mat_type, type, len_old, &is_same));
791     if (is_same) PetscFunctionReturn(PETSC_SUCCESS);
792   }
793   // Clean up old mats in different format
794   // LCOV_EXCL_START
795   if (ctx->mat_assembled_full_internal) {
796     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
797       if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) {
798         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) {
799           ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j];
800         }
801         ctx->num_mats_assembled_full--;
802         // Note: we'll realloc this array again, so no need to shrink the allocation
803         PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
804       }
805     }
806   }
807   if (ctx->mat_assembled_pbd_internal) {
808     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
809       if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) {
810         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) {
811           ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j];
812         }
813         // Note: we'll realloc this array again, so no need to shrink the allocation
814         ctx->num_mats_assembled_pbd--;
815         PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
816       }
817     }
818   }
819   PetscCall(PetscFree(ctx->internal_mat_type));
820   PetscCall(PetscStrallocpy(type, &ctx->internal_mat_type));
821   PetscFunctionReturn(PETSC_SUCCESS);
822   // LCOV_EXCL_STOP
823 }
824 
825 /**
826   @brief Gets the inner matrix type as a string from the `MATCEED`.
827 
828   Collective across MPI processes.
829 
830   @param[in,out]  mat   `MATCEED`
831   @param[in]      type  Inner `MatType`
832 
833   @return An error code: 0 - success, otherwise - failure
834 **/
835 PetscErrorCode MatCeedGetInnerMatType(Mat mat, MatType *type) {
836   MatCeedContext ctx;
837 
838   PetscFunctionBeginUser;
839   PetscCall(MatShellGetContext(mat, &ctx));
840   *type = ctx->internal_mat_type;
841   PetscFunctionReturn(PETSC_SUCCESS);
842 }
843 
844 /**
845   @brief Set a user defined matrix operation for a `MATCEED` matrix.
846 
847   Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by
848 `MatCeedSetContext()`.
849 
850   Collective across MPI processes.
851 
852   @param[in,out]  mat  `MATCEED`
853   @param[in]      op   Name of the `MatOperation`
854   @param[in]      g    Function that provides the operation
855 
856   @return An error code: 0 - success, otherwise - failure
857 **/
858 PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) {
859   PetscFunctionBeginUser;
860   PetscCall(MatShellSetOperation(mat, op, g));
861   PetscFunctionReturn(PETSC_SUCCESS);
862 }
863 
864 /**
865   @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
866 
867   Not collective across MPI processes.
868 
869   @param[in,out]  mat              `MATCEED`
870   @param[in]      X_loc            Input PETSc local vector, or NULL
871   @param[in]      Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
872 
873   @return An error code: 0 - success, otherwise - failure
874 **/
875 PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) {
876   MatCeedContext ctx;
877 
878   PetscFunctionBeginUser;
879   PetscCall(MatShellGetContext(mat, &ctx));
880   if (X_loc) {
881     PetscInt len_old, len_new;
882 
883     PetscCall(VecGetSize(ctx->X_loc, &len_old));
884     PetscCall(VecGetSize(X_loc, &len_new));
885     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT,
886                len_new, len_old);
887     PetscCall(VecDestroy(&ctx->X_loc));
888     ctx->X_loc = X_loc;
889     PetscCall(PetscObjectReference((PetscObject)X_loc));
890   }
891   if (Y_loc_transpose) {
892     PetscInt len_old, len_new;
893 
894     PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old));
895     PetscCall(VecGetSize(Y_loc_transpose, &len_new));
896     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB,
897                "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old);
898     PetscCall(VecDestroy(&ctx->Y_loc_transpose));
899     ctx->Y_loc_transpose = Y_loc_transpose;
900     PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
901   }
902   PetscFunctionReturn(PETSC_SUCCESS);
903 }
904 
905 /**
906   @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
907 
908   Not collective across MPI processes.
909 
910   @param[in,out]  mat              `MATCEED`
911   @param[out]     X_loc            Input PETSc local vector, or NULL
912   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
913 
914   @return An error code: 0 - success, otherwise - failure
915 **/
916 PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
917   MatCeedContext ctx;
918 
919   PetscFunctionBeginUser;
920   PetscCall(MatShellGetContext(mat, &ctx));
921   if (X_loc) {
922     *X_loc = ctx->X_loc;
923     PetscCall(PetscObjectReference((PetscObject)*X_loc));
924   }
925   if (Y_loc_transpose) {
926     *Y_loc_transpose = ctx->Y_loc_transpose;
927     PetscCall(PetscObjectReference((PetscObject)*Y_loc_transpose));
928   }
929   PetscFunctionReturn(PETSC_SUCCESS);
930 }
931 
932 /**
933   @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
934 
935   Not collective across MPI processes.
936 
937   @param[in,out]  mat              MatCeed
938   @param[out]     X_loc            Input PETSc local vector, or NULL
939   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
940 
941   @return An error code: 0 - success, otherwise - failure
942 **/
943 PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
944   PetscFunctionBeginUser;
945   if (X_loc) PetscCall(VecDestroy(X_loc));
946   if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose));
947   PetscFunctionReturn(PETSC_SUCCESS);
948 }
949 
950 /**
951   @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
952 
953   Not collective across MPI processes.
954 
955   @param[in,out]  mat                MatCeed
956   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
957   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
958 
959   @return An error code: 0 - success, otherwise - failure
960 **/
961 PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
962   MatCeedContext ctx;
963 
964   PetscFunctionBeginUser;
965   PetscCall(MatShellGetContext(mat, &ctx));
966   if (op_mult) {
967     *op_mult = NULL;
968     PetscCeedCall(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult));
969   }
970   if (op_mult_transpose) {
971     *op_mult_transpose = NULL;
972     PetscCeedCall(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose));
973   }
974   PetscFunctionReturn(PETSC_SUCCESS);
975 }
976 
977 /**
978   @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
979 
980   Not collective across MPI processes.
981 
982   @param[in,out]  mat                MatCeed
983   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
984   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
985 
986   @return An error code: 0 - success, otherwise - failure
987 **/
988 PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
989   MatCeedContext ctx;
990 
991   PetscFunctionBeginUser;
992   PetscCall(MatShellGetContext(mat, &ctx));
993   if (op_mult) PetscCeedCall(ctx->ceed, CeedOperatorDestroy(op_mult));
994   if (op_mult_transpose) PetscCeedCall(ctx->ceed, CeedOperatorDestroy(op_mult_transpose));
995   PetscFunctionReturn(PETSC_SUCCESS);
996 }
997 
998 /**
999   @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
1000 
1001   Not collective across MPI processes.
1002 
1003   @param[in,out]  mat                       MatCeed
1004   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
1005   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
1006 
1007   @return An error code: 0 - success, otherwise - failure
1008 **/
1009 PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) {
1010   MatCeedContext ctx;
1011 
1012   PetscFunctionBeginUser;
1013   PetscCall(MatShellGetContext(mat, &ctx));
1014   if (log_event_mult) ctx->log_event_mult = log_event_mult;
1015   if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose;
1016   PetscFunctionReturn(PETSC_SUCCESS);
1017 }
1018 
1019 /**
1020   @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
1021 
1022   Not collective across MPI processes.
1023 
1024   @param[in,out]  mat                       MatCeed
1025   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
1026   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
1027 
1028   @return An error code: 0 - success, otherwise - failure
1029 **/
1030 PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) {
1031   MatCeedContext ctx;
1032 
1033   PetscFunctionBeginUser;
1034   PetscCall(MatShellGetContext(mat, &ctx));
1035   if (log_event_mult) *log_event_mult = ctx->log_event_mult;
1036   if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose;
1037   PetscFunctionReturn(PETSC_SUCCESS);
1038 }
1039 
1040 // -----------------------------------------------------------------------------
1041 // Operator context data
1042 // -----------------------------------------------------------------------------
1043 
1044 /**
1045   @brief Setup context data for operator application.
1046 
1047   Collective across MPI processes.
1048 
1049   @param[in]   dm_x                      Input `DM`
1050   @param[in]   dm_y                      Output `DM`
1051   @param[in]   X_loc                     Input PETSc local vector, or NULL
1052   @param[in]   Y_loc_transpose           Input PETSc local vector for transpose operation, or NULL
1053   @param[in]   op_mult                   `CeedOperator` for forward evaluation
1054   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
1055   @param[in]   log_event_mult            `PetscLogEvent` for forward evaluation
1056   @param[in]   log_event_mult_transpose  `PetscLogEvent` for transpose evaluation
1057   @param[out]  ctx                       Context data for operator evaluation
1058 
1059   @return An error code: 0 - success, otherwise - failure
1060 **/
1061 PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose,
1062                                     PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) {
1063   CeedSize x_loc_len, y_loc_len;
1064 
1065   PetscFunctionBeginUser;
1066 
1067   // Allocate
1068   PetscCall(PetscNew(ctx));
1069   (*ctx)->ref_count = 1;
1070 
1071   // Logging
1072   (*ctx)->log_event_mult           = log_event_mult;
1073   (*ctx)->log_event_mult_transpose = log_event_mult_transpose;
1074 
1075   // PETSc objects
1076   PetscCall(PetscObjectReference((PetscObject)dm_x));
1077   (*ctx)->dm_x = dm_x;
1078   PetscCall(PetscObjectReference((PetscObject)dm_y));
1079   (*ctx)->dm_y = dm_y;
1080   if (X_loc) PetscCall(PetscObjectReference((PetscObject)X_loc));
1081   (*ctx)->X_loc = X_loc;
1082   if (Y_loc_transpose) PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
1083   (*ctx)->Y_loc_transpose = Y_loc_transpose;
1084 
1085   // Memtype
1086   {
1087     const PetscScalar *x;
1088     Vec                X;
1089 
1090     PetscCall(DMGetLocalVector(dm_x, &X));
1091     PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type));
1092     PetscCall(VecRestoreArrayReadAndMemType(X, &x));
1093     PetscCall(DMRestoreLocalVector(dm_x, &X));
1094   }
1095 
1096   // libCEED objects
1097   PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB,
1098              "retrieving Ceed context object failed");
1099   PetscCeedCall((*ctx)->ceed, CeedReference((*ctx)->ceed));
1100   PetscCeedCall((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len));
1101   PetscCeedCall((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult));
1102   if (op_mult_transpose) PetscCeedCall((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose));
1103   PetscCeedCall((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc));
1104   PetscCeedCall((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc));
1105 
1106   // Flop counting
1107   {
1108     CeedSize ceed_flops_estimate = 0;
1109 
1110     PetscCeedCall((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate));
1111     (*ctx)->flops_mult = ceed_flops_estimate;
1112     if (op_mult_transpose) {
1113       PetscCeedCall((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate));
1114       (*ctx)->flops_mult_transpose = ceed_flops_estimate;
1115     }
1116   }
1117 
1118   // Check sizes
1119   if (x_loc_len > 0 || y_loc_len > 0) {
1120     CeedSize ctx_x_loc_len, ctx_y_loc_len;
1121     PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len;
1122     Vec      dm_X_loc, dm_Y_loc;
1123 
1124     // -- Input
1125     PetscCall(DMGetLocalVector(dm_x, &dm_X_loc));
1126     PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len));
1127     PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc));
1128     if (X_loc) PetscCall(VecGetLocalSize(X_loc, &X_loc_len));
1129     PetscCeedCall((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len));
1130     if (X_loc) PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "X_loc must match dm_x dimensions");
1131     PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_x dimensions");
1132     PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc must match op dimensions");
1133 
1134     // -- Output
1135     PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc));
1136     PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len));
1137     PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc));
1138     PetscCeedCall((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len));
1139     PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_y dimensions");
1140 
1141     // -- Transpose
1142     if (Y_loc_transpose) {
1143       PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len));
1144       PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "Y_loc_transpose must match dm_y dimensions");
1145     }
1146   }
1147   PetscFunctionReturn(PETSC_SUCCESS);
1148 }
1149 
1150 /**
1151   @brief Increment reference counter for `MATCEED` context.
1152 
1153   Not collective across MPI processes.
1154 
1155   @param[in,out]  ctx  Context data
1156 
1157   @return An error code: 0 - success, otherwise - failure
1158 **/
1159 PetscErrorCode MatCeedContextReference(MatCeedContext ctx) {
1160   PetscFunctionBeginUser;
1161   ctx->ref_count++;
1162   PetscFunctionReturn(PETSC_SUCCESS);
1163 }
1164 
1165 /**
1166   @brief Copy reference for `MATCEED`.
1167          Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`.
1168 
1169   Not collective across MPI processes.
1170 
1171   @param[in]   ctx       Context data
1172   @param[out]  ctx_copy  Copy of pointer to context data
1173 
1174   @return An error code: 0 - success, otherwise - failure
1175 **/
1176 PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) {
1177   PetscFunctionBeginUser;
1178   PetscCall(MatCeedContextReference(ctx));
1179   PetscCall(MatCeedContextDestroy(*ctx_copy));
1180   *ctx_copy = ctx;
1181   PetscFunctionReturn(PETSC_SUCCESS);
1182 }
1183 
1184 /**
1185   @brief Destroy context data for operator application.
1186 
1187   Collective across MPI processes.
1188 
1189   @param[in,out]  ctx  Context data for operator evaluation
1190 
1191   @return An error code: 0 - success, otherwise - failure
1192 **/
1193 PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) {
1194   PetscFunctionBeginUser;
1195   if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS);
1196 
1197   // PETSc objects
1198   PetscCall(DMDestroy(&ctx->dm_x));
1199   PetscCall(DMDestroy(&ctx->dm_y));
1200   PetscCall(VecDestroy(&ctx->X_loc));
1201   PetscCall(VecDestroy(&ctx->Y_loc_transpose));
1202   PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
1203   PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
1204   PetscCall(PetscFree(ctx->internal_mat_type));
1205   PetscCall(PetscFree(ctx->mats_assembled_full));
1206   PetscCall(PetscFree(ctx->mats_assembled_pbd));
1207 
1208   // libCEED objects
1209   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->x_loc));
1210   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->y_loc));
1211   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full));
1212   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd));
1213   PetscCeedCall(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult));
1214   PetscCeedCall(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose));
1215   PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed");
1216 
1217   // Deallocate
1218   ctx->is_destroyed = PETSC_TRUE;  // Flag as destroyed in case someone has stale ref
1219   PetscCall(PetscFree(ctx));
1220   PetscFunctionReturn(PETSC_SUCCESS);
1221 }
1222 
1223 /**
1224   @brief Compute the diagonal of an operator via libCEED.
1225 
1226   Collective across MPI processes.
1227 
1228   @param[in]   A  `MATCEED`
1229   @param[out]  D  Vector holding operator diagonal
1230 
1231   @return An error code: 0 - success, otherwise - failure
1232 **/
1233 PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) {
1234   PetscMemType   mem_type;
1235   Vec            D_loc;
1236   MatCeedContext ctx;
1237 
1238   PetscFunctionBeginUser;
1239   PetscCall(MatShellGetContext(A, &ctx));
1240 
1241   // Place PETSc vector in libCEED vector
1242   PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc));
1243   PetscCall(VecP2C(ctx->ceed, D_loc, &mem_type, ctx->x_loc));
1244 
1245   // Compute Diagonal
1246   PetscCeedCall(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1247 
1248   // Restore PETSc vector
1249   PetscCall(VecC2P(ctx->ceed, ctx->x_loc, mem_type, D_loc));
1250 
1251   // Local-to-Global
1252   PetscCall(VecZeroEntries(D));
1253   PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D));
1254   PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc));
1255   PetscFunctionReturn(PETSC_SUCCESS);
1256 }
1257 
1258 /**
1259   @brief Compute `A X = Y` for a `MATCEED`.
1260 
1261   Collective across MPI processes.
1262 
1263   @param[in]   A  `MATCEED`
1264   @param[in]   X  Input PETSc vector
1265   @param[out]  Y  Output PETSc vector
1266 
1267   @return An error code: 0 - success, otherwise - failure
1268 **/
1269 PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) {
1270   MatCeedContext ctx;
1271 
1272   PetscFunctionBeginUser;
1273   PetscCall(MatShellGetContext(A, &ctx));
1274   PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0));
1275 
1276   {
1277     PetscMemType x_mem_type, y_mem_type;
1278     Vec          X_loc = ctx->X_loc, Y_loc;
1279 
1280     // Get local vectors
1281     if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1282     PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1283 
1284     // Global-to-local
1285     PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc));
1286 
1287     // Setup libCEED vectors
1288     PetscCall(VecReadP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc));
1289     PetscCall(VecZeroEntries(Y_loc));
1290     PetscCall(VecP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc));
1291 
1292     // Apply libCEED operator
1293     PetscCall(PetscLogGpuTimeBegin());
1294     PetscCeedCall(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE));
1295     PetscCall(PetscLogGpuTimeEnd());
1296 
1297     // Restore PETSc vectors
1298     PetscCall(VecReadC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc));
1299     PetscCall(VecC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc));
1300 
1301     // Local-to-global
1302     PetscCall(VecZeroEntries(Y));
1303     PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y));
1304 
1305     // Restore local vectors, as needed
1306     if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1307     PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1308   }
1309 
1310   // Log flops
1311   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult));
1312   else PetscCall(PetscLogFlops(ctx->flops_mult));
1313 
1314   PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0));
1315   PetscFunctionReturn(PETSC_SUCCESS);
1316 }
1317 
1318 /**
1319   @brief Compute `A^T Y = X` for a `MATCEED`.
1320 
1321   Collective across MPI processes.
1322 
1323   @param[in]   A  `MATCEED`
1324   @param[in]   Y  Input PETSc vector
1325   @param[out]  X  Output PETSc vector
1326 
1327   @return An error code: 0 - success, otherwise - failure
1328 **/
1329 PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) {
1330   MatCeedContext ctx;
1331 
1332   PetscFunctionBeginUser;
1333   PetscCall(MatShellGetContext(A, &ctx));
1334   PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0));
1335 
1336   {
1337     PetscMemType x_mem_type, y_mem_type;
1338     Vec          X_loc, Y_loc = ctx->Y_loc_transpose;
1339 
1340     // Get local vectors
1341     if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1342     PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1343 
1344     // Global-to-local
1345     PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc));
1346 
1347     // Setup libCEED vectors
1348     PetscCall(VecReadP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc));
1349     PetscCall(VecZeroEntries(X_loc));
1350     PetscCall(VecP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc));
1351 
1352     // Apply libCEED operator
1353     PetscCall(PetscLogGpuTimeBegin());
1354     PetscCeedCall(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1355     PetscCall(PetscLogGpuTimeEnd());
1356 
1357     // Restore PETSc vectors
1358     PetscCall(VecReadC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc));
1359     PetscCall(VecC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc));
1360 
1361     // Local-to-global
1362     PetscCall(VecZeroEntries(X));
1363     PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X));
1364 
1365     // Restore local vectors, as needed
1366     if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1367     PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1368   }
1369 
1370   // Log flops
1371   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose));
1372   else PetscCall(PetscLogFlops(ctx->flops_mult_transpose));
1373 
1374   PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0));
1375   PetscFunctionReturn(PETSC_SUCCESS);
1376 }
1377