xref: /libCEED/examples/fluids/src/mat-ceed.c (revision 1ac265fb5043aa4a880d7861b19d62b13bc26bc5)
1 /// @file
2 /// MatCEED implementation
3 
4 #include <ceed.h>
5 #include <ceed/backend.h>
6 #include <mat-ceed-impl.h>
7 #include <mat-ceed.h>
8 #include <petsc-ceed-utils.h>
9 #include <petsc-ceed.h>
10 #include <petscdmplex.h>
11 #include <stdlib.h>
12 #include <string.h>
13 
14 PetscClassId  MATCEED_CLASSID;
15 PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE;
16 
17 /**
18   @brief Register MATCEED log events.
19 
20   Not collective across MPI processes.
21 
22   @return An error code: 0 - success, otherwise - failure
23 **/
24 static PetscErrorCode MatCeedRegisterLogEvents() {
25   static PetscBool registered = PETSC_FALSE;
26 
27   PetscFunctionBeginUser;
28   if (registered) PetscFunctionReturn(PETSC_SUCCESS);
29   PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID));
30   PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT));
31   PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE));
32   registered = PETSC_TRUE;
33   PetscFunctionReturn(PETSC_SUCCESS);
34 }
35 
36 /**
37   @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar.
38          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
39          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
40 
41   Collective across MPI processes.
42 
43   @param[in]      mat_ceed  `MATCEED` to assemble
44   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
45 
46   @return An error code: 0 - success, otherwise - failure
47 **/
48 static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) {
49   MatCeedContext ctx;
50 
51   PetscFunctionBeginUser;
52   PetscCall(MatShellGetContext(mat_ceed, &ctx));
53 
54   // Check if COO pattern set
55   {
56     PetscInt index = -1;
57 
58     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
59       if (ctx->mats_assembled_pbd[i] == mat_coo) index = i;
60     }
61     if (index == -1) {
62       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
63       CeedInt      *rows_ceed, *cols_ceed;
64       PetscCount    num_entries;
65       PetscLogStage stage_amg_setup;
66 
67       // -- Assemble sparsity pattern if mat hasn't been assembled before
68       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
69       if (stage_amg_setup == -1) {
70         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
71       }
72       PetscCall(PetscLogStagePush(stage_amg_setup));
73       PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
74       PetscCall(IntArrayCeedToPetsc(num_entries, &rows_ceed, &rows_petsc));
75       PetscCall(IntArrayCeedToPetsc(num_entries, &cols_ceed, &cols_petsc));
76       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
77       free(rows_petsc);
78       free(cols_petsc);
79       if (!ctx->coo_values_pbd) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd));
80       PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd));
81       ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo;
82       PetscCall(PetscLogStagePop());
83     }
84   }
85 
86   // Assemble mat_ceed
87   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
88   {
89     const CeedScalar *values;
90     MatType           mat_type;
91     CeedMemType       mem_type = CEED_MEM_HOST;
92     PetscBool         is_spd, is_spd_known;
93 
94     PetscCall(MatGetType(mat_coo, &mat_type));
95     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
96     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
97     else mem_type = CEED_MEM_HOST;
98 
99     PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE));
100     PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values));
101     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
102     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
103     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
104     PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values));
105   }
106   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
107   PetscFunctionReturn(PETSC_SUCCESS);
108 }
109 
110 /**
111   @brief Assemble inner `Mat` for diagonal `PC` operations
112 
113   Collective across MPI processes.
114 
115   @param[in]   mat_ceed      `MATCEED` to invert
116   @param[in]   use_ceed_pbd  Boolean flag to use libCEED PBD assembly
117   @param[out]  mat_inner     Inner `Mat` for diagonal operations
118 
119   @return An error code: 0 - success, otherwise - failure
120 **/
121 static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) {
122   MatCeedContext ctx;
123 
124   PetscFunctionBeginUser;
125   PetscCall(MatShellGetContext(mat_ceed, &ctx));
126   if (use_ceed_pbd) {
127     // Check if COO pattern set
128     if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedCreateMatCOO(mat_ceed, &ctx->mat_assembled_pbd_internal));
129 
130     // Assemble mat_assembled_full_internal
131     PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal));
132     if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal;
133   } else {
134     // Check if COO pattern set
135     if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedCreateMatCOO(mat_ceed, &ctx->mat_assembled_full_internal));
136 
137     // Assemble mat_assembled_full_internal
138     PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal));
139     if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal;
140   }
141   PetscFunctionReturn(PETSC_SUCCESS);
142 }
143 
144 /**
145   @brief Get `MATCEED` diagonal block for Jacobi.
146 
147   Collective across MPI processes.
148 
149   @param[in]   mat_ceed   `MATCEED` to invert
150   @param[out]  mat_block  The diagonal block matrix
151 
152   @return An error code: 0 - success, otherwise - failure
153 **/
154 static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) {
155   Mat            mat_inner = NULL;
156   MatCeedContext ctx;
157 
158   PetscFunctionBeginUser;
159   PetscCall(MatShellGetContext(mat_ceed, &ctx));
160 
161   // Assemble inner mat if needed
162   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
163 
164   // Get block diagonal
165   PetscCall(MatGetDiagonalBlock(mat_inner, mat_block));
166   PetscFunctionReturn(PETSC_SUCCESS);
167 }
168 
169 /**
170   @brief Invert `MATCEED` diagonal block for Jacobi.
171 
172   Collective across MPI processes.
173 
174   @param[in]   mat_ceed  `MATCEED` to invert
175   @param[out]  values    The block inverses in column major order
176 
177   @return An error code: 0 - success, otherwise - failure
178 **/
179 static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) {
180   Mat            mat_inner = NULL;
181   MatCeedContext ctx;
182 
183   PetscFunctionBeginUser;
184   PetscCall(MatShellGetContext(mat_ceed, &ctx));
185 
186   // Assemble inner mat if needed
187   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
188 
189   // Invert PB diagonal
190   PetscCall(MatInvertBlockDiagonal(mat_inner, values));
191   PetscFunctionReturn(PETSC_SUCCESS);
192 }
193 
194 /**
195   @brief Invert `MATCEED` variable diagonal block for Jacobi.
196 
197   Collective across MPI processes.
198 
199   @param[in]   mat_ceed     `MATCEED` to invert
200   @param[in]   num_blocks   The number of blocks on the process
201   @param[in]   block_sizes  The size of each block on the process
202   @param[out]  values       The block inverses in column major order
203 
204   @return An error code: 0 - success, otherwise - failure
205 **/
206 static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) {
207   Mat            mat_inner = NULL;
208   MatCeedContext ctx;
209 
210   PetscFunctionBeginUser;
211   PetscCall(MatShellGetContext(mat_ceed, &ctx));
212 
213   // Assemble inner mat if needed
214   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner));
215 
216   // Invert PB diagonal
217   PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values));
218   PetscFunctionReturn(PETSC_SUCCESS);
219 }
220 
221 /**
222   @brief View `MATCEED`.
223 
224   Collective across MPI processes.
225 
226   @param[in]   mat_ceed  `MATCEED` to view
227   @param[in]   viewer    The visualization context
228 
229   @return An error code: 0 - success, otherwise - failure
230 **/
231 static PetscErrorCode MatView_Ceed(Mat mat_ceed, PetscViewer viewer) {
232   PetscBool         is_ascii;
233   PetscViewerFormat format;
234   PetscMPIInt       size;
235   MatCeedContext    ctx;
236 
237   PetscFunctionBeginUser;
238   PetscValidHeaderSpecific(viewer, PETSC_VIEWER_CLASSID, 2);
239   PetscCall(MatShellGetContext(mat_ceed, &ctx));
240   if (!viewer) PetscCall(PetscViewerASCIIGetStdout(PetscObjectComm((PetscObject)mat_ceed), &viewer));
241 
242   PetscCall(PetscViewerGetFormat(viewer, &format));
243   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat_ceed), &size));
244   if (size == 1 && format == PETSC_VIEWER_LOAD_BALANCE) PetscFunctionReturn(PETSC_SUCCESS);
245 
246   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &is_ascii));
247   {
248     FILE *file;
249 
250     PetscCall(PetscViewerASCIIPrintf(viewer, "MatCEED:\n  Default COO MatType:%s\n", ctx->coo_mat_type));
251     PetscCall(PetscViewerASCIIGetPointer(viewer, &file));
252     PetscCall(PetscViewerASCIIPrintf(viewer, " libCEED Operator:\n"));
253     PetscCallCeed(ctx->ceed, CeedOperatorView(ctx->op_mult, file));
254     if (ctx->op_mult_transpose) {
255       PetscCall(PetscViewerASCIIPrintf(viewer, "  libCEED Transpose Operator:\n"));
256       PetscCallCeed(ctx->ceed, CeedOperatorView(ctx->op_mult_transpose, file));
257     }
258   }
259   PetscFunctionReturn(PETSC_SUCCESS);
260 }
261 
262 // -----------------------------------------------------------------------------
263 // MatCeed
264 // -----------------------------------------------------------------------------
265 
266 /**
267   @brief Create PETSc `Mat` from libCEED operators.
268 
269   Collective across MPI processes.
270 
271   @param[in]   dm_x                      Input `DM`
272   @param[in]   dm_y                      Output `DM`
273   @param[in]   op_mult                   `CeedOperator` for forward evaluation
274   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
275   @param[out]  mat                        New MatCeed
276 
277   @return An error code: 0 - success, otherwise - failure
278 **/
279 PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) {
280   PetscInt       X_l_size, X_g_size, Y_l_size, Y_g_size;
281   VecType        vec_type;
282   MatCeedContext ctx;
283 
284   PetscFunctionBeginUser;
285   PetscCall(MatCeedRegisterLogEvents());
286 
287   // Collect context data
288   PetscCall(DMGetVecType(dm_x, &vec_type));
289   {
290     Vec X;
291 
292     PetscCall(DMGetGlobalVector(dm_x, &X));
293     PetscCall(VecGetSize(X, &X_g_size));
294     PetscCall(VecGetLocalSize(X, &X_l_size));
295     PetscCall(DMRestoreGlobalVector(dm_x, &X));
296   }
297   if (dm_y) {
298     Vec Y;
299 
300     PetscCall(DMGetGlobalVector(dm_y, &Y));
301     PetscCall(VecGetSize(Y, &Y_g_size));
302     PetscCall(VecGetLocalSize(Y, &Y_l_size));
303     PetscCall(DMRestoreGlobalVector(dm_y, &Y));
304   } else {
305     dm_y     = dm_x;
306     Y_g_size = X_g_size;
307     Y_l_size = X_l_size;
308   }
309 
310   // Create context
311   {
312     Vec X_loc, Y_loc_transpose = NULL;
313 
314     PetscCall(DMCreateLocalVector(dm_x, &X_loc));
315     PetscCall(VecZeroEntries(X_loc));
316     if (op_mult_transpose) {
317       PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose));
318       PetscCall(VecZeroEntries(Y_loc_transpose));
319     }
320     PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx));
321     PetscCall(VecDestroy(&X_loc));
322     PetscCall(VecDestroy(&Y_loc_transpose));
323   }
324 
325   // Create mat
326   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat));
327   PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED));
328   // -- Set block and variable block sizes
329   if (dm_x == dm_y) {
330     MatType dm_mat_type, dm_mat_type_copy;
331     Mat     temp_mat;
332 
333     PetscCall(DMGetMatType(dm_x, &dm_mat_type));
334     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
335     PetscCall(DMSetMatType(dm_x, MATAIJ));
336     PetscCall(DMCreateMatrix(dm_x, &temp_mat));
337     PetscCall(DMSetMatType(dm_x, dm_mat_type_copy));
338     PetscCall(PetscFree(dm_mat_type_copy));
339 
340     {
341       PetscInt        block_size, num_blocks, max_vblock_size = PETSC_INT_MAX;
342       const PetscInt *vblock_sizes;
343 
344       // -- Get block sizes
345       PetscCall(MatGetBlockSize(temp_mat, &block_size));
346       PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes));
347       {
348         PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX};
349 
350         for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]);
351         PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max));
352         max_vblock_size = global_min_max[1];
353       }
354 
355       // -- Copy block sizes
356       if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size));
357       if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes));
358 
359       // -- Check libCEED compatibility
360       {
361         bool is_composite;
362 
363         ctx->is_ceed_pbd_valid  = PETSC_TRUE;
364         ctx->is_ceed_vpbd_valid = PETSC_TRUE;
365         PetscCallCeed(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite));
366         if (is_composite) {
367           CeedInt       num_sub_operators;
368           CeedOperator *sub_operators;
369 
370           PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators));
371           PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators));
372           for (CeedInt i = 0; i < num_sub_operators; i++) {
373             CeedInt                  num_bases, num_comp;
374             CeedBasis               *active_bases;
375             CeedOperatorAssemblyData assembly_data;
376 
377             PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data));
378             PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
379             PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
380             if (num_bases > 1) {
381               ctx->is_ceed_pbd_valid  = PETSC_FALSE;
382               ctx->is_ceed_vpbd_valid = PETSC_FALSE;
383             }
384             if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
385             if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
386           }
387         } else {
388           // LCOV_EXCL_START
389           CeedInt                  num_bases, num_comp;
390           CeedBasis               *active_bases;
391           CeedOperatorAssemblyData assembly_data;
392 
393           PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data));
394           PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
395           PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
396           if (num_bases > 1) {
397             ctx->is_ceed_pbd_valid  = PETSC_FALSE;
398             ctx->is_ceed_vpbd_valid = PETSC_FALSE;
399           }
400           if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
401           if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
402           // LCOV_EXCL_STOP
403         }
404         {
405           PetscInt local_is_valid[2], global_is_valid[2];
406 
407           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid;
408           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
409           ctx->is_ceed_pbd_valid = global_is_valid[0];
410           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid;
411           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
412           ctx->is_ceed_vpbd_valid = global_is_valid[0];
413         }
414       }
415     }
416     PetscCall(MatDestroy(&temp_mat));
417   }
418   // -- Set internal mat type
419   {
420     VecType vec_type;
421     MatType coo_mat_type;
422 
423     PetscCall(VecGetType(ctx->X_loc, &vec_type));
424     if (strstr(vec_type, VECCUDA)) coo_mat_type = MATAIJCUSPARSE;
425     else if (strstr(vec_type, VECKOKKOS)) coo_mat_type = MATAIJKOKKOS;
426     else coo_mat_type = MATAIJ;
427     PetscCall(PetscStrallocpy(coo_mat_type, &ctx->coo_mat_type));
428   }
429   // -- Set mat operations
430   PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
431   PetscCall(MatShellSetOperation(*mat, MATOP_VIEW, (void (*)(void))MatView_Ceed));
432   PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed));
433   if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
434   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
435   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
436   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
437   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
438   PetscCall(MatShellSetVecType(*mat, vec_type));
439   PetscFunctionReturn(PETSC_SUCCESS);
440 }
441 
442 /**
443   @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`.
444 
445   Collective across MPI processes.
446 
447   @param[in]   mat_ceed   `MATCEED` to copy from
448   @param[out]  mat_other  `MatShell` or `MATCEED` to copy into
449 
450   @return An error code: 0 - success, otherwise - failure
451 **/
452 PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) {
453   PetscFunctionBeginUser;
454   PetscCall(MatCeedRegisterLogEvents());
455 
456   // Check type compatibility
457   {
458     PetscBool is_matceed = PETSC_FALSE, is_matshell = PETSC_FALSE;
459     MatType   mat_type_ceed, mat_type_other;
460 
461     PetscCall(MatGetType(mat_ceed, &mat_type_ceed));
462     PetscCall(PetscStrcmp(mat_type_ceed, MATCEED, &is_matceed));
463     PetscCheck(is_matceed, PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED);
464     PetscCall(MatGetType(mat_other, &mat_type_other));
465     PetscCall(PetscStrcmp(mat_type_other, MATCEED, &is_matceed));
466     PetscCall(PetscStrcmp(mat_type_other, MATSHELL, &is_matceed));
467     PetscCheck(is_matceed || is_matshell, PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_other must have type " MATCEED " or " MATSHELL);
468   }
469 
470   // Check dimension compatibility
471   {
472     PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size;
473 
474     PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size));
475     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size));
476     PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size));
477     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size));
478     PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) &&
479                    (X_l_ceed_size == X_l_other_size),
480                PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ,
481                "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT
482                "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT
483                ", %" PetscInt_FMT ")",
484                Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size);
485   }
486 
487   // Convert
488   {
489     VecType        vec_type;
490     MatCeedContext ctx;
491 
492     PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED));
493     PetscCall(MatShellGetContext(mat_ceed, &ctx));
494     PetscCall(MatCeedContextReference(ctx));
495     PetscCall(MatShellSetContext(mat_other, ctx));
496     PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
497     PetscCall(MatShellSetOperation(mat_other, MATOP_VIEW, (void (*)(void))MatView_Ceed));
498     PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed));
499     if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
500     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
501     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
502     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
503     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
504     {
505       PetscInt block_size;
506 
507       PetscCall(MatGetBlockSize(mat_ceed, &block_size));
508       if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size));
509     }
510     {
511       PetscInt        num_blocks;
512       const PetscInt *block_sizes;
513 
514       PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes));
515       if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes));
516     }
517     PetscCall(DMGetVecType(ctx->dm_x, &vec_type));
518     PetscCall(MatShellSetVecType(mat_other, vec_type));
519   }
520   PetscFunctionReturn(PETSC_SUCCESS);
521 }
522 
523 /**
524   @brief Setup a `Mat` with the same COO pattern as a `MatCEED`.
525 
526   Collective across MPI processes.
527 
528   @param[in]   mat_ceed  `MATCEED`
529   @param[out]  mat_coo   Sparse `Mat` with same COO pattern
530 
531   @return An error code: 0 - success, otherwise - failure
532 **/
533 PetscErrorCode MatCeedCreateMatCOO(Mat mat_ceed, Mat *mat_coo) {
534   MatCeedContext ctx;
535 
536   PetscFunctionBeginUser;
537   PetscCall(MatShellGetContext(mat_ceed, &ctx));
538 
539   PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "COO assembly only supported for MATCEED on a single DM");
540 
541   // Check cl mat type
542   {
543     PetscBool is_coo_mat_type_cl = PETSC_FALSE;
544     char      coo_mat_type_cl[64];
545 
546     // Check for specific CL coo mat type for this Mat
547     {
548       const char *mat_ceed_prefix = NULL;
549 
550       PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix));
551       PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL);
552       PetscCall(PetscOptionsFList("-ceed_coo_mat_type", "Default MATCEED COO assembly MatType", NULL, MatList, coo_mat_type_cl, coo_mat_type_cl,
553                                   sizeof(coo_mat_type_cl), &is_coo_mat_type_cl));
554       PetscOptionsEnd();
555       if (is_coo_mat_type_cl) {
556         PetscCall(PetscFree(ctx->coo_mat_type));
557         PetscCall(PetscStrallocpy(coo_mat_type_cl, &ctx->coo_mat_type));
558       }
559     }
560   }
561 
562   // Create sparse matrix
563   {
564     MatType dm_mat_type, dm_mat_type_copy;
565 
566     PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type));
567     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
568     PetscCall(DMSetMatType(ctx->dm_x, ctx->coo_mat_type));
569     PetscCall(DMCreateMatrix(ctx->dm_x, mat_coo));
570     PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy));
571     PetscCall(PetscFree(dm_mat_type_copy));
572   }
573   PetscFunctionReturn(PETSC_SUCCESS);
574 }
575 
576 /**
577   @brief Setup the COO preallocation `MATCEED` into a `MATAIJ` or similar.
578          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
579 
580   Collective across MPI processes.
581 
582   @param[in]      mat_ceed  `MATCEED` to assemble
583   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
584 
585   @return An error code: 0 - success, otherwise - failure
586 **/
587 PetscErrorCode MatCeedSetPreallocationCOO(Mat mat_ceed, Mat mat_coo) {
588   MatCeedContext ctx;
589 
590   PetscFunctionBeginUser;
591   PetscCall(MatShellGetContext(mat_ceed, &ctx));
592 
593   {
594     PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
595     CeedInt      *rows_ceed, *cols_ceed;
596     PetscCount    num_entries;
597     PetscLogStage stage_amg_setup;
598 
599     // -- Assemble sparsity pattern if mat hasn't been assembled before
600     PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
601     if (stage_amg_setup == -1) {
602       PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
603     }
604     PetscCall(PetscLogStagePush(stage_amg_setup));
605     PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
606     PetscCall(IntArrayCeedToPetsc(num_entries, &rows_ceed, &rows_petsc));
607     PetscCall(IntArrayCeedToPetsc(num_entries, &cols_ceed, &cols_petsc));
608     PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
609     free(rows_petsc);
610     free(cols_petsc);
611     if (!ctx->coo_values_full) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full));
612     PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full));
613     ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo;
614     PetscCall(PetscLogStagePop());
615   }
616   PetscFunctionReturn(PETSC_SUCCESS);
617 }
618 
619 /**
620   @brief Assemble a `MATCEED` into a `MATAIJ` or similar.
621          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
622          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
623 
624   Collective across MPI processes.
625 
626   @param[in]      mat_ceed  `MATCEED` to assemble
627   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
628 
629   @return An error code: 0 - success, otherwise - failure
630 **/
631 PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) {
632   MatCeedContext ctx;
633 
634   PetscFunctionBeginUser;
635   PetscCall(MatShellGetContext(mat_ceed, &ctx));
636 
637   // Set COO pattern if needed
638   {
639     CeedInt index = -1;
640 
641     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
642       if (ctx->mats_assembled_full[i] == mat_coo) index = i;
643     }
644     if (index == -1) PetscCall(MatCeedSetPreallocationCOO(mat_ceed, mat_coo));
645   }
646 
647   // Assemble mat_ceed
648   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
649   {
650     const CeedScalar *values;
651     MatType           mat_type;
652     CeedMemType       mem_type = CEED_MEM_HOST;
653     PetscBool         is_spd, is_spd_known;
654 
655     PetscCall(MatGetType(mat_coo, &mat_type));
656     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
657     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
658     else mem_type = CEED_MEM_HOST;
659 
660     PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full));
661     PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values));
662     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
663     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
664     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
665     PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values));
666   }
667   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
668   PetscFunctionReturn(PETSC_SUCCESS);
669 }
670 
671 /**
672   @brief Set the current value of a context field for a `MatCEED`.
673 
674   Not collective across MPI processes.
675 
676   @param[in,out]  mat    `MatCEED`
677   @param[in]      name   Name of the context field
678   @param[in]      value  New context field value
679 
680   @return An error code: 0 - success, otherwise - failure
681 **/
682 PetscErrorCode MatCeedSetContextDouble(Mat mat, const char *name, double value) {
683   PetscBool      was_updated = PETSC_FALSE;
684   MatCeedContext ctx;
685 
686   PetscFunctionBeginUser;
687   PetscCall(MatShellGetContext(mat, &ctx));
688   {
689     CeedContextFieldLabel label = NULL;
690 
691     PetscCallCeed(ctx->ceed, CeedOperatorGetContextFieldLabel(ctx->op_mult, name, &label));
692     if (label) {
693       PetscCallCeed(ctx->ceed, CeedOperatorSetContextDouble(ctx->op_mult, label, &value));
694       was_updated = PETSC_TRUE;
695     }
696     if (ctx->op_mult_transpose) {
697       label = NULL;
698       PetscCallCeed(ctx->ceed, CeedOperatorGetContextFieldLabel(ctx->op_mult_transpose, name, &label));
699       if (label) {
700         PetscCallCeed(ctx->ceed, CeedOperatorSetContextDouble(ctx->op_mult_transpose, label, &value));
701         was_updated = PETSC_TRUE;
702       }
703     }
704   }
705   if (was_updated) {
706     PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
707     PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
708   }
709   PetscFunctionReturn(PETSC_SUCCESS);
710 }
711 
712 /**
713   @brief Get the current value of a context field for a `MatCEED`.
714 
715   Not collective across MPI processes.
716 
717   @param[in]   mat    `MatCEED`
718   @param[in]   name   Name of the context field
719   @param[out]  value  Current context field value
720 
721   @return An error code: 0 - success, otherwise - failure
722 **/
723 PetscErrorCode MatCeedGetContextDouble(Mat mat, const char *name, double *value) {
724   MatCeedContext ctx;
725 
726   PetscFunctionBeginUser;
727   PetscCall(MatShellGetContext(mat, &ctx));
728   {
729     CeedContextFieldLabel label = NULL;
730     CeedOperator          op    = ctx->op_mult;
731 
732     PetscCallCeed(ctx->ceed, CeedOperatorGetContextFieldLabel(op, name, &label));
733     if (!label && ctx->op_mult_transpose) {
734       op = ctx->op_mult_transpose;
735       PetscCallCeed(ctx->ceed, CeedOperatorGetContextFieldLabel(op, name, &label));
736     }
737     if (label) {
738       PetscSizeT    num_values;
739       const double *values_ceed;
740 
741       PetscCallCeed(ctx->ceed, CeedOperatorGetContextDoubleRead(op, label, &num_values, &values_ceed));
742       *value = values_ceed[0];
743       PetscCallCeed(ctx->ceed, CeedOperatorRestoreContextDoubleRead(op, label, &values_ceed));
744     }
745   }
746   PetscFunctionReturn(PETSC_SUCCESS);
747 }
748 
749 /**
750   @brief Set user context for a `MATCEED`.
751 
752   Collective across MPI processes.
753 
754   @param[in,out]  mat  `MATCEED`
755   @param[in]      f    The context destroy function, or NULL
756   @param[in]      ctx  User context, or NULL to unset
757 
758   @return An error code: 0 - success, otherwise - failure
759 **/
760 PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) {
761   PetscContainer user_ctx = NULL;
762 
763   PetscFunctionBeginUser;
764   if (ctx) {
765     PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx));
766     PetscCall(PetscContainerSetPointer(user_ctx, ctx));
767     PetscCall(PetscContainerSetUserDestroy(user_ctx, f));
768   }
769   PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx));
770   PetscCall(PetscContainerDestroy(&user_ctx));
771   PetscFunctionReturn(PETSC_SUCCESS);
772 }
773 
774 /**
775   @brief Retrieve the user context for a `MATCEED`.
776 
777   Collective across MPI processes.
778 
779   @param[in,out]  mat  `MATCEED`
780   @param[in]      ctx  User context
781 
782   @return An error code: 0 - success, otherwise - failure
783 **/
784 PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) {
785   PetscContainer user_ctx;
786 
787   PetscFunctionBeginUser;
788   PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx));
789   if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx));
790   else *(void **)ctx = NULL;
791   PetscFunctionReturn(PETSC_SUCCESS);
792 }
793 /**
794   @brief Set a user defined matrix operation for a `MATCEED` matrix.
795 
796   Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by
797 `MatCeedSetContext()`.
798 
799   Collective across MPI processes.
800 
801   @param[in,out]  mat  `MATCEED`
802   @param[in]      op   Name of the `MatOperation`
803   @param[in]      g    Function that provides the operation
804 
805   @return An error code: 0 - success, otherwise - failure
806 **/
807 PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) {
808   PetscFunctionBeginUser;
809   PetscCall(MatShellSetOperation(mat, op, g));
810   PetscFunctionReturn(PETSC_SUCCESS);
811 }
812 
813 /**
814   @brief Sets the default COO matrix type as a string from the `MATCEED`.
815 
816   Collective across MPI processes.
817 
818   @param[in,out]  mat   `MATCEED`
819   @param[in]      type  COO `MatType` to set
820 
821   @return An error code: 0 - success, otherwise - failure
822 **/
823 PetscErrorCode MatCeedSetCOOMatType(Mat mat, MatType type) {
824   MatCeedContext ctx;
825 
826   PetscFunctionBeginUser;
827   PetscCall(MatShellGetContext(mat, &ctx));
828   // Check if same
829   {
830     size_t    len_old, len_new;
831     PetscBool is_same = PETSC_FALSE;
832 
833     PetscCall(PetscStrlen(ctx->coo_mat_type, &len_old));
834     PetscCall(PetscStrlen(type, &len_new));
835     if (len_old == len_new) PetscCall(PetscStrncmp(ctx->coo_mat_type, type, len_old, &is_same));
836     if (is_same) PetscFunctionReturn(PETSC_SUCCESS);
837   }
838   // Clean up old mats in different format
839   // LCOV_EXCL_START
840   if (ctx->mat_assembled_full_internal) {
841     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
842       if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) {
843         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) {
844           ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j];
845         }
846         ctx->num_mats_assembled_full--;
847         // Note: we'll realloc this array again, so no need to shrink the allocation
848         PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
849       }
850     }
851   }
852   if (ctx->mat_assembled_pbd_internal) {
853     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
854       if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) {
855         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) {
856           ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j];
857         }
858         // Note: we'll realloc this array again, so no need to shrink the allocation
859         ctx->num_mats_assembled_pbd--;
860         PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
861       }
862     }
863   }
864   PetscCall(PetscFree(ctx->coo_mat_type));
865   PetscCall(PetscStrallocpy(type, &ctx->coo_mat_type));
866   PetscFunctionReturn(PETSC_SUCCESS);
867   // LCOV_EXCL_STOP
868 }
869 
870 /**
871   @brief Gets the default COO matrix type as a string from the `MATCEED`.
872 
873   Collective across MPI processes.
874 
875   @param[in,out]  mat   `MATCEED`
876   @param[in]      type  COO `MatType`
877 
878   @return An error code: 0 - success, otherwise - failure
879 **/
880 PetscErrorCode MatCeedGetCOOMatType(Mat mat, MatType *type) {
881   MatCeedContext ctx;
882 
883   PetscFunctionBeginUser;
884   PetscCall(MatShellGetContext(mat, &ctx));
885   *type = ctx->coo_mat_type;
886   PetscFunctionReturn(PETSC_SUCCESS);
887 }
888 
889 /**
890   @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
891 
892   Not collective across MPI processes.
893 
894   @param[in,out]  mat              `MATCEED`
895   @param[in]      X_loc            Input PETSc local vector, or NULL
896   @param[in]      Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
897 
898   @return An error code: 0 - success, otherwise - failure
899 **/
900 PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) {
901   MatCeedContext ctx;
902 
903   PetscFunctionBeginUser;
904   PetscCall(MatShellGetContext(mat, &ctx));
905   if (X_loc) {
906     PetscInt len_old, len_new;
907 
908     PetscCall(VecGetSize(ctx->X_loc, &len_old));
909     PetscCall(VecGetSize(X_loc, &len_new));
910     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT,
911                len_new, len_old);
912     PetscCall(VecReferenceCopy(X_loc, &ctx->X_loc));
913   }
914   if (Y_loc_transpose) {
915     PetscInt len_old, len_new;
916 
917     PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old));
918     PetscCall(VecGetSize(Y_loc_transpose, &len_new));
919     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB,
920                "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old);
921     PetscCall(VecReferenceCopy(Y_loc_transpose, &ctx->Y_loc_transpose));
922   }
923   PetscFunctionReturn(PETSC_SUCCESS);
924 }
925 
926 /**
927   @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
928 
929   Not collective across MPI processes.
930 
931   @param[in,out]  mat              `MATCEED`
932   @param[out]     X_loc            Input PETSc local vector, or NULL
933   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
934 
935   @return An error code: 0 - success, otherwise - failure
936 **/
937 PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
938   MatCeedContext ctx;
939 
940   PetscFunctionBeginUser;
941   PetscCall(MatShellGetContext(mat, &ctx));
942   if (X_loc) {
943     *X_loc = NULL;
944     PetscCall(VecReferenceCopy(ctx->X_loc, X_loc));
945   }
946   if (Y_loc_transpose) {
947     *Y_loc_transpose = NULL;
948     PetscCall(VecReferenceCopy(ctx->Y_loc_transpose, Y_loc_transpose));
949   }
950   PetscFunctionReturn(PETSC_SUCCESS);
951 }
952 
953 /**
954   @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
955 
956   Not collective across MPI processes.
957 
958   @param[in,out]  mat              MatCeed
959   @param[out]     X_loc            Input PETSc local vector, or NULL
960   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
961 
962   @return An error code: 0 - success, otherwise - failure
963 **/
964 PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
965   PetscFunctionBeginUser;
966   if (X_loc) PetscCall(VecDestroy(X_loc));
967   if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose));
968   PetscFunctionReturn(PETSC_SUCCESS);
969 }
970 
971 /**
972   @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
973 
974   Not collective across MPI processes.
975 
976   @param[in,out]  mat                MatCeed
977   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
978   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
979 
980   @return An error code: 0 - success, otherwise - failure
981 **/
982 PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
983   MatCeedContext ctx;
984 
985   PetscFunctionBeginUser;
986   PetscCall(MatShellGetContext(mat, &ctx));
987   if (op_mult) {
988     *op_mult = NULL;
989     PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult));
990   }
991   if (op_mult_transpose) {
992     *op_mult_transpose = NULL;
993     PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose));
994   }
995   PetscFunctionReturn(PETSC_SUCCESS);
996 }
997 
998 /**
999   @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
1000 
1001   Not collective across MPI processes.
1002 
1003   @param[in,out]  mat                MatCeed
1004   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
1005   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
1006 
1007   @return An error code: 0 - success, otherwise - failure
1008 **/
1009 PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
1010   MatCeedContext ctx;
1011 
1012   PetscFunctionBeginUser;
1013   PetscCall(MatShellGetContext(mat, &ctx));
1014   if (op_mult) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult));
1015   if (op_mult_transpose) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult_transpose));
1016   PetscFunctionReturn(PETSC_SUCCESS);
1017 }
1018 
1019 /**
1020   @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
1021 
1022   Not collective across MPI processes.
1023 
1024   @param[in,out]  mat                       MatCeed
1025   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
1026   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
1027 
1028   @return An error code: 0 - success, otherwise - failure
1029 **/
1030 PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) {
1031   MatCeedContext ctx;
1032 
1033   PetscFunctionBeginUser;
1034   PetscCall(MatShellGetContext(mat, &ctx));
1035   if (log_event_mult) ctx->log_event_mult = log_event_mult;
1036   if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose;
1037   PetscFunctionReturn(PETSC_SUCCESS);
1038 }
1039 
1040 /**
1041   @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
1042 
1043   Not collective across MPI processes.
1044 
1045   @param[in,out]  mat                       MatCeed
1046   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
1047   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
1048 
1049   @return An error code: 0 - success, otherwise - failure
1050 **/
1051 PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) {
1052   MatCeedContext ctx;
1053 
1054   PetscFunctionBeginUser;
1055   PetscCall(MatShellGetContext(mat, &ctx));
1056   if (log_event_mult) *log_event_mult = ctx->log_event_mult;
1057   if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose;
1058   PetscFunctionReturn(PETSC_SUCCESS);
1059 }
1060 
1061 // -----------------------------------------------------------------------------
1062 // Operator context data
1063 // -----------------------------------------------------------------------------
1064 
1065 /**
1066   @brief Setup context data for operator application.
1067 
1068   Collective across MPI processes.
1069 
1070   @param[in]   dm_x                      Input `DM`
1071   @param[in]   dm_y                      Output `DM`
1072   @param[in]   X_loc                     Input PETSc local vector, or NULL
1073   @param[in]   Y_loc_transpose           Input PETSc local vector for transpose operation, or NULL
1074   @param[in]   op_mult                   `CeedOperator` for forward evaluation
1075   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
1076   @param[in]   log_event_mult            `PetscLogEvent` for forward evaluation
1077   @param[in]   log_event_mult_transpose  `PetscLogEvent` for transpose evaluation
1078   @param[out]  ctx                       Context data for operator evaluation
1079 
1080   @return An error code: 0 - success, otherwise - failure
1081 **/
1082 PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose,
1083                                     PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) {
1084   CeedSize x_loc_len, y_loc_len;
1085 
1086   PetscFunctionBeginUser;
1087 
1088   // Allocate
1089   PetscCall(PetscNew(ctx));
1090   (*ctx)->ref_count = 1;
1091 
1092   // Logging
1093   (*ctx)->log_event_mult           = log_event_mult;
1094   (*ctx)->log_event_mult_transpose = log_event_mult_transpose;
1095 
1096   // PETSc objects
1097   PetscCall(DMReferenceCopy(dm_x, &(*ctx)->dm_x));
1098   PetscCall(DMReferenceCopy(dm_y, &(*ctx)->dm_y));
1099   if (X_loc) PetscCall(VecReferenceCopy(X_loc, &(*ctx)->X_loc));
1100   if (Y_loc_transpose) PetscCall(VecReferenceCopy(Y_loc_transpose, &(*ctx)->Y_loc_transpose));
1101 
1102   // Memtype
1103   {
1104     const PetscScalar *x;
1105     Vec                X;
1106 
1107     PetscCall(DMGetLocalVector(dm_x, &X));
1108     PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type));
1109     PetscCall(VecRestoreArrayReadAndMemType(X, &x));
1110     PetscCall(DMRestoreLocalVector(dm_x, &X));
1111   }
1112 
1113   // libCEED objects
1114   PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB,
1115              "retrieving Ceed context object failed");
1116   PetscCallCeed((*ctx)->ceed, CeedReference((*ctx)->ceed));
1117   PetscCallCeed((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len));
1118   PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult));
1119   if (op_mult_transpose) PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose));
1120   PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc));
1121   PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc));
1122 
1123   // Flop counting
1124   {
1125     CeedSize ceed_flops_estimate = 0;
1126 
1127     PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate));
1128     (*ctx)->flops_mult = ceed_flops_estimate;
1129     if (op_mult_transpose) {
1130       PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate));
1131       (*ctx)->flops_mult_transpose = ceed_flops_estimate;
1132     }
1133   }
1134 
1135   // Check sizes
1136   if (x_loc_len > 0 || y_loc_len > 0) {
1137     CeedSize ctx_x_loc_len, ctx_y_loc_len;
1138     PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len;
1139     Vec      dm_X_loc, dm_Y_loc;
1140 
1141     // -- Input
1142     PetscCall(DMGetLocalVector(dm_x, &dm_X_loc));
1143     PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len));
1144     PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc));
1145     PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len));
1146     if (X_loc) {
1147       PetscCall(VecGetLocalSize(X_loc, &X_loc_len));
1148       PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB,
1149                  "X_loc (%" PetscInt_FMT ") must match dm_x (%" PetscInt_FMT ") dimensions", X_loc_len, dm_x_loc_len);
1150     }
1151     PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op (%" CeedSize_FMT ") must match dm_x (%" PetscInt_FMT ") dimensions",
1152                x_loc_len, dm_x_loc_len);
1153     PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc (%" CeedSize_FMT ") must match op dimensions (%" CeedSize_FMT ")",
1154                x_loc_len, ctx_x_loc_len);
1155 
1156     // -- Output
1157     PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc));
1158     PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len));
1159     PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc));
1160     PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len));
1161     PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op (%" CeedSize_FMT ") must match dm_y (%" PetscInt_FMT ") dimensions",
1162                ctx_y_loc_len, dm_y_loc_len);
1163 
1164     // -- Transpose
1165     if (Y_loc_transpose) {
1166       PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len));
1167       PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB,
1168                  "Y_loc_transpose (%" PetscInt_FMT ") must match dm_y (%" PetscInt_FMT ") dimensions", Y_loc_len, dm_y_loc_len);
1169     }
1170   }
1171   PetscFunctionReturn(PETSC_SUCCESS);
1172 }
1173 
1174 /**
1175   @brief Increment reference counter for `MATCEED` context.
1176 
1177   Not collective across MPI processes.
1178 
1179   @param[in,out]  ctx  Context data
1180 
1181   @return An error code: 0 - success, otherwise - failure
1182 **/
1183 PetscErrorCode MatCeedContextReference(MatCeedContext ctx) {
1184   PetscFunctionBeginUser;
1185   ctx->ref_count++;
1186   PetscFunctionReturn(PETSC_SUCCESS);
1187 }
1188 
1189 /**
1190   @brief Copy reference for `MATCEED`.
1191          Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`.
1192 
1193   Not collective across MPI processes.
1194 
1195   @param[in]   ctx       Context data
1196   @param[out]  ctx_copy  Copy of pointer to context data
1197 
1198   @return An error code: 0 - success, otherwise - failure
1199 **/
1200 PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) {
1201   PetscFunctionBeginUser;
1202   PetscCall(MatCeedContextReference(ctx));
1203   PetscCall(MatCeedContextDestroy(*ctx_copy));
1204   *ctx_copy = ctx;
1205   PetscFunctionReturn(PETSC_SUCCESS);
1206 }
1207 
1208 /**
1209   @brief Destroy context data for operator application.
1210 
1211   Collective across MPI processes.
1212 
1213   @param[in,out]  ctx  Context data for operator evaluation
1214 
1215   @return An error code: 0 - success, otherwise - failure
1216 **/
1217 PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) {
1218   PetscFunctionBeginUser;
1219   if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS);
1220 
1221   // PETSc objects
1222   PetscCall(DMDestroy(&ctx->dm_x));
1223   PetscCall(DMDestroy(&ctx->dm_y));
1224   PetscCall(VecDestroy(&ctx->X_loc));
1225   PetscCall(VecDestroy(&ctx->Y_loc_transpose));
1226   PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
1227   PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
1228   PetscCall(PetscFree(ctx->coo_mat_type));
1229   PetscCall(PetscFree(ctx->mats_assembled_full));
1230   PetscCall(PetscFree(ctx->mats_assembled_pbd));
1231 
1232   // libCEED objects
1233   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->x_loc));
1234   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->y_loc));
1235   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full));
1236   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd));
1237   PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult));
1238   PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose));
1239   PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed");
1240 
1241   // Deallocate
1242   ctx->is_destroyed = PETSC_TRUE;  // Flag as destroyed in case someone has stale ref
1243   PetscCall(PetscFree(ctx));
1244   PetscFunctionReturn(PETSC_SUCCESS);
1245 }
1246 
1247 /**
1248   @brief Compute the diagonal of an operator via libCEED.
1249 
1250   Collective across MPI processes.
1251 
1252   @param[in]   A  `MATCEED`
1253   @param[out]  D  Vector holding operator diagonal
1254 
1255   @return An error code: 0 - success, otherwise - failure
1256 **/
1257 PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) {
1258   PetscMemType   mem_type;
1259   Vec            D_loc;
1260   MatCeedContext ctx;
1261 
1262   PetscFunctionBeginUser;
1263   PetscCall(MatShellGetContext(A, &ctx));
1264 
1265   // Place PETSc vector in libCEED vector
1266   PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc));
1267   PetscCall(VecPetscToCeed(D_loc, &mem_type, ctx->x_loc));
1268 
1269   // Compute Diagonal
1270   PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1271 
1272   // Restore PETSc vector
1273   PetscCall(VecCeedToPetsc(ctx->x_loc, mem_type, D_loc));
1274 
1275   // Local-to-Global
1276   PetscCall(VecZeroEntries(D));
1277   PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D));
1278   PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc));
1279   PetscFunctionReturn(PETSC_SUCCESS);
1280 }
1281 
1282 /**
1283   @brief Compute `A X = Y` for a `MATCEED`.
1284 
1285   Collective across MPI processes.
1286 
1287   @param[in]   A  `MATCEED`
1288   @param[in]   X  Input PETSc vector
1289   @param[out]  Y  Output PETSc vector
1290 
1291   @return An error code: 0 - success, otherwise - failure
1292 **/
1293 PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) {
1294   MatCeedContext ctx;
1295 
1296   PetscFunctionBeginUser;
1297   PetscCall(MatShellGetContext(A, &ctx));
1298   PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0));
1299 
1300   {
1301     PetscMemType x_mem_type, y_mem_type;
1302     Vec          X_loc = ctx->X_loc, Y_loc;
1303 
1304     // Get local vectors
1305     if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1306     PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1307 
1308     // Global-to-local
1309     PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc));
1310 
1311     // Setup libCEED vectors
1312     PetscCall(VecReadPetscToCeed(X_loc, &x_mem_type, ctx->x_loc));
1313     PetscCall(VecZeroEntries(Y_loc));
1314     PetscCall(VecPetscToCeed(Y_loc, &y_mem_type, ctx->y_loc));
1315 
1316     // Apply libCEED operator
1317     PetscCall(PetscLogGpuTimeBegin());
1318     PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE));
1319     PetscCall(PetscLogGpuTimeEnd());
1320 
1321     // Restore PETSc vectors
1322     PetscCall(VecReadCeedToPetsc(ctx->x_loc, x_mem_type, X_loc));
1323     PetscCall(VecCeedToPetsc(ctx->y_loc, y_mem_type, Y_loc));
1324 
1325     // Local-to-global
1326     PetscCall(VecZeroEntries(Y));
1327     PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y));
1328 
1329     // Restore local vectors, as needed
1330     if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1331     PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1332   }
1333 
1334   // Log flops
1335   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult));
1336   else PetscCall(PetscLogFlops(ctx->flops_mult));
1337 
1338   PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0));
1339   PetscFunctionReturn(PETSC_SUCCESS);
1340 }
1341 
1342 /**
1343   @brief Compute `A^T Y = X` for a `MATCEED`.
1344 
1345   Collective across MPI processes.
1346 
1347   @param[in]   A  `MATCEED`
1348   @param[in]   Y  Input PETSc vector
1349   @param[out]  X  Output PETSc vector
1350 
1351   @return An error code: 0 - success, otherwise - failure
1352 **/
1353 PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) {
1354   MatCeedContext ctx;
1355 
1356   PetscFunctionBeginUser;
1357   PetscCall(MatShellGetContext(A, &ctx));
1358   PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0));
1359 
1360   {
1361     PetscMemType x_mem_type, y_mem_type;
1362     Vec          X_loc, Y_loc = ctx->Y_loc_transpose;
1363 
1364     // Get local vectors
1365     if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1366     PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1367 
1368     // Global-to-local
1369     PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc));
1370 
1371     // Setup libCEED vectors
1372     PetscCall(VecReadPetscToCeed(Y_loc, &y_mem_type, ctx->y_loc));
1373     PetscCall(VecZeroEntries(X_loc));
1374     PetscCall(VecPetscToCeed(X_loc, &x_mem_type, ctx->x_loc));
1375 
1376     // Apply libCEED operator
1377     PetscCall(PetscLogGpuTimeBegin());
1378     PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1379     PetscCall(PetscLogGpuTimeEnd());
1380 
1381     // Restore PETSc vectors
1382     PetscCall(VecReadCeedToPetsc(ctx->y_loc, y_mem_type, Y_loc));
1383     PetscCall(VecCeedToPetsc(ctx->x_loc, x_mem_type, X_loc));
1384 
1385     // Local-to-global
1386     PetscCall(VecZeroEntries(X));
1387     PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X));
1388 
1389     // Restore local vectors, as needed
1390     if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1391     PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1392   }
1393 
1394   // Log flops
1395   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose));
1396   else PetscCall(PetscLogFlops(ctx->flops_mult_transpose));
1397 
1398   PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0));
1399   PetscFunctionReturn(PETSC_SUCCESS);
1400 }
1401