xref: /libCEED/examples/fluids/src/mat-ceed.c (revision 5dd8ea6c4fae62f97ab98a3a828ca9f3db1c59f1)
1 /// @file
2 /// MatCeed and it's related operators
3 
4 #include <ceed-utils.h>
5 #include <ceed.h>
6 #include <ceed/backend.h>
7 #include <mat-ceed-impl.h>
8 #include <mat-ceed.h>
9 #include <petscdmplex.h>
10 #include <stdlib.h>
11 #include <string.h>
12 
13 PetscClassId  MATCEED_CLASSID;
14 PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE;
15 
16 /**
17   @brief Register MATCEED log events.
18 
19   Not collective across MPI processes.
20 
21   @return An error code: 0 - success, otherwise - failure
22 **/
23 static PetscErrorCode MatCeedRegisterLogEvents() {
24   static bool registered = false;
25 
26   PetscFunctionBeginUser;
27   if (registered) PetscFunctionReturn(PETSC_SUCCESS);
28   PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID));
29   PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT));
30   PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE));
31   registered = true;
32   PetscFunctionReturn(PETSC_SUCCESS);
33 }
34 
35 /**
36   @brief Setup inner `Mat` for `PC` operations not directly supported by libCEED.
37 
38   Collective across MPI processes.
39 
40   @param[in]   mat_ceed   `MATCEED` to setup
41   @param[out]  mat_inner  Inner `Mat`
42 
43   @return An error code: 0 - success, otherwise - failure
44 **/
45 static PetscErrorCode MatCeedSetupInnerMat(Mat mat_ceed, Mat *mat_inner) {
46   MatCeedContext ctx;
47 
48   PetscFunctionBeginUser;
49   PetscCall(MatShellGetContext(mat_ceed, &ctx));
50 
51   PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "PC only supported for MATCEED on a single DM");
52 
53   // Check cl mat type
54   {
55     PetscBool is_internal_mat_type_cl = PETSC_FALSE;
56     char      internal_mat_type_cl[64];
57 
58     // Check for specific CL inner mat type for this Mat
59     {
60       const char *mat_ceed_prefix = NULL;
61 
62       PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix));
63       PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL);
64       PetscCall(PetscOptionsFList("-ceed_inner_mat_type", "MATCEED inner assembled MatType for PC support", NULL, MatList, internal_mat_type_cl,
65                                   internal_mat_type_cl, sizeof(internal_mat_type_cl), &is_internal_mat_type_cl));
66       PetscOptionsEnd();
67       if (is_internal_mat_type_cl) {
68         PetscCall(PetscFree(ctx->internal_mat_type));
69         PetscCall(PetscStrallocpy(internal_mat_type_cl, &ctx->internal_mat_type));
70       }
71     }
72   }
73 
74   // Create sparse matrix
75   {
76     MatType dm_mat_type, dm_mat_type_copy;
77 
78     PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type));
79     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
80     PetscCall(DMSetMatType(ctx->dm_x, ctx->internal_mat_type));
81     PetscCall(DMCreateMatrix(ctx->dm_x, mat_inner));
82     PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy));
83     PetscCall(PetscFree(dm_mat_type_copy));
84   }
85   PetscFunctionReturn(PETSC_SUCCESS);
86 }
87 
88 /**
89   @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar.
90          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
91          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
92 
93   Collective across MPI processes.
94 
95   @param[in]      mat_ceed  `MATCEED` to assemble
96   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
97 
98   @return An error code: 0 - success, otherwise - failure
99 **/
100 static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) {
101   MatCeedContext ctx;
102 
103   PetscFunctionBeginUser;
104   PetscCall(MatShellGetContext(mat_ceed, &ctx));
105 
106   // Check if COO pattern set
107   {
108     PetscInt index = -1;
109 
110     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
111       if (ctx->mats_assembled_pbd[i] == mat_coo) index = i;
112     }
113     if (index == -1) {
114       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
115       CeedInt      *rows_ceed, *cols_ceed;
116       PetscCount    num_entries;
117       PetscLogStage stage_amg_setup;
118 
119       // -- Assemble sparsity pattern if mat hasn't been assembled before
120       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
121       if (stage_amg_setup == -1) {
122         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
123       }
124       PetscCall(PetscLogStagePush(stage_amg_setup));
125       PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
126       PetscCall(IntArrayCeedToPetsc(num_entries, &rows_ceed, &rows_petsc));
127       PetscCall(IntArrayCeedToPetsc(num_entries, &cols_ceed, &cols_petsc));
128       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
129       free(rows_petsc);
130       free(cols_petsc);
131       if (!ctx->coo_values_pbd) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd));
132       PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd));
133       ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo;
134       PetscCall(PetscLogStagePop());
135     }
136   }
137 
138   // Assemble mat_ceed
139   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
140   {
141     const CeedScalar *values;
142     MatType           mat_type;
143     CeedMemType       mem_type = CEED_MEM_HOST;
144     PetscBool         is_spd, is_spd_known;
145 
146     PetscCall(MatGetType(mat_coo, &mat_type));
147     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
148     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
149     else mem_type = CEED_MEM_HOST;
150 
151     PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE));
152     PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values));
153     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
154     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
155     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
156     PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values));
157   }
158   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
159   PetscFunctionReturn(PETSC_SUCCESS);
160 }
161 
162 /**
163   @brief Assemble inner `Mat` for diagonal `PC` operations
164 
165   Collective across MPI processes.
166 
167   @param[in]   mat_ceed      `MATCEED` to invert
168   @param[in]   use_ceed_pbd  Boolean flag to use libCEED PBD assembly
169   @param[out]  mat_inner     Inner `Mat` for diagonal operations
170 
171   @return An error code: 0 - success, otherwise - failure
172 **/
173 static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) {
174   MatCeedContext ctx;
175 
176   PetscFunctionBeginUser;
177   PetscCall(MatShellGetContext(mat_ceed, &ctx));
178   if (use_ceed_pbd) {
179     // Check if COO pattern set
180     if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_pbd_internal));
181 
182     // Assemble mat_assembled_full_internal
183     PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal));
184     if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal;
185   } else {
186     // Check if COO pattern set
187     if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_full_internal));
188 
189     // Assemble mat_assembled_full_internal
190     PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal));
191     if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal;
192   }
193   PetscFunctionReturn(PETSC_SUCCESS);
194 }
195 
196 /**
197   @brief Get `MATCEED` diagonal block for Jacobi.
198 
199   Collective across MPI processes.
200 
201   @param[in]   mat_ceed   `MATCEED` to invert
202   @param[out]  mat_block  The diagonal block matrix
203 
204   @return An error code: 0 - success, otherwise - failure
205 **/
206 static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) {
207   Mat            mat_inner = NULL;
208   MatCeedContext ctx;
209 
210   PetscFunctionBeginUser;
211   PetscCall(MatShellGetContext(mat_ceed, &ctx));
212 
213   // Assemble inner mat if needed
214   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
215 
216   // Get block diagonal
217   PetscCall(MatGetDiagonalBlock(mat_inner, mat_block));
218   PetscFunctionReturn(PETSC_SUCCESS);
219 }
220 
221 /**
222   @brief Invert `MATCEED` diagonal block for Jacobi.
223 
224   Collective across MPI processes.
225 
226   @param[in]   mat_ceed  `MATCEED` to invert
227   @param[out]  values    The block inverses in column major order
228 
229   @return An error code: 0 - success, otherwise - failure
230 **/
231 static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) {
232   Mat            mat_inner = NULL;
233   MatCeedContext ctx;
234 
235   PetscFunctionBeginUser;
236   PetscCall(MatShellGetContext(mat_ceed, &ctx));
237 
238   // Assemble inner mat if needed
239   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
240 
241   // Invert PB diagonal
242   PetscCall(MatInvertBlockDiagonal(mat_inner, values));
243   PetscFunctionReturn(PETSC_SUCCESS);
244 }
245 
246 /**
247   @brief Invert `MATCEED` variable diagonal block for Jacobi.
248 
249   Collective across MPI processes.
250 
251   @param[in]   mat_ceed     `MATCEED` to invert
252   @param[in]   num_blocks   The number of blocks on the process
253   @param[in]   block_sizes  The size of each block on the process
254   @param[out]  values       The block inverses in column major order
255 
256   @return An error code: 0 - success, otherwise - failure
257 **/
258 static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) {
259   Mat            mat_inner = NULL;
260   MatCeedContext ctx;
261 
262   PetscFunctionBeginUser;
263   PetscCall(MatShellGetContext(mat_ceed, &ctx));
264 
265   // Assemble inner mat if needed
266   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner));
267 
268   // Invert PB diagonal
269   PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values));
270   PetscFunctionReturn(PETSC_SUCCESS);
271 }
272 
273 /**
274   @brief View `MATCEED`.
275 
276   Collective across MPI processes.
277 
278   @param[in]   mat_ceed  `MATCEED` to view
279   @param[in]   viewer    The visualization context
280 
281   @return An error code: 0 - success, otherwise - failure
282 **/
283 static PetscErrorCode MatView_Ceed(Mat mat_ceed, PetscViewer viewer) {
284   PetscBool         is_ascii;
285   PetscViewerFormat format;
286   PetscMPIInt       size;
287   MatCeedContext    ctx;
288 
289   PetscFunctionBeginUser;
290   PetscValidHeaderSpecific(viewer, PETSC_VIEWER_CLASSID, 2);
291   PetscCall(MatShellGetContext(mat_ceed, &ctx));
292   if (!viewer) PetscCall(PetscViewerASCIIGetStdout(PetscObjectComm((PetscObject)mat_ceed), &viewer));
293 
294   PetscCall(PetscViewerGetFormat(viewer, &format));
295   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat_ceed), &size));
296   if (size == 1 && format == PETSC_VIEWER_LOAD_BALANCE) PetscFunctionReturn(PETSC_SUCCESS);
297 
298   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &is_ascii));
299   {
300     FILE *file;
301 
302     PetscCall(PetscViewerASCIIPrintf(viewer, "MatCEED:\n  Internal MatType:%s\n", ctx->internal_mat_type));
303     PetscCall(PetscViewerASCIIGetPointer(viewer, &file));
304     PetscCall(PetscViewerASCIIPrintf(viewer, " libCEED Operator:\n"));
305     PetscCallCeed(ctx->ceed, CeedOperatorView(ctx->op_mult, file));
306     if (ctx->op_mult_transpose) {
307       PetscCall(PetscViewerASCIIPrintf(viewer, "  libCEED Transpose Operator:\n"));
308       PetscCallCeed(ctx->ceed, CeedOperatorView(ctx->op_mult_transpose, file));
309     }
310   }
311   PetscFunctionReturn(PETSC_SUCCESS);
312 }
313 
314 // -----------------------------------------------------------------------------
315 // MatCeed
316 // -----------------------------------------------------------------------------
317 
318 /**
319   @brief Create PETSc `Mat` from libCEED operators.
320 
321   Collective across MPI processes.
322 
323   @param[in]   dm_x                      Input `DM`
324   @param[in]   dm_y                      Output `DM`
325   @param[in]   op_mult                   `CeedOperator` for forward evaluation
326   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
327   @param[out]  mat                        New MatCeed
328 
329   @return An error code: 0 - success, otherwise - failure
330 **/
331 PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) {
332   PetscInt       X_l_size, X_g_size, Y_l_size, Y_g_size;
333   VecType        vec_type;
334   MatCeedContext ctx;
335 
336   PetscFunctionBeginUser;
337   PetscCall(MatCeedRegisterLogEvents());
338 
339   // Collect context data
340   PetscCall(DMGetVecType(dm_x, &vec_type));
341   {
342     Vec X;
343 
344     PetscCall(DMGetGlobalVector(dm_x, &X));
345     PetscCall(VecGetSize(X, &X_g_size));
346     PetscCall(VecGetLocalSize(X, &X_l_size));
347     PetscCall(DMRestoreGlobalVector(dm_x, &X));
348   }
349   if (dm_y) {
350     Vec Y;
351 
352     PetscCall(DMGetGlobalVector(dm_y, &Y));
353     PetscCall(VecGetSize(Y, &Y_g_size));
354     PetscCall(VecGetLocalSize(Y, &Y_l_size));
355     PetscCall(DMRestoreGlobalVector(dm_y, &Y));
356   } else {
357     dm_y     = dm_x;
358     Y_g_size = X_g_size;
359     Y_l_size = X_l_size;
360   }
361   // Create context
362   {
363     Vec X_loc, Y_loc_transpose = NULL;
364 
365     PetscCall(DMCreateLocalVector(dm_x, &X_loc));
366     PetscCall(VecZeroEntries(X_loc));
367     if (op_mult_transpose) {
368       PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose));
369       PetscCall(VecZeroEntries(Y_loc_transpose));
370     }
371     PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx));
372     PetscCall(VecDestroy(&X_loc));
373     PetscCall(VecDestroy(&Y_loc_transpose));
374   }
375 
376   // Create mat
377   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat));
378   PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED));
379   // -- Set block and variable block sizes
380   if (dm_x == dm_y) {
381     MatType dm_mat_type, dm_mat_type_copy;
382     Mat     temp_mat;
383 
384     PetscCall(DMGetMatType(dm_x, &dm_mat_type));
385     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
386     PetscCall(DMSetMatType(dm_x, MATAIJ));
387     PetscCall(DMCreateMatrix(dm_x, &temp_mat));
388     PetscCall(DMSetMatType(dm_x, dm_mat_type_copy));
389     PetscCall(PetscFree(dm_mat_type_copy));
390 
391     {
392       PetscInt        block_size, num_blocks, max_vblock_size = PETSC_INT_MAX;
393       const PetscInt *vblock_sizes;
394 
395       // -- Get block sizes
396       PetscCall(MatGetBlockSize(temp_mat, &block_size));
397       PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes));
398       {
399         PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX};
400 
401         for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]);
402         PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max));
403         max_vblock_size = global_min_max[1];
404       }
405 
406       // -- Copy block sizes
407       if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size));
408       if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes));
409 
410       // -- Check libCEED compatibility
411       {
412         bool is_composite;
413 
414         ctx->is_ceed_pbd_valid  = PETSC_TRUE;
415         ctx->is_ceed_vpbd_valid = PETSC_TRUE;
416         PetscCallCeed(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite));
417         if (is_composite) {
418           CeedInt       num_sub_operators;
419           CeedOperator *sub_operators;
420 
421           PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators));
422           PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators));
423           for (CeedInt i = 0; i < num_sub_operators; i++) {
424             CeedInt                  num_bases, num_comp;
425             CeedBasis               *active_bases;
426             CeedOperatorAssemblyData assembly_data;
427 
428             PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data));
429             PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
430             PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
431             if (num_bases > 1) {
432               ctx->is_ceed_pbd_valid  = PETSC_FALSE;
433               ctx->is_ceed_vpbd_valid = PETSC_FALSE;
434             }
435             if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
436             if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
437           }
438         } else {
439           // LCOV_EXCL_START
440           CeedInt                  num_bases, num_comp;
441           CeedBasis               *active_bases;
442           CeedOperatorAssemblyData assembly_data;
443 
444           PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data));
445           PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
446           PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
447           if (num_bases > 1) {
448             ctx->is_ceed_pbd_valid  = PETSC_FALSE;
449             ctx->is_ceed_vpbd_valid = PETSC_FALSE;
450           }
451           if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
452           if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
453           // LCOV_EXCL_STOP
454         }
455         {
456           PetscInt local_is_valid[2], global_is_valid[2];
457 
458           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid;
459           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
460           ctx->is_ceed_pbd_valid = global_is_valid[0];
461           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid;
462           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
463           ctx->is_ceed_vpbd_valid = global_is_valid[0];
464         }
465       }
466     }
467     PetscCall(MatDestroy(&temp_mat));
468   }
469   // -- Set internal mat type
470   {
471     VecType vec_type;
472     MatType internal_mat_type = MATAIJ;
473 
474     PetscCall(VecGetType(ctx->X_loc, &vec_type));
475     if (strstr(vec_type, VECCUDA)) internal_mat_type = MATAIJCUSPARSE;
476     else if (strstr(vec_type, VECKOKKOS)) internal_mat_type = MATAIJKOKKOS;
477     else internal_mat_type = MATAIJ;
478     PetscCall(PetscStrallocpy(internal_mat_type, &ctx->internal_mat_type));
479   }
480   // -- Set mat operations
481   PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
482   PetscCall(MatShellSetOperation(*mat, MATOP_VIEW, (void (*)(void))MatView_Ceed));
483   PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed));
484   if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
485   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
486   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
487   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
488   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
489   PetscCall(MatShellSetVecType(*mat, vec_type));
490   PetscFunctionReturn(PETSC_SUCCESS);
491 }
492 
493 /**
494   @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`.
495 
496   Collective across MPI processes.
497 
498   @param[in]   mat_ceed   `MATCEED` to copy from
499   @param[out]  mat_other  `MatShell` or `MATCEED` to copy into
500 
501   @return An error code: 0 - success, otherwise - failure
502 **/
503 PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) {
504   PetscFunctionBeginUser;
505   PetscCall(MatCeedRegisterLogEvents());
506 
507   // Check type compatibility
508   {
509     MatType mat_type_ceed, mat_type_other;
510 
511     PetscCall(MatGetType(mat_ceed, &mat_type_ceed));
512     PetscCheck(!strcmp(mat_type_ceed, MATCEED), PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED);
513     PetscCall(MatGetType(mat_ceed, &mat_type_other));
514     PetscCheck(!strcmp(mat_type_other, MATCEED) || !strcmp(mat_type_other, MATSHELL), PETSC_COMM_SELF, PETSC_ERR_LIB,
515                "mat_other must have type " MATCEED " or " MATSHELL);
516   }
517 
518   // Check dimension compatibility
519   {
520     PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size;
521 
522     PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size));
523     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size));
524     PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size));
525     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size));
526     PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) &&
527                    (X_l_ceed_size == X_l_other_size),
528                PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ,
529                "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT
530                "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT
531                ", %" PetscInt_FMT ")",
532                Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size);
533   }
534 
535   // Convert
536   {
537     VecType        vec_type;
538     MatCeedContext ctx;
539 
540     PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED));
541     PetscCall(MatShellGetContext(mat_ceed, &ctx));
542     PetscCall(MatCeedContextReference(ctx));
543     PetscCall(MatShellSetContext(mat_other, ctx));
544     PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
545     PetscCall(MatShellSetOperation(mat_other, MATOP_VIEW, (void (*)(void))MatView_Ceed));
546     PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed));
547     if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
548     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
549     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
550     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
551     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
552     {
553       PetscInt block_size;
554 
555       PetscCall(MatGetBlockSize(mat_ceed, &block_size));
556       if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size));
557     }
558     {
559       PetscInt        num_blocks;
560       const PetscInt *block_sizes;
561 
562       PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes));
563       if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes));
564     }
565     PetscCall(DMGetVecType(ctx->dm_x, &vec_type));
566     PetscCall(MatShellSetVecType(mat_other, vec_type));
567   }
568   PetscFunctionReturn(PETSC_SUCCESS);
569 }
570 
571 /**
572   @brief Assemble a `MATCEED` into a `MATAIJ` or similar.
573          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
574          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
575 
576   Collective across MPI processes.
577 
578   @param[in]      mat_ceed  `MATCEED` to assemble
579   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
580 
581   @return An error code: 0 - success, otherwise - failure
582 **/
583 PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) {
584   MatCeedContext ctx;
585 
586   PetscFunctionBeginUser;
587   PetscCall(MatShellGetContext(mat_ceed, &ctx));
588 
589   // Check if COO pattern set
590   {
591     PetscInt index = -1;
592 
593     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
594       if (ctx->mats_assembled_full[i] == mat_coo) index = i;
595     }
596     if (index == -1) {
597       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
598       CeedInt      *rows_ceed, *cols_ceed;
599       PetscCount    num_entries;
600       PetscLogStage stage_amg_setup;
601 
602       // -- Assemble sparsity pattern if mat hasn't been assembled before
603       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
604       if (stage_amg_setup == -1) {
605         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
606       }
607       PetscCall(PetscLogStagePush(stage_amg_setup));
608       PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
609       PetscCall(IntArrayCeedToPetsc(num_entries, &rows_ceed, &rows_petsc));
610       PetscCall(IntArrayCeedToPetsc(num_entries, &cols_ceed, &cols_petsc));
611       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
612       free(rows_petsc);
613       free(cols_petsc);
614       if (!ctx->coo_values_full) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full));
615       PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full));
616       ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo;
617       PetscCall(PetscLogStagePop());
618     }
619   }
620 
621   // Assemble mat_ceed
622   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
623   {
624     const CeedScalar *values;
625     MatType           mat_type;
626     CeedMemType       mem_type = CEED_MEM_HOST;
627     PetscBool         is_spd, is_spd_known;
628 
629     PetscCall(MatGetType(mat_coo, &mat_type));
630     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
631     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
632     else mem_type = CEED_MEM_HOST;
633 
634     PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full));
635     PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values));
636     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
637     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
638     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
639     PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values));
640   }
641   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
642   PetscFunctionReturn(PETSC_SUCCESS);
643 }
644 
645 /**
646   @brief Set user context for a `MATCEED`.
647 
648   Collective across MPI processes.
649 
650   @param[in,out]  mat  `MATCEED`
651   @param[in]      f    The context destroy function, or NULL
652   @param[in]      ctx  User context, or NULL to unset
653 
654   @return An error code: 0 - success, otherwise - failure
655 **/
656 PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) {
657   PetscContainer user_ctx = NULL;
658 
659   PetscFunctionBeginUser;
660   if (ctx) {
661     PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx));
662     PetscCall(PetscContainerSetPointer(user_ctx, ctx));
663     PetscCall(PetscContainerSetUserDestroy(user_ctx, f));
664   }
665   PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx));
666   PetscCall(PetscContainerDestroy(&user_ctx));
667   PetscFunctionReturn(PETSC_SUCCESS);
668 }
669 
670 /**
671   @brief Retrieve the user context for a `MATCEED`.
672 
673   Collective across MPI processes.
674 
675   @param[in,out]  mat  `MATCEED`
676   @param[in]      ctx  User context
677 
678   @return An error code: 0 - success, otherwise - failure
679 **/
680 PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) {
681   PetscContainer user_ctx;
682 
683   PetscFunctionBeginUser;
684   PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx));
685   if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx));
686   else *(void **)ctx = NULL;
687   PetscFunctionReturn(PETSC_SUCCESS);
688 }
689 
690 /**
691   @brief Sets the inner matrix type as a string from the `MATCEED`.
692 
693   Collective across MPI processes.
694 
695   @param[in,out]  mat   `MATCEED`
696   @param[in]      type  Inner `MatType` to set
697 
698   @return An error code: 0 - success, otherwise - failure
699 **/
700 PetscErrorCode MatCeedSetInnerMatType(Mat mat, MatType type) {
701   MatCeedContext ctx;
702 
703   PetscFunctionBeginUser;
704   PetscCall(MatShellGetContext(mat, &ctx));
705   // Check if same
706   {
707     size_t    len_old, len_new;
708     PetscBool is_same = PETSC_FALSE;
709 
710     PetscCall(PetscStrlen(ctx->internal_mat_type, &len_old));
711     PetscCall(PetscStrlen(type, &len_new));
712     if (len_old == len_new) PetscCall(PetscStrncmp(ctx->internal_mat_type, type, len_old, &is_same));
713     if (is_same) PetscFunctionReturn(PETSC_SUCCESS);
714   }
715   // Clean up old mats in different format
716   // LCOV_EXCL_START
717   if (ctx->mat_assembled_full_internal) {
718     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
719       if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) {
720         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) {
721           ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j];
722         }
723         ctx->num_mats_assembled_full--;
724         // Note: we'll realloc this array again, so no need to shrink the allocation
725         PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
726       }
727     }
728   }
729   if (ctx->mat_assembled_pbd_internal) {
730     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
731       if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) {
732         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) {
733           ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j];
734         }
735         // Note: we'll realloc this array again, so no need to shrink the allocation
736         ctx->num_mats_assembled_pbd--;
737         PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
738       }
739     }
740   }
741   PetscCall(PetscFree(ctx->internal_mat_type));
742   PetscCall(PetscStrallocpy(type, &ctx->internal_mat_type));
743   PetscFunctionReturn(PETSC_SUCCESS);
744   // LCOV_EXCL_STOP
745 }
746 
747 /**
748   @brief Gets the inner matrix type as a string from the `MATCEED`.
749 
750   Collective across MPI processes.
751 
752   @param[in,out]  mat   `MATCEED`
753   @param[in]      type  Inner `MatType`
754 
755   @return An error code: 0 - success, otherwise - failure
756 **/
757 PetscErrorCode MatCeedGetInnerMatType(Mat mat, MatType *type) {
758   MatCeedContext ctx;
759 
760   PetscFunctionBeginUser;
761   PetscCall(MatShellGetContext(mat, &ctx));
762   *type = ctx->internal_mat_type;
763   PetscFunctionReturn(PETSC_SUCCESS);
764 }
765 
766 /**
767   @brief Set a user defined matrix operation for a `MATCEED` matrix.
768 
769   Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by
770 `MatCeedSetContext()`.
771 
772   Collective across MPI processes.
773 
774   @param[in,out]  mat  `MATCEED`
775   @param[in]      op   Name of the `MatOperation`
776   @param[in]      g    Function that provides the operation
777 
778   @return An error code: 0 - success, otherwise - failure
779 **/
780 PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) {
781   PetscFunctionBeginUser;
782   PetscCall(MatShellSetOperation(mat, op, g));
783   PetscFunctionReturn(PETSC_SUCCESS);
784 }
785 
786 /**
787   @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
788 
789   Not collective across MPI processes.
790 
791   @param[in,out]  mat              `MATCEED`
792   @param[in]      X_loc            Input PETSc local vector, or NULL
793   @param[in]      Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
794 
795   @return An error code: 0 - success, otherwise - failure
796 **/
797 PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) {
798   MatCeedContext ctx;
799 
800   PetscFunctionBeginUser;
801   PetscCall(MatShellGetContext(mat, &ctx));
802   if (X_loc) {
803     PetscInt len_old, len_new;
804 
805     PetscCall(VecGetSize(ctx->X_loc, &len_old));
806     PetscCall(VecGetSize(X_loc, &len_new));
807     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT,
808                len_new, len_old);
809     PetscCall(VecDestroy(&ctx->X_loc));
810     ctx->X_loc = X_loc;
811     PetscCall(PetscObjectReference((PetscObject)X_loc));
812   }
813   if (Y_loc_transpose) {
814     PetscInt len_old, len_new;
815 
816     PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old));
817     PetscCall(VecGetSize(Y_loc_transpose, &len_new));
818     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB,
819                "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old);
820     PetscCall(VecDestroy(&ctx->Y_loc_transpose));
821     ctx->Y_loc_transpose = Y_loc_transpose;
822     PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
823   }
824   PetscFunctionReturn(PETSC_SUCCESS);
825 }
826 
827 /**
828   @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
829 
830   Not collective across MPI processes.
831 
832   @param[in,out]  mat              `MATCEED`
833   @param[out]     X_loc            Input PETSc local vector, or NULL
834   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
835 
836   @return An error code: 0 - success, otherwise - failure
837 **/
838 PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
839   MatCeedContext ctx;
840 
841   PetscFunctionBeginUser;
842   PetscCall(MatShellGetContext(mat, &ctx));
843   if (X_loc) {
844     *X_loc = ctx->X_loc;
845     PetscCall(PetscObjectReference((PetscObject)*X_loc));
846   }
847   if (Y_loc_transpose) {
848     *Y_loc_transpose = ctx->Y_loc_transpose;
849     PetscCall(PetscObjectReference((PetscObject)*Y_loc_transpose));
850   }
851   PetscFunctionReturn(PETSC_SUCCESS);
852 }
853 
854 /**
855   @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
856 
857   Not collective across MPI processes.
858 
859   @param[in,out]  mat              MatCeed
860   @param[out]     X_loc            Input PETSc local vector, or NULL
861   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
862 
863   @return An error code: 0 - success, otherwise - failure
864 **/
865 PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
866   PetscFunctionBeginUser;
867   if (X_loc) PetscCall(VecDestroy(X_loc));
868   if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose));
869   PetscFunctionReturn(PETSC_SUCCESS);
870 }
871 
872 /**
873   @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
874 
875   Not collective across MPI processes.
876 
877   @param[in,out]  mat                MatCeed
878   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
879   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
880 
881   @return An error code: 0 - success, otherwise - failure
882 **/
883 PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
884   MatCeedContext ctx;
885 
886   PetscFunctionBeginUser;
887   PetscCall(MatShellGetContext(mat, &ctx));
888   if (op_mult) {
889     *op_mult = NULL;
890     PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult));
891   }
892   if (op_mult_transpose) {
893     *op_mult_transpose = NULL;
894     PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose));
895   }
896   PetscFunctionReturn(PETSC_SUCCESS);
897 }
898 
899 /**
900   @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
901 
902   Not collective across MPI processes.
903 
904   @param[in,out]  mat                MatCeed
905   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
906   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
907 
908   @return An error code: 0 - success, otherwise - failure
909 **/
910 PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
911   MatCeedContext ctx;
912 
913   PetscFunctionBeginUser;
914   PetscCall(MatShellGetContext(mat, &ctx));
915   if (op_mult) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult));
916   if (op_mult_transpose) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult_transpose));
917   PetscFunctionReturn(PETSC_SUCCESS);
918 }
919 
920 /**
921   @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
922 
923   Not collective across MPI processes.
924 
925   @param[in,out]  mat                       MatCeed
926   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
927   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
928 
929   @return An error code: 0 - success, otherwise - failure
930 **/
931 PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) {
932   MatCeedContext ctx;
933 
934   PetscFunctionBeginUser;
935   PetscCall(MatShellGetContext(mat, &ctx));
936   if (log_event_mult) ctx->log_event_mult = log_event_mult;
937   if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose;
938   PetscFunctionReturn(PETSC_SUCCESS);
939 }
940 
941 /**
942   @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
943 
944   Not collective across MPI processes.
945 
946   @param[in,out]  mat                       MatCeed
947   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
948   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
949 
950   @return An error code: 0 - success, otherwise - failure
951 **/
952 PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) {
953   MatCeedContext ctx;
954 
955   PetscFunctionBeginUser;
956   PetscCall(MatShellGetContext(mat, &ctx));
957   if (log_event_mult) *log_event_mult = ctx->log_event_mult;
958   if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose;
959   PetscFunctionReturn(PETSC_SUCCESS);
960 }
961 
962 // -----------------------------------------------------------------------------
963 // Operator context data
964 // -----------------------------------------------------------------------------
965 
966 /**
967   @brief Setup context data for operator application.
968 
969   Collective across MPI processes.
970 
971   @param[in]   dm_x                      Input `DM`
972   @param[in]   dm_y                      Output `DM`
973   @param[in]   X_loc                     Input PETSc local vector, or NULL
974   @param[in]   Y_loc_transpose           Input PETSc local vector for transpose operation, or NULL
975   @param[in]   op_mult                   `CeedOperator` for forward evaluation
976   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
977   @param[in]   log_event_mult            `PetscLogEvent` for forward evaluation
978   @param[in]   log_event_mult_transpose  `PetscLogEvent` for transpose evaluation
979   @param[out]  ctx                       Context data for operator evaluation
980 
981   @return An error code: 0 - success, otherwise - failure
982 **/
983 PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose,
984                                     PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) {
985   CeedSize x_loc_len, y_loc_len;
986 
987   PetscFunctionBeginUser;
988 
989   // Allocate
990   PetscCall(PetscNew(ctx));
991   (*ctx)->ref_count = 1;
992 
993   // Logging
994   (*ctx)->log_event_mult           = log_event_mult;
995   (*ctx)->log_event_mult_transpose = log_event_mult_transpose;
996 
997   // PETSc objects
998   PetscCall(PetscObjectReference((PetscObject)dm_x));
999   (*ctx)->dm_x = dm_x;
1000   PetscCall(PetscObjectReference((PetscObject)dm_y));
1001   (*ctx)->dm_y = dm_y;
1002   if (X_loc) PetscCall(PetscObjectReference((PetscObject)X_loc));
1003   (*ctx)->X_loc = X_loc;
1004   if (Y_loc_transpose) PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
1005   (*ctx)->Y_loc_transpose = Y_loc_transpose;
1006 
1007   // Memtype
1008   {
1009     const PetscScalar *x;
1010     Vec                X;
1011 
1012     PetscCall(DMGetLocalVector(dm_x, &X));
1013     PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type));
1014     PetscCall(VecRestoreArrayReadAndMemType(X, &x));
1015     PetscCall(DMRestoreLocalVector(dm_x, &X));
1016   }
1017 
1018   // libCEED objects
1019   PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB,
1020              "retrieving Ceed context object failed");
1021   PetscCallCeed((*ctx)->ceed, CeedReference((*ctx)->ceed));
1022   PetscCallCeed((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len));
1023   PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult));
1024   if (op_mult_transpose) PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose));
1025   PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc));
1026   PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc));
1027 
1028   // Flop counting
1029   {
1030     CeedSize ceed_flops_estimate = 0;
1031 
1032     PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate));
1033     (*ctx)->flops_mult = ceed_flops_estimate;
1034     if (op_mult_transpose) {
1035       PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate));
1036       (*ctx)->flops_mult_transpose = ceed_flops_estimate;
1037     }
1038   }
1039 
1040   // Check sizes
1041   if (x_loc_len > 0 || y_loc_len > 0) {
1042     CeedSize ctx_x_loc_len, ctx_y_loc_len;
1043     PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len;
1044     Vec      dm_X_loc, dm_Y_loc;
1045 
1046     // -- Input
1047     PetscCall(DMGetLocalVector(dm_x, &dm_X_loc));
1048     PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len));
1049     PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc));
1050     PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len));
1051     if (X_loc) {
1052       PetscCall(VecGetLocalSize(X_loc, &X_loc_len));
1053       PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB,
1054                  "X_loc (%" PetscInt_FMT ") must match dm_x (%" PetscInt_FMT ") dimensions", X_loc_len, dm_x_loc_len);
1055     }
1056     PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op (%" CeedSize_FMT ") must match dm_x (%" PetscInt_FMT ") dimensions",
1057                x_loc_len, dm_x_loc_len);
1058     PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc (%" CeedSize_FMT ") must match op dimensions (%" CeedSize_FMT ")",
1059                x_loc_len, ctx_x_loc_len);
1060 
1061     // -- Output
1062     PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc));
1063     PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len));
1064     PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc));
1065     PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len));
1066     PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op (%" CeedSize_FMT ") must match dm_y (%" PetscInt_FMT ") dimensions",
1067                ctx_y_loc_len, dm_y_loc_len);
1068 
1069     // -- Transpose
1070     if (Y_loc_transpose) {
1071       PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len));
1072       PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB,
1073                  "Y_loc_transpose (%" PetscInt_FMT ") must match dm_y (%" PetscInt_FMT ") dimensions", Y_loc_len, dm_y_loc_len);
1074     }
1075   }
1076   PetscFunctionReturn(PETSC_SUCCESS);
1077 }
1078 
1079 /**
1080   @brief Increment reference counter for `MATCEED` context.
1081 
1082   Not collective across MPI processes.
1083 
1084   @param[in,out]  ctx  Context data
1085 
1086   @return An error code: 0 - success, otherwise - failure
1087 **/
1088 PetscErrorCode MatCeedContextReference(MatCeedContext ctx) {
1089   PetscFunctionBeginUser;
1090   ctx->ref_count++;
1091   PetscFunctionReturn(PETSC_SUCCESS);
1092 }
1093 
1094 /**
1095   @brief Copy reference for `MATCEED`.
1096          Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`.
1097 
1098   Not collective across MPI processes.
1099 
1100   @param[in]   ctx       Context data
1101   @param[out]  ctx_copy  Copy of pointer to context data
1102 
1103   @return An error code: 0 - success, otherwise - failure
1104 **/
1105 PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) {
1106   PetscFunctionBeginUser;
1107   PetscCall(MatCeedContextReference(ctx));
1108   PetscCall(MatCeedContextDestroy(*ctx_copy));
1109   *ctx_copy = ctx;
1110   PetscFunctionReturn(PETSC_SUCCESS);
1111 }
1112 
1113 /**
1114   @brief Destroy context data for operator application.
1115 
1116   Collective across MPI processes.
1117 
1118   @param[in,out]  ctx  Context data for operator evaluation
1119 
1120   @return An error code: 0 - success, otherwise - failure
1121 **/
1122 PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) {
1123   PetscFunctionBeginUser;
1124   if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS);
1125 
1126   // PETSc objects
1127   PetscCall(DMDestroy(&ctx->dm_x));
1128   PetscCall(DMDestroy(&ctx->dm_y));
1129   PetscCall(VecDestroy(&ctx->X_loc));
1130   PetscCall(VecDestroy(&ctx->Y_loc_transpose));
1131   PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
1132   PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
1133   PetscCall(PetscFree(ctx->internal_mat_type));
1134   PetscCall(PetscFree(ctx->mats_assembled_full));
1135   PetscCall(PetscFree(ctx->mats_assembled_pbd));
1136 
1137   // libCEED objects
1138   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->x_loc));
1139   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->y_loc));
1140   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full));
1141   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd));
1142   PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult));
1143   PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose));
1144   PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed");
1145 
1146   // Deallocate
1147   ctx->is_destroyed = PETSC_TRUE;  // Flag as destroyed in case someone has stale ref
1148   PetscCall(PetscFree(ctx));
1149   PetscFunctionReturn(PETSC_SUCCESS);
1150 }
1151 
1152 /**
1153   @brief Compute the diagonal of an operator via libCEED.
1154 
1155   Collective across MPI processes.
1156 
1157   @param[in]   A  `MATCEED`
1158   @param[out]  D  Vector holding operator diagonal
1159 
1160   @return An error code: 0 - success, otherwise - failure
1161 **/
1162 PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) {
1163   PetscMemType   mem_type;
1164   Vec            D_loc;
1165   MatCeedContext ctx;
1166 
1167   PetscFunctionBeginUser;
1168   PetscCall(MatShellGetContext(A, &ctx));
1169 
1170   // Place PETSc vector in libCEED vector
1171   PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc));
1172   PetscCall(VecPetscToCeed(D_loc, &mem_type, ctx->x_loc));
1173 
1174   // Compute Diagonal
1175   PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1176 
1177   // Restore PETSc vector
1178   PetscCall(VecCeedToPetsc(ctx->x_loc, mem_type, D_loc));
1179 
1180   // Local-to-Global
1181   PetscCall(VecZeroEntries(D));
1182   PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D));
1183   PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc));
1184   PetscFunctionReturn(PETSC_SUCCESS);
1185 }
1186 
1187 /**
1188   @brief Compute `A X = Y` for a `MATCEED`.
1189 
1190   Collective across MPI processes.
1191 
1192   @param[in]   A  `MATCEED`
1193   @param[in]   X  Input PETSc vector
1194   @param[out]  Y  Output PETSc vector
1195 
1196   @return An error code: 0 - success, otherwise - failure
1197 **/
1198 PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) {
1199   MatCeedContext ctx;
1200 
1201   PetscFunctionBeginUser;
1202   PetscCall(MatShellGetContext(A, &ctx));
1203   PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0));
1204 
1205   {
1206     PetscMemType x_mem_type, y_mem_type;
1207     Vec          X_loc = ctx->X_loc, Y_loc;
1208 
1209     // Get local vectors
1210     if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1211     PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1212 
1213     // Global-to-local
1214     PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc));
1215 
1216     // Setup libCEED vectors
1217     PetscCall(VecReadPetscToCeed(X_loc, &x_mem_type, ctx->x_loc));
1218     PetscCall(VecZeroEntries(Y_loc));
1219     PetscCall(VecPetscToCeed(Y_loc, &y_mem_type, ctx->y_loc));
1220 
1221     // Apply libCEED operator
1222     PetscCall(PetscLogGpuTimeBegin());
1223     PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE));
1224     PetscCall(PetscLogGpuTimeEnd());
1225 
1226     // Restore PETSc vectors
1227     PetscCall(VecReadCeedToPetsc(ctx->x_loc, x_mem_type, X_loc));
1228     PetscCall(VecCeedToPetsc(ctx->y_loc, y_mem_type, Y_loc));
1229 
1230     // Local-to-global
1231     PetscCall(VecZeroEntries(Y));
1232     PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y));
1233 
1234     // Restore local vectors, as needed
1235     if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1236     PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1237   }
1238 
1239   // Log flops
1240   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult));
1241   else PetscCall(PetscLogFlops(ctx->flops_mult));
1242 
1243   PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0));
1244   PetscFunctionReturn(PETSC_SUCCESS);
1245 }
1246 
1247 /**
1248   @brief Compute `A^T Y = X` for a `MATCEED`.
1249 
1250   Collective across MPI processes.
1251 
1252   @param[in]   A  `MATCEED`
1253   @param[in]   Y  Input PETSc vector
1254   @param[out]  X  Output PETSc vector
1255 
1256   @return An error code: 0 - success, otherwise - failure
1257 **/
1258 PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) {
1259   MatCeedContext ctx;
1260 
1261   PetscFunctionBeginUser;
1262   PetscCall(MatShellGetContext(A, &ctx));
1263   PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0));
1264 
1265   {
1266     PetscMemType x_mem_type, y_mem_type;
1267     Vec          X_loc, Y_loc = ctx->Y_loc_transpose;
1268 
1269     // Get local vectors
1270     if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1271     PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1272 
1273     // Global-to-local
1274     PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc));
1275 
1276     // Setup libCEED vectors
1277     PetscCall(VecReadPetscToCeed(Y_loc, &y_mem_type, ctx->y_loc));
1278     PetscCall(VecZeroEntries(X_loc));
1279     PetscCall(VecPetscToCeed(X_loc, &x_mem_type, ctx->x_loc));
1280 
1281     // Apply libCEED operator
1282     PetscCall(PetscLogGpuTimeBegin());
1283     PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1284     PetscCall(PetscLogGpuTimeEnd());
1285 
1286     // Restore PETSc vectors
1287     PetscCall(VecReadCeedToPetsc(ctx->y_loc, y_mem_type, Y_loc));
1288     PetscCall(VecCeedToPetsc(ctx->x_loc, x_mem_type, X_loc));
1289 
1290     // Local-to-global
1291     PetscCall(VecZeroEntries(X));
1292     PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X));
1293 
1294     // Restore local vectors, as needed
1295     if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1296     PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1297   }
1298 
1299   // Log flops
1300   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose));
1301   else PetscCall(PetscLogFlops(ctx->flops_mult_transpose));
1302 
1303   PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0));
1304   PetscFunctionReturn(PETSC_SUCCESS);
1305 }
1306