xref: /honee/src/mat-ceed.c (revision e90c2cee6032d3255f027bc24e44037dbb4ab623)
158600ac3SJames Wright /// @file
258600ac3SJames Wright /// MatCeed and it's related operators
358600ac3SJames Wright 
4a7dac1d5SJames Wright #include <ceed-utils.h>
558600ac3SJames Wright #include <ceed.h>
658600ac3SJames Wright #include <ceed/backend.h>
758600ac3SJames Wright #include <mat-ceed-impl.h>
858600ac3SJames Wright #include <mat-ceed.h>
958600ac3SJames Wright #include <petscdmplex.h>
1058600ac3SJames Wright #include <stdlib.h>
1158600ac3SJames Wright #include <string.h>
1258600ac3SJames Wright 
1358600ac3SJames Wright PetscClassId  MATCEED_CLASSID;
1458600ac3SJames Wright PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE;
1558600ac3SJames Wright 
1658600ac3SJames Wright /**
1758600ac3SJames Wright   @brief Register MATCEED log events.
1858600ac3SJames Wright 
1958600ac3SJames Wright   Not collective across MPI processes.
2058600ac3SJames Wright 
2158600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
2258600ac3SJames Wright **/
2358600ac3SJames Wright static PetscErrorCode MatCeedRegisterLogEvents() {
2458600ac3SJames Wright   static bool registered = false;
2558600ac3SJames Wright 
2658600ac3SJames Wright   PetscFunctionBeginUser;
2758600ac3SJames Wright   if (registered) PetscFunctionReturn(PETSC_SUCCESS);
2858600ac3SJames Wright   PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID));
2958600ac3SJames Wright   PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT));
3058600ac3SJames Wright   PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE));
3158600ac3SJames Wright   registered = true;
3258600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
3358600ac3SJames Wright }
3458600ac3SJames Wright 
3558600ac3SJames Wright /**
3658600ac3SJames Wright   @brief Setup inner `Mat` for `PC` operations not directly supported by libCEED.
3758600ac3SJames Wright 
3858600ac3SJames Wright   Collective across MPI processes.
3958600ac3SJames Wright 
4058600ac3SJames Wright   @param[in]   mat_ceed   `MATCEED` to setup
4158600ac3SJames Wright   @param[out]  mat_inner  Inner `Mat`
4258600ac3SJames Wright 
4358600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
4458600ac3SJames Wright **/
4558600ac3SJames Wright static PetscErrorCode MatCeedSetupInnerMat(Mat mat_ceed, Mat *mat_inner) {
4658600ac3SJames Wright   MatCeedContext ctx;
4758600ac3SJames Wright 
4858600ac3SJames Wright   PetscFunctionBeginUser;
4958600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
5058600ac3SJames Wright 
5158600ac3SJames Wright   PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "PC only supported for MATCEED on a single DM");
5258600ac3SJames Wright 
5358600ac3SJames Wright   // Check cl mat type
5458600ac3SJames Wright   {
5558600ac3SJames Wright     PetscBool is_internal_mat_type_cl = PETSC_FALSE;
5658600ac3SJames Wright     char      internal_mat_type_cl[64];
5758600ac3SJames Wright 
5858600ac3SJames Wright     // Check for specific CL inner mat type for this Mat
5958600ac3SJames Wright     {
6058600ac3SJames Wright       const char *mat_ceed_prefix = NULL;
6158600ac3SJames Wright 
6258600ac3SJames Wright       PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix));
6358600ac3SJames Wright       PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL);
6458600ac3SJames Wright       PetscCall(PetscOptionsFList("-ceed_inner_mat_type", "MATCEED inner assembled MatType for PC support", NULL, MatList, internal_mat_type_cl,
6558600ac3SJames Wright                                   internal_mat_type_cl, sizeof(internal_mat_type_cl), &is_internal_mat_type_cl));
6658600ac3SJames Wright       PetscOptionsEnd();
6758600ac3SJames Wright       if (is_internal_mat_type_cl) {
6858600ac3SJames Wright         PetscCall(PetscFree(ctx->internal_mat_type));
6958600ac3SJames Wright         PetscCall(PetscStrallocpy(internal_mat_type_cl, &ctx->internal_mat_type));
7058600ac3SJames Wright       }
7158600ac3SJames Wright     }
7258600ac3SJames Wright   }
7358600ac3SJames Wright 
7458600ac3SJames Wright   // Create sparse matrix
7558600ac3SJames Wright   {
7658600ac3SJames Wright     MatType dm_mat_type, dm_mat_type_copy;
7758600ac3SJames Wright 
7858600ac3SJames Wright     PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type));
7958600ac3SJames Wright     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
8058600ac3SJames Wright     PetscCall(DMSetMatType(ctx->dm_x, ctx->internal_mat_type));
8158600ac3SJames Wright     PetscCall(DMCreateMatrix(ctx->dm_x, mat_inner));
8258600ac3SJames Wright     PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy));
8358600ac3SJames Wright     PetscCall(PetscFree(dm_mat_type_copy));
8458600ac3SJames Wright   }
8558600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
8658600ac3SJames Wright }
8758600ac3SJames Wright 
8858600ac3SJames Wright /**
8958600ac3SJames Wright   @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar.
9058600ac3SJames Wright          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
9158600ac3SJames Wright          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
9258600ac3SJames Wright 
9358600ac3SJames Wright   Collective across MPI processes.
9458600ac3SJames Wright 
9558600ac3SJames Wright   @param[in]      mat_ceed  `MATCEED` to assemble
9658600ac3SJames Wright   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
9758600ac3SJames Wright 
9858600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
9958600ac3SJames Wright **/
10058600ac3SJames Wright static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) {
10158600ac3SJames Wright   MatCeedContext ctx;
10258600ac3SJames Wright 
10358600ac3SJames Wright   PetscFunctionBeginUser;
10458600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
10558600ac3SJames Wright 
10658600ac3SJames Wright   // Check if COO pattern set
10758600ac3SJames Wright   {
10858600ac3SJames Wright     PetscInt index = -1;
10958600ac3SJames Wright 
11058600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
11158600ac3SJames Wright       if (ctx->mats_assembled_pbd[i] == mat_coo) index = i;
11258600ac3SJames Wright     }
11358600ac3SJames Wright     if (index == -1) {
11458600ac3SJames Wright       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
11558600ac3SJames Wright       CeedInt      *rows_ceed, *cols_ceed;
11658600ac3SJames Wright       PetscCount    num_entries;
11758600ac3SJames Wright       PetscLogStage stage_amg_setup;
11858600ac3SJames Wright 
11958600ac3SJames Wright       // -- Assemble sparsity pattern if mat hasn't been assembled before
12058600ac3SJames Wright       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
12158600ac3SJames Wright       if (stage_amg_setup == -1) {
12258600ac3SJames Wright         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
12358600ac3SJames Wright       }
12458600ac3SJames Wright       PetscCall(PetscLogStagePush(stage_amg_setup));
12550f50432SJames Wright       PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
126a7dac1d5SJames Wright       PetscCall(IntArrayCeedToPetsc(num_entries, &rows_ceed, &rows_petsc));
127a7dac1d5SJames Wright       PetscCall(IntArrayCeedToPetsc(num_entries, &cols_ceed, &cols_petsc));
12858600ac3SJames Wright       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
12958600ac3SJames Wright       free(rows_petsc);
13058600ac3SJames Wright       free(cols_petsc);
13150f50432SJames Wright       if (!ctx->coo_values_pbd) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd));
13258600ac3SJames Wright       PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd));
13358600ac3SJames Wright       ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo;
13458600ac3SJames Wright       PetscCall(PetscLogStagePop());
13558600ac3SJames Wright     }
13658600ac3SJames Wright   }
13758600ac3SJames Wright 
13858600ac3SJames Wright   // Assemble mat_ceed
13958600ac3SJames Wright   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
14058600ac3SJames Wright   {
14158600ac3SJames Wright     const CeedScalar *values;
14258600ac3SJames Wright     MatType           mat_type;
14358600ac3SJames Wright     CeedMemType       mem_type = CEED_MEM_HOST;
14458600ac3SJames Wright     PetscBool         is_spd, is_spd_known;
14558600ac3SJames Wright 
14658600ac3SJames Wright     PetscCall(MatGetType(mat_coo, &mat_type));
14758600ac3SJames Wright     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
14858600ac3SJames Wright     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
14958600ac3SJames Wright     else mem_type = CEED_MEM_HOST;
15058600ac3SJames Wright 
15150f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE));
15250f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values));
15358600ac3SJames Wright     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
15458600ac3SJames Wright     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
15558600ac3SJames Wright     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
15650f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values));
15758600ac3SJames Wright   }
15858600ac3SJames Wright   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
15958600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
16058600ac3SJames Wright }
16158600ac3SJames Wright 
16258600ac3SJames Wright /**
16358600ac3SJames Wright   @brief Assemble inner `Mat` for diagonal `PC` operations
16458600ac3SJames Wright 
16558600ac3SJames Wright   Collective across MPI processes.
16658600ac3SJames Wright 
16758600ac3SJames Wright   @param[in]   mat_ceed      `MATCEED` to invert
16858600ac3SJames Wright   @param[in]   use_ceed_pbd  Boolean flag to use libCEED PBD assembly
16958600ac3SJames Wright   @param[out]  mat_inner     Inner `Mat` for diagonal operations
17058600ac3SJames Wright 
17158600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
17258600ac3SJames Wright **/
17358600ac3SJames Wright static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) {
17458600ac3SJames Wright   MatCeedContext ctx;
17558600ac3SJames Wright 
17658600ac3SJames Wright   PetscFunctionBeginUser;
17758600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
17858600ac3SJames Wright   if (use_ceed_pbd) {
17958600ac3SJames Wright     // Check if COO pattern set
18058600ac3SJames Wright     if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_pbd_internal));
18158600ac3SJames Wright 
18258600ac3SJames Wright     // Assemble mat_assembled_full_internal
18358600ac3SJames Wright     PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal));
18458600ac3SJames Wright     if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal;
18558600ac3SJames Wright   } else {
18658600ac3SJames Wright     // Check if COO pattern set
18758600ac3SJames Wright     if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_full_internal));
18858600ac3SJames Wright 
18958600ac3SJames Wright     // Assemble mat_assembled_full_internal
19058600ac3SJames Wright     PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal));
19158600ac3SJames Wright     if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal;
19258600ac3SJames Wright   }
19358600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
19458600ac3SJames Wright }
19558600ac3SJames Wright 
19658600ac3SJames Wright /**
19758600ac3SJames Wright   @brief Get `MATCEED` diagonal block for Jacobi.
19858600ac3SJames Wright 
19958600ac3SJames Wright   Collective across MPI processes.
20058600ac3SJames Wright 
20158600ac3SJames Wright   @param[in]   mat_ceed   `MATCEED` to invert
20258600ac3SJames Wright   @param[out]  mat_block  The diagonal block matrix
20358600ac3SJames Wright 
20458600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
20558600ac3SJames Wright **/
20658600ac3SJames Wright static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) {
20758600ac3SJames Wright   Mat            mat_inner = NULL;
20858600ac3SJames Wright   MatCeedContext ctx;
20958600ac3SJames Wright 
21058600ac3SJames Wright   PetscFunctionBeginUser;
21158600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
21258600ac3SJames Wright 
21358600ac3SJames Wright   // Assemble inner mat if needed
21458600ac3SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
21558600ac3SJames Wright 
21658600ac3SJames Wright   // Get block diagonal
21758600ac3SJames Wright   PetscCall(MatGetDiagonalBlock(mat_inner, mat_block));
21858600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
21958600ac3SJames Wright }
22058600ac3SJames Wright 
22158600ac3SJames Wright /**
22258600ac3SJames Wright   @brief Invert `MATCEED` diagonal block for Jacobi.
22358600ac3SJames Wright 
22458600ac3SJames Wright   Collective across MPI processes.
22558600ac3SJames Wright 
22658600ac3SJames Wright   @param[in]   mat_ceed  `MATCEED` to invert
22758600ac3SJames Wright   @param[out]  values    The block inverses in column major order
22858600ac3SJames Wright 
22958600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
23058600ac3SJames Wright **/
23158600ac3SJames Wright static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) {
23258600ac3SJames Wright   Mat            mat_inner = NULL;
23358600ac3SJames Wright   MatCeedContext ctx;
23458600ac3SJames Wright 
23558600ac3SJames Wright   PetscFunctionBeginUser;
23658600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
23758600ac3SJames Wright 
23858600ac3SJames Wright   // Assemble inner mat if needed
23958600ac3SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
24058600ac3SJames Wright 
24158600ac3SJames Wright   // Invert PB diagonal
24258600ac3SJames Wright   PetscCall(MatInvertBlockDiagonal(mat_inner, values));
24358600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
24458600ac3SJames Wright }
24558600ac3SJames Wright 
24658600ac3SJames Wright /**
24758600ac3SJames Wright   @brief Invert `MATCEED` variable diagonal block for Jacobi.
24858600ac3SJames Wright 
24958600ac3SJames Wright   Collective across MPI processes.
25058600ac3SJames Wright 
25158600ac3SJames Wright   @param[in]   mat_ceed     `MATCEED` to invert
25258600ac3SJames Wright   @param[in]   num_blocks   The number of blocks on the process
25358600ac3SJames Wright   @param[in]   block_sizes  The size of each block on the process
25458600ac3SJames Wright   @param[out]  values       The block inverses in column major order
25558600ac3SJames Wright 
25658600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
25758600ac3SJames Wright **/
25858600ac3SJames Wright static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) {
25958600ac3SJames Wright   Mat            mat_inner = NULL;
26058600ac3SJames Wright   MatCeedContext ctx;
26158600ac3SJames Wright 
26258600ac3SJames Wright   PetscFunctionBeginUser;
26358600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
26458600ac3SJames Wright 
26558600ac3SJames Wright   // Assemble inner mat if needed
26658600ac3SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner));
26758600ac3SJames Wright 
26858600ac3SJames Wright   // Invert PB diagonal
26958600ac3SJames Wright   PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values));
27058600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
27158600ac3SJames Wright }
27258600ac3SJames Wright 
273*e90c2ceeSJames Wright /**
274*e90c2ceeSJames Wright   @brief View `MATCEED`.
275*e90c2ceeSJames Wright 
276*e90c2ceeSJames Wright   Collective across MPI processes.
277*e90c2ceeSJames Wright 
278*e90c2ceeSJames Wright   @param[in]   mat_ceed  `MATCEED` to view
279*e90c2ceeSJames Wright   @param[in]   viewer    The visualization context
280*e90c2ceeSJames Wright 
281*e90c2ceeSJames Wright   @return An error code: 0 - success, otherwise - failure
282*e90c2ceeSJames Wright **/
283*e90c2ceeSJames Wright static PetscErrorCode MatView_Ceed(Mat mat_ceed, PetscViewer viewer) {
284*e90c2ceeSJames Wright   PetscBool         is_ascii;
285*e90c2ceeSJames Wright   PetscViewerFormat format;
286*e90c2ceeSJames Wright   PetscMPIInt       size;
287*e90c2ceeSJames Wright   MatCeedContext    ctx;
288*e90c2ceeSJames Wright 
289*e90c2ceeSJames Wright   PetscFunctionBeginUser;
290*e90c2ceeSJames Wright   PetscValidHeaderSpecific(viewer, PETSC_VIEWER_CLASSID, 2);
291*e90c2ceeSJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
292*e90c2ceeSJames Wright   if (!viewer) PetscCall(PetscViewerASCIIGetStdout(PetscObjectComm((PetscObject)mat_ceed), &viewer));
293*e90c2ceeSJames Wright 
294*e90c2ceeSJames Wright   PetscCall(PetscViewerGetFormat(viewer, &format));
295*e90c2ceeSJames Wright   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat_ceed), &size));
296*e90c2ceeSJames Wright   if (size == 1 && format == PETSC_VIEWER_LOAD_BALANCE) PetscFunctionReturn(PETSC_SUCCESS);
297*e90c2ceeSJames Wright 
298*e90c2ceeSJames Wright   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &is_ascii));
299*e90c2ceeSJames Wright   {
300*e90c2ceeSJames Wright     FILE *file;
301*e90c2ceeSJames Wright 
302*e90c2ceeSJames Wright     PetscCall(PetscViewerASCIIPrintf(viewer, "MatCEED:\n  Internal MatType:%s\n", ctx->internal_mat_type));
303*e90c2ceeSJames Wright     PetscCall(PetscViewerASCIIGetPointer(viewer, &file));
304*e90c2ceeSJames Wright     PetscCall(PetscViewerASCIIPrintf(viewer, " libCEED Operator:\n"));
305*e90c2ceeSJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorView(ctx->op_mult, file));
306*e90c2ceeSJames Wright     if (ctx->op_mult_transpose) {
307*e90c2ceeSJames Wright       PetscCall(PetscViewerASCIIPrintf(viewer, "  libCEED Transpose Operator:\n"));
308*e90c2ceeSJames Wright       PetscCallCeed(ctx->ceed, CeedOperatorView(ctx->op_mult_transpose, file));
309*e90c2ceeSJames Wright     }
310*e90c2ceeSJames Wright   }
311*e90c2ceeSJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
312*e90c2ceeSJames Wright }
313*e90c2ceeSJames Wright 
31458600ac3SJames Wright // -----------------------------------------------------------------------------
31558600ac3SJames Wright // MatCeed
31658600ac3SJames Wright // -----------------------------------------------------------------------------
31758600ac3SJames Wright 
31858600ac3SJames Wright /**
31958600ac3SJames Wright   @brief Create PETSc `Mat` from libCEED operators.
32058600ac3SJames Wright 
32158600ac3SJames Wright   Collective across MPI processes.
32258600ac3SJames Wright 
32358600ac3SJames Wright   @param[in]   dm_x                      Input `DM`
32458600ac3SJames Wright   @param[in]   dm_y                      Output `DM`
32558600ac3SJames Wright   @param[in]   op_mult                   `CeedOperator` for forward evaluation
32658600ac3SJames Wright   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
32758600ac3SJames Wright   @param[out]  mat                        New MatCeed
32858600ac3SJames Wright 
32958600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
33058600ac3SJames Wright **/
33158600ac3SJames Wright PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) {
33258600ac3SJames Wright   PetscInt       X_l_size, X_g_size, Y_l_size, Y_g_size;
33358600ac3SJames Wright   VecType        vec_type;
33458600ac3SJames Wright   MatCeedContext ctx;
33558600ac3SJames Wright 
33658600ac3SJames Wright   PetscFunctionBeginUser;
33758600ac3SJames Wright   PetscCall(MatCeedRegisterLogEvents());
33858600ac3SJames Wright 
33958600ac3SJames Wright   // Collect context data
34058600ac3SJames Wright   PetscCall(DMGetVecType(dm_x, &vec_type));
34158600ac3SJames Wright   {
34258600ac3SJames Wright     Vec X;
34358600ac3SJames Wright 
34458600ac3SJames Wright     PetscCall(DMGetGlobalVector(dm_x, &X));
34558600ac3SJames Wright     PetscCall(VecGetSize(X, &X_g_size));
34658600ac3SJames Wright     PetscCall(VecGetLocalSize(X, &X_l_size));
34758600ac3SJames Wright     PetscCall(DMRestoreGlobalVector(dm_x, &X));
34858600ac3SJames Wright   }
34958600ac3SJames Wright   if (dm_y) {
35058600ac3SJames Wright     Vec Y;
35158600ac3SJames Wright 
35258600ac3SJames Wright     PetscCall(DMGetGlobalVector(dm_y, &Y));
35358600ac3SJames Wright     PetscCall(VecGetSize(Y, &Y_g_size));
35458600ac3SJames Wright     PetscCall(VecGetLocalSize(Y, &Y_l_size));
35558600ac3SJames Wright     PetscCall(DMRestoreGlobalVector(dm_y, &Y));
35658600ac3SJames Wright   } else {
35758600ac3SJames Wright     dm_y     = dm_x;
35858600ac3SJames Wright     Y_g_size = X_g_size;
35958600ac3SJames Wright     Y_l_size = X_l_size;
36058600ac3SJames Wright   }
36158600ac3SJames Wright   // Create context
36258600ac3SJames Wright   {
36358600ac3SJames Wright     Vec X_loc, Y_loc_transpose = NULL;
36458600ac3SJames Wright 
36558600ac3SJames Wright     PetscCall(DMCreateLocalVector(dm_x, &X_loc));
36658600ac3SJames Wright     PetscCall(VecZeroEntries(X_loc));
36758600ac3SJames Wright     if (op_mult_transpose) {
36858600ac3SJames Wright       PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose));
36958600ac3SJames Wright       PetscCall(VecZeroEntries(Y_loc_transpose));
37058600ac3SJames Wright     }
37158600ac3SJames Wright     PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx));
37258600ac3SJames Wright     PetscCall(VecDestroy(&X_loc));
37358600ac3SJames Wright     PetscCall(VecDestroy(&Y_loc_transpose));
37458600ac3SJames Wright   }
37558600ac3SJames Wright 
37658600ac3SJames Wright   // Create mat
37758600ac3SJames Wright   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat));
37858600ac3SJames Wright   PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED));
37958600ac3SJames Wright   // -- Set block and variable block sizes
38058600ac3SJames Wright   if (dm_x == dm_y) {
38158600ac3SJames Wright     MatType dm_mat_type, dm_mat_type_copy;
38258600ac3SJames Wright     Mat     temp_mat;
38358600ac3SJames Wright 
38458600ac3SJames Wright     PetscCall(DMGetMatType(dm_x, &dm_mat_type));
38558600ac3SJames Wright     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
38658600ac3SJames Wright     PetscCall(DMSetMatType(dm_x, MATAIJ));
38758600ac3SJames Wright     PetscCall(DMCreateMatrix(dm_x, &temp_mat));
38858600ac3SJames Wright     PetscCall(DMSetMatType(dm_x, dm_mat_type_copy));
38958600ac3SJames Wright     PetscCall(PetscFree(dm_mat_type_copy));
39058600ac3SJames Wright 
39158600ac3SJames Wright     {
39258600ac3SJames Wright       PetscInt        block_size, num_blocks, max_vblock_size = PETSC_INT_MAX;
39358600ac3SJames Wright       const PetscInt *vblock_sizes;
39458600ac3SJames Wright 
39558600ac3SJames Wright       // -- Get block sizes
39658600ac3SJames Wright       PetscCall(MatGetBlockSize(temp_mat, &block_size));
39758600ac3SJames Wright       PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes));
39858600ac3SJames Wright       {
39958600ac3SJames Wright         PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX};
40058600ac3SJames Wright 
40158600ac3SJames Wright         for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]);
40258600ac3SJames Wright         PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max));
40358600ac3SJames Wright         max_vblock_size = global_min_max[1];
40458600ac3SJames Wright       }
40558600ac3SJames Wright 
40658600ac3SJames Wright       // -- Copy block sizes
40758600ac3SJames Wright       if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size));
40858600ac3SJames Wright       if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes));
40958600ac3SJames Wright 
41058600ac3SJames Wright       // -- Check libCEED compatibility
41158600ac3SJames Wright       {
41258600ac3SJames Wright         bool is_composite;
41358600ac3SJames Wright 
41458600ac3SJames Wright         ctx->is_ceed_pbd_valid  = PETSC_TRUE;
41558600ac3SJames Wright         ctx->is_ceed_vpbd_valid = PETSC_TRUE;
41650f50432SJames Wright         PetscCallCeed(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite));
41758600ac3SJames Wright         if (is_composite) {
41858600ac3SJames Wright           CeedInt       num_sub_operators;
41958600ac3SJames Wright           CeedOperator *sub_operators;
42058600ac3SJames Wright 
42150f50432SJames Wright           PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators));
42250f50432SJames Wright           PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators));
42358600ac3SJames Wright           for (CeedInt i = 0; i < num_sub_operators; i++) {
42458600ac3SJames Wright             CeedInt                  num_bases, num_comp;
42558600ac3SJames Wright             CeedBasis               *active_bases;
42658600ac3SJames Wright             CeedOperatorAssemblyData assembly_data;
42758600ac3SJames Wright 
42850f50432SJames Wright             PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data));
42950f50432SJames Wright             PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
43050f50432SJames Wright             PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
43158600ac3SJames Wright             if (num_bases > 1) {
43258600ac3SJames Wright               ctx->is_ceed_pbd_valid  = PETSC_FALSE;
43358600ac3SJames Wright               ctx->is_ceed_vpbd_valid = PETSC_FALSE;
43458600ac3SJames Wright             }
43558600ac3SJames Wright             if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
43658600ac3SJames Wright             if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
43758600ac3SJames Wright           }
43858600ac3SJames Wright         } else {
43958600ac3SJames Wright           // LCOV_EXCL_START
44058600ac3SJames Wright           CeedInt                  num_bases, num_comp;
44158600ac3SJames Wright           CeedBasis               *active_bases;
44258600ac3SJames Wright           CeedOperatorAssemblyData assembly_data;
44358600ac3SJames Wright 
44450f50432SJames Wright           PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data));
44550f50432SJames Wright           PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
44650f50432SJames Wright           PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
44758600ac3SJames Wright           if (num_bases > 1) {
44858600ac3SJames Wright             ctx->is_ceed_pbd_valid  = PETSC_FALSE;
44958600ac3SJames Wright             ctx->is_ceed_vpbd_valid = PETSC_FALSE;
45058600ac3SJames Wright           }
45158600ac3SJames Wright           if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
45258600ac3SJames Wright           if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
45358600ac3SJames Wright           // LCOV_EXCL_STOP
45458600ac3SJames Wright         }
45558600ac3SJames Wright         {
45658600ac3SJames Wright           PetscInt local_is_valid[2], global_is_valid[2];
45758600ac3SJames Wright 
45858600ac3SJames Wright           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid;
45958600ac3SJames Wright           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
46058600ac3SJames Wright           ctx->is_ceed_pbd_valid = global_is_valid[0];
46158600ac3SJames Wright           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid;
46258600ac3SJames Wright           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
46358600ac3SJames Wright           ctx->is_ceed_vpbd_valid = global_is_valid[0];
46458600ac3SJames Wright         }
46558600ac3SJames Wright       }
46658600ac3SJames Wright     }
46758600ac3SJames Wright     PetscCall(MatDestroy(&temp_mat));
46858600ac3SJames Wright   }
46958600ac3SJames Wright   // -- Set internal mat type
47058600ac3SJames Wright   {
47158600ac3SJames Wright     VecType vec_type;
47258600ac3SJames Wright     MatType internal_mat_type = MATAIJ;
47358600ac3SJames Wright 
47458600ac3SJames Wright     PetscCall(VecGetType(ctx->X_loc, &vec_type));
47558600ac3SJames Wright     if (strstr(vec_type, VECCUDA)) internal_mat_type = MATAIJCUSPARSE;
47658600ac3SJames Wright     else if (strstr(vec_type, VECKOKKOS)) internal_mat_type = MATAIJKOKKOS;
47758600ac3SJames Wright     else internal_mat_type = MATAIJ;
47858600ac3SJames Wright     PetscCall(PetscStrallocpy(internal_mat_type, &ctx->internal_mat_type));
47958600ac3SJames Wright   }
48058600ac3SJames Wright   // -- Set mat operations
48158600ac3SJames Wright   PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
482*e90c2ceeSJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_VIEW, (void (*)(void))MatView_Ceed));
48358600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed));
48458600ac3SJames Wright   if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
48558600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
48658600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
48758600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
48858600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
48958600ac3SJames Wright   PetscCall(MatShellSetVecType(*mat, vec_type));
49058600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
49158600ac3SJames Wright }
49258600ac3SJames Wright 
49358600ac3SJames Wright /**
49458600ac3SJames Wright   @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`.
49558600ac3SJames Wright 
49658600ac3SJames Wright   Collective across MPI processes.
49758600ac3SJames Wright 
49858600ac3SJames Wright   @param[in]   mat_ceed   `MATCEED` to copy from
49958600ac3SJames Wright   @param[out]  mat_other  `MatShell` or `MATCEED` to copy into
50058600ac3SJames Wright 
50158600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
50258600ac3SJames Wright **/
50358600ac3SJames Wright PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) {
50458600ac3SJames Wright   PetscFunctionBeginUser;
50558600ac3SJames Wright   PetscCall(MatCeedRegisterLogEvents());
50658600ac3SJames Wright 
50758600ac3SJames Wright   // Check type compatibility
50858600ac3SJames Wright   {
50958600ac3SJames Wright     MatType mat_type_ceed, mat_type_other;
51058600ac3SJames Wright 
51158600ac3SJames Wright     PetscCall(MatGetType(mat_ceed, &mat_type_ceed));
51258600ac3SJames Wright     PetscCheck(!strcmp(mat_type_ceed, MATCEED), PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED);
51358600ac3SJames Wright     PetscCall(MatGetType(mat_ceed, &mat_type_other));
51458600ac3SJames Wright     PetscCheck(!strcmp(mat_type_other, MATCEED) || !strcmp(mat_type_other, MATSHELL), PETSC_COMM_SELF, PETSC_ERR_LIB,
51558600ac3SJames Wright                "mat_other must have type " MATCEED " or " MATSHELL);
51658600ac3SJames Wright   }
51758600ac3SJames Wright 
51858600ac3SJames Wright   // Check dimension compatibility
51958600ac3SJames Wright   {
52058600ac3SJames Wright     PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size;
52158600ac3SJames Wright 
52258600ac3SJames Wright     PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size));
52358600ac3SJames Wright     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size));
52458600ac3SJames Wright     PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size));
52558600ac3SJames Wright     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size));
52658600ac3SJames Wright     PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) &&
52758600ac3SJames Wright                    (X_l_ceed_size == X_l_other_size),
52858600ac3SJames Wright                PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ,
52958600ac3SJames Wright                "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT
53058600ac3SJames Wright                "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT
53158600ac3SJames Wright                ", %" PetscInt_FMT ")",
53258600ac3SJames Wright                Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size);
53358600ac3SJames Wright   }
53458600ac3SJames Wright 
53558600ac3SJames Wright   // Convert
53658600ac3SJames Wright   {
53758600ac3SJames Wright     VecType        vec_type;
53858600ac3SJames Wright     MatCeedContext ctx;
53958600ac3SJames Wright 
54058600ac3SJames Wright     PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED));
54158600ac3SJames Wright     PetscCall(MatShellGetContext(mat_ceed, &ctx));
54258600ac3SJames Wright     PetscCall(MatCeedContextReference(ctx));
54358600ac3SJames Wright     PetscCall(MatShellSetContext(mat_other, ctx));
54458600ac3SJames Wright     PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
545*e90c2ceeSJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_VIEW, (void (*)(void))MatView_Ceed));
54658600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed));
54758600ac3SJames Wright     if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
54858600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
54958600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
55058600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
55158600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
55258600ac3SJames Wright     {
55358600ac3SJames Wright       PetscInt block_size;
55458600ac3SJames Wright 
55558600ac3SJames Wright       PetscCall(MatGetBlockSize(mat_ceed, &block_size));
55658600ac3SJames Wright       if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size));
55758600ac3SJames Wright     }
55858600ac3SJames Wright     {
55958600ac3SJames Wright       PetscInt        num_blocks;
56058600ac3SJames Wright       const PetscInt *block_sizes;
56158600ac3SJames Wright 
56258600ac3SJames Wright       PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes));
56358600ac3SJames Wright       if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes));
56458600ac3SJames Wright     }
56558600ac3SJames Wright     PetscCall(DMGetVecType(ctx->dm_x, &vec_type));
56658600ac3SJames Wright     PetscCall(MatShellSetVecType(mat_other, vec_type));
56758600ac3SJames Wright   }
56858600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
56958600ac3SJames Wright }
57058600ac3SJames Wright 
57158600ac3SJames Wright /**
57258600ac3SJames Wright   @brief Assemble a `MATCEED` into a `MATAIJ` or similar.
57358600ac3SJames Wright          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
57458600ac3SJames Wright          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
57558600ac3SJames Wright 
57658600ac3SJames Wright   Collective across MPI processes.
57758600ac3SJames Wright 
57858600ac3SJames Wright   @param[in]      mat_ceed  `MATCEED` to assemble
57958600ac3SJames Wright   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
58058600ac3SJames Wright 
58158600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
58258600ac3SJames Wright **/
58358600ac3SJames Wright PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) {
58458600ac3SJames Wright   MatCeedContext ctx;
58558600ac3SJames Wright 
58658600ac3SJames Wright   PetscFunctionBeginUser;
58758600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
58858600ac3SJames Wright 
58958600ac3SJames Wright   // Check if COO pattern set
59058600ac3SJames Wright   {
59158600ac3SJames Wright     PetscInt index = -1;
59258600ac3SJames Wright 
59358600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
59458600ac3SJames Wright       if (ctx->mats_assembled_full[i] == mat_coo) index = i;
59558600ac3SJames Wright     }
59658600ac3SJames Wright     if (index == -1) {
59758600ac3SJames Wright       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
59858600ac3SJames Wright       CeedInt      *rows_ceed, *cols_ceed;
59958600ac3SJames Wright       PetscCount    num_entries;
60058600ac3SJames Wright       PetscLogStage stage_amg_setup;
60158600ac3SJames Wright 
60258600ac3SJames Wright       // -- Assemble sparsity pattern if mat hasn't been assembled before
60358600ac3SJames Wright       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
60458600ac3SJames Wright       if (stage_amg_setup == -1) {
60558600ac3SJames Wright         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
60658600ac3SJames Wright       }
60758600ac3SJames Wright       PetscCall(PetscLogStagePush(stage_amg_setup));
60850f50432SJames Wright       PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
609a7dac1d5SJames Wright       PetscCall(IntArrayCeedToPetsc(num_entries, &rows_ceed, &rows_petsc));
610a7dac1d5SJames Wright       PetscCall(IntArrayCeedToPetsc(num_entries, &cols_ceed, &cols_petsc));
61158600ac3SJames Wright       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
61258600ac3SJames Wright       free(rows_petsc);
61358600ac3SJames Wright       free(cols_petsc);
61450f50432SJames Wright       if (!ctx->coo_values_full) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full));
61558600ac3SJames Wright       PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full));
61658600ac3SJames Wright       ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo;
61758600ac3SJames Wright       PetscCall(PetscLogStagePop());
61858600ac3SJames Wright     }
61958600ac3SJames Wright   }
62058600ac3SJames Wright 
62158600ac3SJames Wright   // Assemble mat_ceed
62258600ac3SJames Wright   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
62358600ac3SJames Wright   {
62458600ac3SJames Wright     const CeedScalar *values;
62558600ac3SJames Wright     MatType           mat_type;
62658600ac3SJames Wright     CeedMemType       mem_type = CEED_MEM_HOST;
62758600ac3SJames Wright     PetscBool         is_spd, is_spd_known;
62858600ac3SJames Wright 
62958600ac3SJames Wright     PetscCall(MatGetType(mat_coo, &mat_type));
63058600ac3SJames Wright     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
63158600ac3SJames Wright     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
63258600ac3SJames Wright     else mem_type = CEED_MEM_HOST;
63358600ac3SJames Wright 
63450f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full));
63550f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values));
63658600ac3SJames Wright     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
63758600ac3SJames Wright     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
63858600ac3SJames Wright     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
63950f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values));
64058600ac3SJames Wright   }
64158600ac3SJames Wright   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
64258600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
64358600ac3SJames Wright }
64458600ac3SJames Wright 
64558600ac3SJames Wright /**
64658600ac3SJames Wright   @brief Set user context for a `MATCEED`.
64758600ac3SJames Wright 
64858600ac3SJames Wright   Collective across MPI processes.
64958600ac3SJames Wright 
65058600ac3SJames Wright   @param[in,out]  mat  `MATCEED`
65158600ac3SJames Wright   @param[in]      f    The context destroy function, or NULL
65258600ac3SJames Wright   @param[in]      ctx  User context, or NULL to unset
65358600ac3SJames Wright 
65458600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
65558600ac3SJames Wright **/
65658600ac3SJames Wright PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) {
65758600ac3SJames Wright   PetscContainer user_ctx = NULL;
65858600ac3SJames Wright 
65958600ac3SJames Wright   PetscFunctionBeginUser;
66058600ac3SJames Wright   if (ctx) {
66158600ac3SJames Wright     PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx));
66258600ac3SJames Wright     PetscCall(PetscContainerSetPointer(user_ctx, ctx));
66358600ac3SJames Wright     PetscCall(PetscContainerSetUserDestroy(user_ctx, f));
66458600ac3SJames Wright   }
66558600ac3SJames Wright   PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx));
66658600ac3SJames Wright   PetscCall(PetscContainerDestroy(&user_ctx));
66758600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
66858600ac3SJames Wright }
66958600ac3SJames Wright 
67058600ac3SJames Wright /**
67158600ac3SJames Wright   @brief Retrieve the user context for a `MATCEED`.
67258600ac3SJames Wright 
67358600ac3SJames Wright   Collective across MPI processes.
67458600ac3SJames Wright 
67558600ac3SJames Wright   @param[in,out]  mat  `MATCEED`
67658600ac3SJames Wright   @param[in]      ctx  User context
67758600ac3SJames Wright 
67858600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
67958600ac3SJames Wright **/
68058600ac3SJames Wright PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) {
68158600ac3SJames Wright   PetscContainer user_ctx;
68258600ac3SJames Wright 
68358600ac3SJames Wright   PetscFunctionBeginUser;
68458600ac3SJames Wright   PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx));
68558600ac3SJames Wright   if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx));
68658600ac3SJames Wright   else *(void **)ctx = NULL;
68758600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
68858600ac3SJames Wright }
68958600ac3SJames Wright 
69058600ac3SJames Wright /**
69158600ac3SJames Wright   @brief Sets the inner matrix type as a string from the `MATCEED`.
69258600ac3SJames Wright 
69358600ac3SJames Wright   Collective across MPI processes.
69458600ac3SJames Wright 
69558600ac3SJames Wright   @param[in,out]  mat   `MATCEED`
69658600ac3SJames Wright   @param[in]      type  Inner `MatType` to set
69758600ac3SJames Wright 
69858600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
69958600ac3SJames Wright **/
70058600ac3SJames Wright PetscErrorCode MatCeedSetInnerMatType(Mat mat, MatType type) {
70158600ac3SJames Wright   MatCeedContext ctx;
70258600ac3SJames Wright 
70358600ac3SJames Wright   PetscFunctionBeginUser;
70458600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
70558600ac3SJames Wright   // Check if same
70658600ac3SJames Wright   {
70758600ac3SJames Wright     size_t    len_old, len_new;
70858600ac3SJames Wright     PetscBool is_same = PETSC_FALSE;
70958600ac3SJames Wright 
71058600ac3SJames Wright     PetscCall(PetscStrlen(ctx->internal_mat_type, &len_old));
71158600ac3SJames Wright     PetscCall(PetscStrlen(type, &len_new));
71258600ac3SJames Wright     if (len_old == len_new) PetscCall(PetscStrncmp(ctx->internal_mat_type, type, len_old, &is_same));
71358600ac3SJames Wright     if (is_same) PetscFunctionReturn(PETSC_SUCCESS);
71458600ac3SJames Wright   }
71558600ac3SJames Wright   // Clean up old mats in different format
71658600ac3SJames Wright   // LCOV_EXCL_START
71758600ac3SJames Wright   if (ctx->mat_assembled_full_internal) {
71858600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
71958600ac3SJames Wright       if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) {
72058600ac3SJames Wright         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) {
72158600ac3SJames Wright           ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j];
72258600ac3SJames Wright         }
72358600ac3SJames Wright         ctx->num_mats_assembled_full--;
72458600ac3SJames Wright         // Note: we'll realloc this array again, so no need to shrink the allocation
72558600ac3SJames Wright         PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
72658600ac3SJames Wright       }
72758600ac3SJames Wright     }
72858600ac3SJames Wright   }
72958600ac3SJames Wright   if (ctx->mat_assembled_pbd_internal) {
73058600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
73158600ac3SJames Wright       if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) {
73258600ac3SJames Wright         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) {
73358600ac3SJames Wright           ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j];
73458600ac3SJames Wright         }
73558600ac3SJames Wright         // Note: we'll realloc this array again, so no need to shrink the allocation
73658600ac3SJames Wright         ctx->num_mats_assembled_pbd--;
73758600ac3SJames Wright         PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
73858600ac3SJames Wright       }
73958600ac3SJames Wright     }
74058600ac3SJames Wright   }
74158600ac3SJames Wright   PetscCall(PetscFree(ctx->internal_mat_type));
74258600ac3SJames Wright   PetscCall(PetscStrallocpy(type, &ctx->internal_mat_type));
74358600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
74458600ac3SJames Wright   // LCOV_EXCL_STOP
74558600ac3SJames Wright }
74658600ac3SJames Wright 
74758600ac3SJames Wright /**
74858600ac3SJames Wright   @brief Gets the inner matrix type as a string from the `MATCEED`.
74958600ac3SJames Wright 
75058600ac3SJames Wright   Collective across MPI processes.
75158600ac3SJames Wright 
75258600ac3SJames Wright   @param[in,out]  mat   `MATCEED`
75358600ac3SJames Wright   @param[in]      type  Inner `MatType`
75458600ac3SJames Wright 
75558600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
75658600ac3SJames Wright **/
75758600ac3SJames Wright PetscErrorCode MatCeedGetInnerMatType(Mat mat, MatType *type) {
75858600ac3SJames Wright   MatCeedContext ctx;
75958600ac3SJames Wright 
76058600ac3SJames Wright   PetscFunctionBeginUser;
76158600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
76258600ac3SJames Wright   *type = ctx->internal_mat_type;
76358600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
76458600ac3SJames Wright }
76558600ac3SJames Wright 
76658600ac3SJames Wright /**
76758600ac3SJames Wright   @brief Set a user defined matrix operation for a `MATCEED` matrix.
76858600ac3SJames Wright 
76958600ac3SJames Wright   Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by
77058600ac3SJames Wright `MatCeedSetContext()`.
77158600ac3SJames Wright 
77258600ac3SJames Wright   Collective across MPI processes.
77358600ac3SJames Wright 
77458600ac3SJames Wright   @param[in,out]  mat  `MATCEED`
77558600ac3SJames Wright   @param[in]      op   Name of the `MatOperation`
77658600ac3SJames Wright   @param[in]      g    Function that provides the operation
77758600ac3SJames Wright 
77858600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
77958600ac3SJames Wright **/
78058600ac3SJames Wright PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) {
78158600ac3SJames Wright   PetscFunctionBeginUser;
78258600ac3SJames Wright   PetscCall(MatShellSetOperation(mat, op, g));
78358600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
78458600ac3SJames Wright }
78558600ac3SJames Wright 
78658600ac3SJames Wright /**
78758600ac3SJames Wright   @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
78858600ac3SJames Wright 
78958600ac3SJames Wright   Not collective across MPI processes.
79058600ac3SJames Wright 
79158600ac3SJames Wright   @param[in,out]  mat              `MATCEED`
79258600ac3SJames Wright   @param[in]      X_loc            Input PETSc local vector, or NULL
79358600ac3SJames Wright   @param[in]      Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
79458600ac3SJames Wright 
79558600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
79658600ac3SJames Wright **/
79758600ac3SJames Wright PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) {
79858600ac3SJames Wright   MatCeedContext ctx;
79958600ac3SJames Wright 
80058600ac3SJames Wright   PetscFunctionBeginUser;
80158600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
80258600ac3SJames Wright   if (X_loc) {
80358600ac3SJames Wright     PetscInt len_old, len_new;
80458600ac3SJames Wright 
80558600ac3SJames Wright     PetscCall(VecGetSize(ctx->X_loc, &len_old));
80658600ac3SJames Wright     PetscCall(VecGetSize(X_loc, &len_new));
80758600ac3SJames Wright     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT,
80858600ac3SJames Wright                len_new, len_old);
80958600ac3SJames Wright     PetscCall(VecDestroy(&ctx->X_loc));
81058600ac3SJames Wright     ctx->X_loc = X_loc;
81158600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)X_loc));
81258600ac3SJames Wright   }
81358600ac3SJames Wright   if (Y_loc_transpose) {
81458600ac3SJames Wright     PetscInt len_old, len_new;
81558600ac3SJames Wright 
81658600ac3SJames Wright     PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old));
81758600ac3SJames Wright     PetscCall(VecGetSize(Y_loc_transpose, &len_new));
81858600ac3SJames Wright     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB,
81958600ac3SJames Wright                "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old);
82058600ac3SJames Wright     PetscCall(VecDestroy(&ctx->Y_loc_transpose));
82158600ac3SJames Wright     ctx->Y_loc_transpose = Y_loc_transpose;
82258600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
82358600ac3SJames Wright   }
82458600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
82558600ac3SJames Wright }
82658600ac3SJames Wright 
82758600ac3SJames Wright /**
82858600ac3SJames Wright   @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
82958600ac3SJames Wright 
83058600ac3SJames Wright   Not collective across MPI processes.
83158600ac3SJames Wright 
83258600ac3SJames Wright   @param[in,out]  mat              `MATCEED`
83358600ac3SJames Wright   @param[out]     X_loc            Input PETSc local vector, or NULL
83458600ac3SJames Wright   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
83558600ac3SJames Wright 
83658600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
83758600ac3SJames Wright **/
83858600ac3SJames Wright PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
83958600ac3SJames Wright   MatCeedContext ctx;
84058600ac3SJames Wright 
84158600ac3SJames Wright   PetscFunctionBeginUser;
84258600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
84358600ac3SJames Wright   if (X_loc) {
84458600ac3SJames Wright     *X_loc = ctx->X_loc;
84558600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)*X_loc));
84658600ac3SJames Wright   }
84758600ac3SJames Wright   if (Y_loc_transpose) {
84858600ac3SJames Wright     *Y_loc_transpose = ctx->Y_loc_transpose;
84958600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)*Y_loc_transpose));
85058600ac3SJames Wright   }
85158600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
85258600ac3SJames Wright }
85358600ac3SJames Wright 
85458600ac3SJames Wright /**
85558600ac3SJames Wright   @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
85658600ac3SJames Wright 
85758600ac3SJames Wright   Not collective across MPI processes.
85858600ac3SJames Wright 
85958600ac3SJames Wright   @param[in,out]  mat              MatCeed
86058600ac3SJames Wright   @param[out]     X_loc            Input PETSc local vector, or NULL
86158600ac3SJames Wright   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
86258600ac3SJames Wright 
86358600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
86458600ac3SJames Wright **/
86558600ac3SJames Wright PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
86658600ac3SJames Wright   PetscFunctionBeginUser;
86758600ac3SJames Wright   if (X_loc) PetscCall(VecDestroy(X_loc));
86858600ac3SJames Wright   if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose));
86958600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
87058600ac3SJames Wright }
87158600ac3SJames Wright 
87258600ac3SJames Wright /**
87358600ac3SJames Wright   @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
87458600ac3SJames Wright 
87558600ac3SJames Wright   Not collective across MPI processes.
87658600ac3SJames Wright 
87758600ac3SJames Wright   @param[in,out]  mat                MatCeed
87858600ac3SJames Wright   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
87958600ac3SJames Wright   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
88058600ac3SJames Wright 
88158600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
88258600ac3SJames Wright **/
88358600ac3SJames Wright PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
88458600ac3SJames Wright   MatCeedContext ctx;
88558600ac3SJames Wright 
88658600ac3SJames Wright   PetscFunctionBeginUser;
88758600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
88858600ac3SJames Wright   if (op_mult) {
88958600ac3SJames Wright     *op_mult = NULL;
89050f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult));
89158600ac3SJames Wright   }
89258600ac3SJames Wright   if (op_mult_transpose) {
89358600ac3SJames Wright     *op_mult_transpose = NULL;
89450f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose));
89558600ac3SJames Wright   }
89658600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
89758600ac3SJames Wright }
89858600ac3SJames Wright 
89958600ac3SJames Wright /**
90058600ac3SJames Wright   @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
90158600ac3SJames Wright 
90258600ac3SJames Wright   Not collective across MPI processes.
90358600ac3SJames Wright 
90458600ac3SJames Wright   @param[in,out]  mat                MatCeed
90558600ac3SJames Wright   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
90658600ac3SJames Wright   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
90758600ac3SJames Wright 
90858600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
90958600ac3SJames Wright **/
91058600ac3SJames Wright PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
91158600ac3SJames Wright   MatCeedContext ctx;
91258600ac3SJames Wright 
91358600ac3SJames Wright   PetscFunctionBeginUser;
91458600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
91550f50432SJames Wright   if (op_mult) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult));
91650f50432SJames Wright   if (op_mult_transpose) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult_transpose));
91758600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
91858600ac3SJames Wright }
91958600ac3SJames Wright 
92058600ac3SJames Wright /**
92158600ac3SJames Wright   @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
92258600ac3SJames Wright 
92358600ac3SJames Wright   Not collective across MPI processes.
92458600ac3SJames Wright 
92558600ac3SJames Wright   @param[in,out]  mat                       MatCeed
92658600ac3SJames Wright   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
92758600ac3SJames Wright   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
92858600ac3SJames Wright 
92958600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
93058600ac3SJames Wright **/
93158600ac3SJames Wright PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) {
93258600ac3SJames Wright   MatCeedContext ctx;
93358600ac3SJames Wright 
93458600ac3SJames Wright   PetscFunctionBeginUser;
93558600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
93658600ac3SJames Wright   if (log_event_mult) ctx->log_event_mult = log_event_mult;
93758600ac3SJames Wright   if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose;
93858600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
93958600ac3SJames Wright }
94058600ac3SJames Wright 
94158600ac3SJames Wright /**
94258600ac3SJames Wright   @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
94358600ac3SJames Wright 
94458600ac3SJames Wright   Not collective across MPI processes.
94558600ac3SJames Wright 
94658600ac3SJames Wright   @param[in,out]  mat                       MatCeed
94758600ac3SJames Wright   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
94858600ac3SJames Wright   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
94958600ac3SJames Wright 
95058600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
95158600ac3SJames Wright **/
95258600ac3SJames Wright PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) {
95358600ac3SJames Wright   MatCeedContext ctx;
95458600ac3SJames Wright 
95558600ac3SJames Wright   PetscFunctionBeginUser;
95658600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
95758600ac3SJames Wright   if (log_event_mult) *log_event_mult = ctx->log_event_mult;
95858600ac3SJames Wright   if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose;
95958600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
96058600ac3SJames Wright }
96158600ac3SJames Wright 
96258600ac3SJames Wright // -----------------------------------------------------------------------------
96358600ac3SJames Wright // Operator context data
96458600ac3SJames Wright // -----------------------------------------------------------------------------
96558600ac3SJames Wright 
96658600ac3SJames Wright /**
96758600ac3SJames Wright   @brief Setup context data for operator application.
96858600ac3SJames Wright 
96958600ac3SJames Wright   Collective across MPI processes.
97058600ac3SJames Wright 
97158600ac3SJames Wright   @param[in]   dm_x                      Input `DM`
97258600ac3SJames Wright   @param[in]   dm_y                      Output `DM`
97358600ac3SJames Wright   @param[in]   X_loc                     Input PETSc local vector, or NULL
97458600ac3SJames Wright   @param[in]   Y_loc_transpose           Input PETSc local vector for transpose operation, or NULL
97558600ac3SJames Wright   @param[in]   op_mult                   `CeedOperator` for forward evaluation
97658600ac3SJames Wright   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
97758600ac3SJames Wright   @param[in]   log_event_mult            `PetscLogEvent` for forward evaluation
97858600ac3SJames Wright   @param[in]   log_event_mult_transpose  `PetscLogEvent` for transpose evaluation
97958600ac3SJames Wright   @param[out]  ctx                       Context data for operator evaluation
98058600ac3SJames Wright 
98158600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
98258600ac3SJames Wright **/
98358600ac3SJames Wright PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose,
98458600ac3SJames Wright                                     PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) {
98558600ac3SJames Wright   CeedSize x_loc_len, y_loc_len;
98658600ac3SJames Wright 
98758600ac3SJames Wright   PetscFunctionBeginUser;
98858600ac3SJames Wright 
98958600ac3SJames Wright   // Allocate
99058600ac3SJames Wright   PetscCall(PetscNew(ctx));
99158600ac3SJames Wright   (*ctx)->ref_count = 1;
99258600ac3SJames Wright 
99358600ac3SJames Wright   // Logging
99458600ac3SJames Wright   (*ctx)->log_event_mult           = log_event_mult;
99558600ac3SJames Wright   (*ctx)->log_event_mult_transpose = log_event_mult_transpose;
99658600ac3SJames Wright 
99758600ac3SJames Wright   // PETSc objects
99858600ac3SJames Wright   PetscCall(PetscObjectReference((PetscObject)dm_x));
99958600ac3SJames Wright   (*ctx)->dm_x = dm_x;
100058600ac3SJames Wright   PetscCall(PetscObjectReference((PetscObject)dm_y));
100158600ac3SJames Wright   (*ctx)->dm_y = dm_y;
100258600ac3SJames Wright   if (X_loc) PetscCall(PetscObjectReference((PetscObject)X_loc));
100358600ac3SJames Wright   (*ctx)->X_loc = X_loc;
100458600ac3SJames Wright   if (Y_loc_transpose) PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
100558600ac3SJames Wright   (*ctx)->Y_loc_transpose = Y_loc_transpose;
100658600ac3SJames Wright 
100758600ac3SJames Wright   // Memtype
100858600ac3SJames Wright   {
100958600ac3SJames Wright     const PetscScalar *x;
101058600ac3SJames Wright     Vec                X;
101158600ac3SJames Wright 
101258600ac3SJames Wright     PetscCall(DMGetLocalVector(dm_x, &X));
101358600ac3SJames Wright     PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type));
101458600ac3SJames Wright     PetscCall(VecRestoreArrayReadAndMemType(X, &x));
101558600ac3SJames Wright     PetscCall(DMRestoreLocalVector(dm_x, &X));
101658600ac3SJames Wright   }
101758600ac3SJames Wright 
101858600ac3SJames Wright   // libCEED objects
101958600ac3SJames Wright   PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB,
102058600ac3SJames Wright              "retrieving Ceed context object failed");
102150f50432SJames Wright   PetscCallCeed((*ctx)->ceed, CeedReference((*ctx)->ceed));
102250f50432SJames Wright   PetscCallCeed((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len));
102350f50432SJames Wright   PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult));
102450f50432SJames Wright   if (op_mult_transpose) PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose));
102550f50432SJames Wright   PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc));
102650f50432SJames Wright   PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc));
102758600ac3SJames Wright 
102858600ac3SJames Wright   // Flop counting
102958600ac3SJames Wright   {
103058600ac3SJames Wright     CeedSize ceed_flops_estimate = 0;
103158600ac3SJames Wright 
103250f50432SJames Wright     PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate));
103358600ac3SJames Wright     (*ctx)->flops_mult = ceed_flops_estimate;
103458600ac3SJames Wright     if (op_mult_transpose) {
103550f50432SJames Wright       PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate));
103658600ac3SJames Wright       (*ctx)->flops_mult_transpose = ceed_flops_estimate;
103758600ac3SJames Wright     }
103858600ac3SJames Wright   }
103958600ac3SJames Wright 
104058600ac3SJames Wright   // Check sizes
104158600ac3SJames Wright   if (x_loc_len > 0 || y_loc_len > 0) {
104258600ac3SJames Wright     CeedSize ctx_x_loc_len, ctx_y_loc_len;
104358600ac3SJames Wright     PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len;
104458600ac3SJames Wright     Vec      dm_X_loc, dm_Y_loc;
104558600ac3SJames Wright 
104658600ac3SJames Wright     // -- Input
104758600ac3SJames Wright     PetscCall(DMGetLocalVector(dm_x, &dm_X_loc));
104858600ac3SJames Wright     PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len));
104958600ac3SJames Wright     PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc));
105050f50432SJames Wright     PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len));
10514c17272bSJames Wright     if (X_loc) {
10524c17272bSJames Wright       PetscCall(VecGetLocalSize(X_loc, &X_loc_len));
10534c17272bSJames Wright       PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB,
10544c17272bSJames Wright                  "X_loc (%" PetscInt_FMT ") must match dm_x (%" PetscInt_FMT ") dimensions", X_loc_len, dm_x_loc_len);
10554c17272bSJames Wright     }
10564c17272bSJames Wright     PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op (%" CeedSize_FMT ") must match dm_x (%" PetscInt_FMT ") dimensions",
10574c17272bSJames Wright                x_loc_len, dm_x_loc_len);
10584c17272bSJames Wright     PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc (%" CeedSize_FMT ") must match op dimensions (%" CeedSize_FMT ")",
10594c17272bSJames Wright                x_loc_len, ctx_x_loc_len);
106058600ac3SJames Wright 
106158600ac3SJames Wright     // -- Output
106258600ac3SJames Wright     PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc));
106358600ac3SJames Wright     PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len));
106458600ac3SJames Wright     PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc));
106550f50432SJames Wright     PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len));
10664c17272bSJames Wright     PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op (%" CeedSize_FMT ") must match dm_y (%" PetscInt_FMT ") dimensions",
10674c17272bSJames Wright                ctx_y_loc_len, dm_y_loc_len);
106858600ac3SJames Wright 
106958600ac3SJames Wright     // -- Transpose
107058600ac3SJames Wright     if (Y_loc_transpose) {
107158600ac3SJames Wright       PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len));
10724c17272bSJames Wright       PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB,
10734c17272bSJames Wright                  "Y_loc_transpose (%" PetscInt_FMT ") must match dm_y (%" PetscInt_FMT ") dimensions", Y_loc_len, dm_y_loc_len);
107458600ac3SJames Wright     }
107558600ac3SJames Wright   }
107658600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
107758600ac3SJames Wright }
107858600ac3SJames Wright 
107958600ac3SJames Wright /**
108058600ac3SJames Wright   @brief Increment reference counter for `MATCEED` context.
108158600ac3SJames Wright 
108258600ac3SJames Wright   Not collective across MPI processes.
108358600ac3SJames Wright 
108458600ac3SJames Wright   @param[in,out]  ctx  Context data
108558600ac3SJames Wright 
108658600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
108758600ac3SJames Wright **/
108858600ac3SJames Wright PetscErrorCode MatCeedContextReference(MatCeedContext ctx) {
108958600ac3SJames Wright   PetscFunctionBeginUser;
109058600ac3SJames Wright   ctx->ref_count++;
109158600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
109258600ac3SJames Wright }
109358600ac3SJames Wright 
109458600ac3SJames Wright /**
109558600ac3SJames Wright   @brief Copy reference for `MATCEED`.
109658600ac3SJames Wright          Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`.
109758600ac3SJames Wright 
109858600ac3SJames Wright   Not collective across MPI processes.
109958600ac3SJames Wright 
110058600ac3SJames Wright   @param[in]   ctx       Context data
110158600ac3SJames Wright   @param[out]  ctx_copy  Copy of pointer to context data
110258600ac3SJames Wright 
110358600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
110458600ac3SJames Wright **/
110558600ac3SJames Wright PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) {
110658600ac3SJames Wright   PetscFunctionBeginUser;
110758600ac3SJames Wright   PetscCall(MatCeedContextReference(ctx));
110858600ac3SJames Wright   PetscCall(MatCeedContextDestroy(*ctx_copy));
110958600ac3SJames Wright   *ctx_copy = ctx;
111058600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
111158600ac3SJames Wright }
111258600ac3SJames Wright 
111358600ac3SJames Wright /**
111458600ac3SJames Wright   @brief Destroy context data for operator application.
111558600ac3SJames Wright 
111658600ac3SJames Wright   Collective across MPI processes.
111758600ac3SJames Wright 
111858600ac3SJames Wright   @param[in,out]  ctx  Context data for operator evaluation
111958600ac3SJames Wright 
112058600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
112158600ac3SJames Wright **/
112258600ac3SJames Wright PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) {
112358600ac3SJames Wright   PetscFunctionBeginUser;
112458600ac3SJames Wright   if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS);
112558600ac3SJames Wright 
112658600ac3SJames Wright   // PETSc objects
112758600ac3SJames Wright   PetscCall(DMDestroy(&ctx->dm_x));
112858600ac3SJames Wright   PetscCall(DMDestroy(&ctx->dm_y));
112958600ac3SJames Wright   PetscCall(VecDestroy(&ctx->X_loc));
113058600ac3SJames Wright   PetscCall(VecDestroy(&ctx->Y_loc_transpose));
113158600ac3SJames Wright   PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
113258600ac3SJames Wright   PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
113358600ac3SJames Wright   PetscCall(PetscFree(ctx->internal_mat_type));
113458600ac3SJames Wright   PetscCall(PetscFree(ctx->mats_assembled_full));
113558600ac3SJames Wright   PetscCall(PetscFree(ctx->mats_assembled_pbd));
113658600ac3SJames Wright 
113758600ac3SJames Wright   // libCEED objects
113850f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->x_loc));
113950f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->y_loc));
114050f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full));
114150f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd));
114250f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult));
114350f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose));
114458600ac3SJames Wright   PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed");
114558600ac3SJames Wright 
114658600ac3SJames Wright   // Deallocate
114758600ac3SJames Wright   ctx->is_destroyed = PETSC_TRUE;  // Flag as destroyed in case someone has stale ref
114858600ac3SJames Wright   PetscCall(PetscFree(ctx));
114958600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
115058600ac3SJames Wright }
115158600ac3SJames Wright 
115258600ac3SJames Wright /**
115358600ac3SJames Wright   @brief Compute the diagonal of an operator via libCEED.
115458600ac3SJames Wright 
115558600ac3SJames Wright   Collective across MPI processes.
115658600ac3SJames Wright 
115758600ac3SJames Wright   @param[in]   A  `MATCEED`
115858600ac3SJames Wright   @param[out]  D  Vector holding operator diagonal
115958600ac3SJames Wright 
116058600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
116158600ac3SJames Wright **/
116258600ac3SJames Wright PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) {
116358600ac3SJames Wright   PetscMemType   mem_type;
116458600ac3SJames Wright   Vec            D_loc;
116558600ac3SJames Wright   MatCeedContext ctx;
116658600ac3SJames Wright 
116758600ac3SJames Wright   PetscFunctionBeginUser;
116858600ac3SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
116958600ac3SJames Wright 
117058600ac3SJames Wright   // Place PETSc vector in libCEED vector
117158600ac3SJames Wright   PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc));
1172a7dac1d5SJames Wright   PetscCall(VecPetscToCeed(D_loc, &mem_type, ctx->x_loc));
117358600ac3SJames Wright 
117458600ac3SJames Wright   // Compute Diagonal
117550f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
117658600ac3SJames Wright 
117758600ac3SJames Wright   // Restore PETSc vector
1178a7dac1d5SJames Wright   PetscCall(VecCeedToPetsc(ctx->x_loc, mem_type, D_loc));
117958600ac3SJames Wright 
118058600ac3SJames Wright   // Local-to-Global
118158600ac3SJames Wright   PetscCall(VecZeroEntries(D));
118258600ac3SJames Wright   PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D));
118358600ac3SJames Wright   PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc));
118458600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
118558600ac3SJames Wright }
118658600ac3SJames Wright 
118758600ac3SJames Wright /**
118858600ac3SJames Wright   @brief Compute `A X = Y` for a `MATCEED`.
118958600ac3SJames Wright 
119058600ac3SJames Wright   Collective across MPI processes.
119158600ac3SJames Wright 
119258600ac3SJames Wright   @param[in]   A  `MATCEED`
119358600ac3SJames Wright   @param[in]   X  Input PETSc vector
119458600ac3SJames Wright   @param[out]  Y  Output PETSc vector
119558600ac3SJames Wright 
119658600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
119758600ac3SJames Wright **/
119858600ac3SJames Wright PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) {
119958600ac3SJames Wright   MatCeedContext ctx;
120058600ac3SJames Wright 
120158600ac3SJames Wright   PetscFunctionBeginUser;
120258600ac3SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
120358600ac3SJames Wright   PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0));
120458600ac3SJames Wright 
120558600ac3SJames Wright   {
120658600ac3SJames Wright     PetscMemType x_mem_type, y_mem_type;
120758600ac3SJames Wright     Vec          X_loc = ctx->X_loc, Y_loc;
120858600ac3SJames Wright 
120958600ac3SJames Wright     // Get local vectors
121058600ac3SJames Wright     if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
121158600ac3SJames Wright     PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
121258600ac3SJames Wright 
121358600ac3SJames Wright     // Global-to-local
121458600ac3SJames Wright     PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc));
121558600ac3SJames Wright 
121658600ac3SJames Wright     // Setup libCEED vectors
1217a7dac1d5SJames Wright     PetscCall(VecReadPetscToCeed(X_loc, &x_mem_type, ctx->x_loc));
121858600ac3SJames Wright     PetscCall(VecZeroEntries(Y_loc));
1219a7dac1d5SJames Wright     PetscCall(VecPetscToCeed(Y_loc, &y_mem_type, ctx->y_loc));
122058600ac3SJames Wright 
122158600ac3SJames Wright     // Apply libCEED operator
122258600ac3SJames Wright     PetscCall(PetscLogGpuTimeBegin());
122350f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE));
122458600ac3SJames Wright     PetscCall(PetscLogGpuTimeEnd());
122558600ac3SJames Wright 
122658600ac3SJames Wright     // Restore PETSc vectors
1227a7dac1d5SJames Wright     PetscCall(VecReadCeedToPetsc(ctx->x_loc, x_mem_type, X_loc));
1228a7dac1d5SJames Wright     PetscCall(VecCeedToPetsc(ctx->y_loc, y_mem_type, Y_loc));
122958600ac3SJames Wright 
123058600ac3SJames Wright     // Local-to-global
123158600ac3SJames Wright     PetscCall(VecZeroEntries(Y));
123258600ac3SJames Wright     PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y));
123358600ac3SJames Wright 
123458600ac3SJames Wright     // Restore local vectors, as needed
123558600ac3SJames Wright     if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
123658600ac3SJames Wright     PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
123758600ac3SJames Wright   }
123858600ac3SJames Wright 
123958600ac3SJames Wright   // Log flops
124058600ac3SJames Wright   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult));
124158600ac3SJames Wright   else PetscCall(PetscLogFlops(ctx->flops_mult));
124258600ac3SJames Wright 
124358600ac3SJames Wright   PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0));
124458600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
124558600ac3SJames Wright }
124658600ac3SJames Wright 
124758600ac3SJames Wright /**
124858600ac3SJames Wright   @brief Compute `A^T Y = X` for a `MATCEED`.
124958600ac3SJames Wright 
125058600ac3SJames Wright   Collective across MPI processes.
125158600ac3SJames Wright 
125258600ac3SJames Wright   @param[in]   A  `MATCEED`
125358600ac3SJames Wright   @param[in]   Y  Input PETSc vector
125458600ac3SJames Wright   @param[out]  X  Output PETSc vector
125558600ac3SJames Wright 
125658600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
125758600ac3SJames Wright **/
125858600ac3SJames Wright PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) {
125958600ac3SJames Wright   MatCeedContext ctx;
126058600ac3SJames Wright 
126158600ac3SJames Wright   PetscFunctionBeginUser;
126258600ac3SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
126358600ac3SJames Wright   PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0));
126458600ac3SJames Wright 
126558600ac3SJames Wright   {
126658600ac3SJames Wright     PetscMemType x_mem_type, y_mem_type;
126758600ac3SJames Wright     Vec          X_loc, Y_loc = ctx->Y_loc_transpose;
126858600ac3SJames Wright 
126958600ac3SJames Wright     // Get local vectors
127058600ac3SJames Wright     if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
127158600ac3SJames Wright     PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
127258600ac3SJames Wright 
127358600ac3SJames Wright     // Global-to-local
127458600ac3SJames Wright     PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc));
127558600ac3SJames Wright 
127658600ac3SJames Wright     // Setup libCEED vectors
1277a7dac1d5SJames Wright     PetscCall(VecReadPetscToCeed(Y_loc, &y_mem_type, ctx->y_loc));
127858600ac3SJames Wright     PetscCall(VecZeroEntries(X_loc));
1279a7dac1d5SJames Wright     PetscCall(VecPetscToCeed(X_loc, &x_mem_type, ctx->x_loc));
128058600ac3SJames Wright 
128158600ac3SJames Wright     // Apply libCEED operator
128258600ac3SJames Wright     PetscCall(PetscLogGpuTimeBegin());
128350f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
128458600ac3SJames Wright     PetscCall(PetscLogGpuTimeEnd());
128558600ac3SJames Wright 
128658600ac3SJames Wright     // Restore PETSc vectors
1287a7dac1d5SJames Wright     PetscCall(VecReadCeedToPetsc(ctx->y_loc, y_mem_type, Y_loc));
1288a7dac1d5SJames Wright     PetscCall(VecCeedToPetsc(ctx->x_loc, x_mem_type, X_loc));
128958600ac3SJames Wright 
129058600ac3SJames Wright     // Local-to-global
129158600ac3SJames Wright     PetscCall(VecZeroEntries(X));
129258600ac3SJames Wright     PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X));
129358600ac3SJames Wright 
129458600ac3SJames Wright     // Restore local vectors, as needed
129558600ac3SJames Wright     if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
129658600ac3SJames Wright     PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
129758600ac3SJames Wright   }
129858600ac3SJames Wright 
129958600ac3SJames Wright   // Log flops
130058600ac3SJames Wright   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose));
130158600ac3SJames Wright   else PetscCall(PetscLogFlops(ctx->flops_mult_transpose));
130258600ac3SJames Wright 
130358600ac3SJames Wright   PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0));
130458600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
130558600ac3SJames Wright }
1306