1*58600ac3SJames Wright /// @file 2*58600ac3SJames Wright /// MatCeed and it's related operators 3*58600ac3SJames Wright 4*58600ac3SJames Wright #include <ceed.h> 5*58600ac3SJames Wright #include <ceed/backend.h> 6*58600ac3SJames Wright #include <mat-ceed-impl.h> 7*58600ac3SJames Wright #include <mat-ceed.h> 8*58600ac3SJames Wright #include <petscdmplex.h> 9*58600ac3SJames Wright #include <stdlib.h> 10*58600ac3SJames Wright #include <string.h> 11*58600ac3SJames Wright 12*58600ac3SJames Wright PetscClassId MATCEED_CLASSID; 13*58600ac3SJames Wright PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE; 14*58600ac3SJames Wright 15*58600ac3SJames Wright /** 16*58600ac3SJames Wright @brief Register MATCEED log events. 17*58600ac3SJames Wright 18*58600ac3SJames Wright Not collective across MPI processes. 19*58600ac3SJames Wright 20*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 21*58600ac3SJames Wright **/ 22*58600ac3SJames Wright static PetscErrorCode MatCeedRegisterLogEvents() { 23*58600ac3SJames Wright static bool registered = false; 24*58600ac3SJames Wright 25*58600ac3SJames Wright PetscFunctionBeginUser; 26*58600ac3SJames Wright if (registered) PetscFunctionReturn(PETSC_SUCCESS); 27*58600ac3SJames Wright PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID)); 28*58600ac3SJames Wright PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT)); 29*58600ac3SJames Wright PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE)); 30*58600ac3SJames Wright registered = true; 31*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 32*58600ac3SJames Wright } 33*58600ac3SJames Wright 34*58600ac3SJames Wright /** 35*58600ac3SJames Wright @brief Translate PetscMemType to CeedMemType 36*58600ac3SJames Wright 37*58600ac3SJames Wright @param[in] mem_type PetscMemType 38*58600ac3SJames Wright 39*58600ac3SJames Wright @return Equivalent CeedMemType 40*58600ac3SJames Wright **/ 41*58600ac3SJames Wright static inline CeedMemType MemTypeP2C(PetscMemType mem_type) { return PetscMemTypeDevice(mem_type) ? CEED_MEM_DEVICE : CEED_MEM_HOST; } 42*58600ac3SJames Wright 43*58600ac3SJames Wright /** 44*58600ac3SJames Wright @brief Translate array of `CeedInt` to `PetscInt`. 45*58600ac3SJames Wright If the types differ, `array_ceed` is freed with `free()` and `array_petsc` is allocated with `malloc()`. 46*58600ac3SJames Wright Caller is responsible for freeing `array_petsc` with `free()`. 47*58600ac3SJames Wright 48*58600ac3SJames Wright Not collective across MPI processes. 49*58600ac3SJames Wright 50*58600ac3SJames Wright @param[in] num_entries Number of array entries 51*58600ac3SJames Wright @param[in,out] array_ceed Array of `CeedInt` 52*58600ac3SJames Wright @param[out] array_petsc Array of `PetscInt` 53*58600ac3SJames Wright 54*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 55*58600ac3SJames Wright **/ 56*58600ac3SJames Wright static inline PetscErrorCode IntArrayC2P(PetscInt num_entries, CeedInt **array_ceed, PetscInt **array_petsc) { 57*58600ac3SJames Wright const CeedInt int_c = 0; 58*58600ac3SJames Wright const PetscInt int_p = 0; 59*58600ac3SJames Wright 60*58600ac3SJames Wright PetscFunctionBeginUser; 61*58600ac3SJames Wright if (sizeof(int_c) == sizeof(int_p)) { 62*58600ac3SJames Wright *array_petsc = (PetscInt *)*array_ceed; 63*58600ac3SJames Wright } else { 64*58600ac3SJames Wright *array_petsc = malloc(num_entries * sizeof(PetscInt)); 65*58600ac3SJames Wright for (PetscInt i = 0; i < num_entries; i++) (*array_petsc)[i] = (*array_ceed)[i]; 66*58600ac3SJames Wright free(*array_ceed); 67*58600ac3SJames Wright } 68*58600ac3SJames Wright *array_ceed = NULL; 69*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 70*58600ac3SJames Wright } 71*58600ac3SJames Wright 72*58600ac3SJames Wright /** 73*58600ac3SJames Wright @brief Transfer array from PETSc `Vec` to `CeedVector`. 74*58600ac3SJames Wright 75*58600ac3SJames Wright Collective across MPI processes. 76*58600ac3SJames Wright 77*58600ac3SJames Wright @param[in] ceed libCEED context 78*58600ac3SJames Wright @param[in] X_petsc PETSc `Vec` 79*58600ac3SJames Wright @param[out] mem_type PETSc `MemType` 80*58600ac3SJames Wright @param[out] x_ceed `CeedVector` 81*58600ac3SJames Wright 82*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 83*58600ac3SJames Wright **/ 84*58600ac3SJames Wright static inline PetscErrorCode VecP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) { 85*58600ac3SJames Wright PetscScalar *x; 86*58600ac3SJames Wright 87*58600ac3SJames Wright PetscFunctionBeginUser; 88*58600ac3SJames Wright PetscCall(VecGetArrayAndMemType(X_petsc, &x, mem_type)); 89*58600ac3SJames Wright PetscCeedCall(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x)); 90*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 91*58600ac3SJames Wright } 92*58600ac3SJames Wright 93*58600ac3SJames Wright /** 94*58600ac3SJames Wright @brief Transfer array from `CeedVector` to PETSc `Vec`. 95*58600ac3SJames Wright 96*58600ac3SJames Wright Collective across MPI processes. 97*58600ac3SJames Wright 98*58600ac3SJames Wright @param[in] ceed libCEED context 99*58600ac3SJames Wright @param[in] x_ceed `CeedVector` 100*58600ac3SJames Wright @param[in] mem_type PETSc `MemType` 101*58600ac3SJames Wright @param[out] X_petsc PETSc `Vec` 102*58600ac3SJames Wright 103*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 104*58600ac3SJames Wright **/ 105*58600ac3SJames Wright static inline PetscErrorCode VecC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) { 106*58600ac3SJames Wright PetscScalar *x; 107*58600ac3SJames Wright 108*58600ac3SJames Wright PetscFunctionBeginUser; 109*58600ac3SJames Wright PetscCeedCall(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x)); 110*58600ac3SJames Wright PetscCall(VecRestoreArrayAndMemType(X_petsc, &x)); 111*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 112*58600ac3SJames Wright } 113*58600ac3SJames Wright 114*58600ac3SJames Wright /** 115*58600ac3SJames Wright @brief Transfer read only array from PETSc `Vec` to `CeedVector`. 116*58600ac3SJames Wright 117*58600ac3SJames Wright Collective across MPI processes. 118*58600ac3SJames Wright 119*58600ac3SJames Wright @param[in] ceed libCEED context 120*58600ac3SJames Wright @param[in] X_petsc PETSc `Vec` 121*58600ac3SJames Wright @param[out] mem_type PETSc `MemType` 122*58600ac3SJames Wright @param[out] x_ceed `CeedVector` 123*58600ac3SJames Wright 124*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 125*58600ac3SJames Wright **/ 126*58600ac3SJames Wright static inline PetscErrorCode VecReadP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) { 127*58600ac3SJames Wright PetscScalar *x; 128*58600ac3SJames Wright 129*58600ac3SJames Wright PetscFunctionBeginUser; 130*58600ac3SJames Wright PetscCall(VecGetArrayReadAndMemType(X_petsc, (const PetscScalar **)&x, mem_type)); 131*58600ac3SJames Wright PetscCeedCall(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x)); 132*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 133*58600ac3SJames Wright } 134*58600ac3SJames Wright 135*58600ac3SJames Wright /** 136*58600ac3SJames Wright @brief Transfer read only array from `CeedVector` to PETSc `Vec`. 137*58600ac3SJames Wright 138*58600ac3SJames Wright Collective across MPI processes. 139*58600ac3SJames Wright 140*58600ac3SJames Wright @param[in] ceed libCEED context 141*58600ac3SJames Wright @param[in] x_ceed `CeedVector` 142*58600ac3SJames Wright @param[in] mem_type PETSc `MemType` 143*58600ac3SJames Wright @param[out] X_petsc PETSc `Vec` 144*58600ac3SJames Wright 145*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 146*58600ac3SJames Wright **/ 147*58600ac3SJames Wright static inline PetscErrorCode VecReadC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) { 148*58600ac3SJames Wright PetscScalar *x; 149*58600ac3SJames Wright 150*58600ac3SJames Wright PetscFunctionBeginUser; 151*58600ac3SJames Wright PetscCeedCall(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x)); 152*58600ac3SJames Wright PetscCall(VecRestoreArrayReadAndMemType(X_petsc, (const PetscScalar **)&x)); 153*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 154*58600ac3SJames Wright } 155*58600ac3SJames Wright 156*58600ac3SJames Wright /** 157*58600ac3SJames Wright @brief Setup inner `Mat` for `PC` operations not directly supported by libCEED. 158*58600ac3SJames Wright 159*58600ac3SJames Wright Collective across MPI processes. 160*58600ac3SJames Wright 161*58600ac3SJames Wright @param[in] mat_ceed `MATCEED` to setup 162*58600ac3SJames Wright @param[out] mat_inner Inner `Mat` 163*58600ac3SJames Wright 164*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 165*58600ac3SJames Wright **/ 166*58600ac3SJames Wright static PetscErrorCode MatCeedSetupInnerMat(Mat mat_ceed, Mat *mat_inner) { 167*58600ac3SJames Wright MatCeedContext ctx; 168*58600ac3SJames Wright 169*58600ac3SJames Wright PetscFunctionBeginUser; 170*58600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 171*58600ac3SJames Wright 172*58600ac3SJames Wright PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "PC only supported for MATCEED on a single DM"); 173*58600ac3SJames Wright 174*58600ac3SJames Wright // Check cl mat type 175*58600ac3SJames Wright { 176*58600ac3SJames Wright PetscBool is_internal_mat_type_cl = PETSC_FALSE; 177*58600ac3SJames Wright char internal_mat_type_cl[64]; 178*58600ac3SJames Wright 179*58600ac3SJames Wright // Check for specific CL inner mat type for this Mat 180*58600ac3SJames Wright { 181*58600ac3SJames Wright const char *mat_ceed_prefix = NULL; 182*58600ac3SJames Wright 183*58600ac3SJames Wright PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix)); 184*58600ac3SJames Wright PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL); 185*58600ac3SJames Wright PetscCall(PetscOptionsFList("-ceed_inner_mat_type", "MATCEED inner assembled MatType for PC support", NULL, MatList, internal_mat_type_cl, 186*58600ac3SJames Wright internal_mat_type_cl, sizeof(internal_mat_type_cl), &is_internal_mat_type_cl)); 187*58600ac3SJames Wright PetscOptionsEnd(); 188*58600ac3SJames Wright if (is_internal_mat_type_cl) { 189*58600ac3SJames Wright PetscCall(PetscFree(ctx->internal_mat_type)); 190*58600ac3SJames Wright PetscCall(PetscStrallocpy(internal_mat_type_cl, &ctx->internal_mat_type)); 191*58600ac3SJames Wright } 192*58600ac3SJames Wright } 193*58600ac3SJames Wright } 194*58600ac3SJames Wright 195*58600ac3SJames Wright // Create sparse matrix 196*58600ac3SJames Wright { 197*58600ac3SJames Wright MatType dm_mat_type, dm_mat_type_copy; 198*58600ac3SJames Wright 199*58600ac3SJames Wright PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type)); 200*58600ac3SJames Wright PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy)); 201*58600ac3SJames Wright PetscCall(DMSetMatType(ctx->dm_x, ctx->internal_mat_type)); 202*58600ac3SJames Wright PetscCall(DMCreateMatrix(ctx->dm_x, mat_inner)); 203*58600ac3SJames Wright PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy)); 204*58600ac3SJames Wright PetscCall(PetscFree(dm_mat_type_copy)); 205*58600ac3SJames Wright } 206*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 207*58600ac3SJames Wright } 208*58600ac3SJames Wright 209*58600ac3SJames Wright /** 210*58600ac3SJames Wright @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar. 211*58600ac3SJames Wright The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`. 212*58600ac3SJames Wright The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail. 213*58600ac3SJames Wright 214*58600ac3SJames Wright Collective across MPI processes. 215*58600ac3SJames Wright 216*58600ac3SJames Wright @param[in] mat_ceed `MATCEED` to assemble 217*58600ac3SJames Wright @param[in,out] mat_coo `MATAIJ` or similar to assemble into 218*58600ac3SJames Wright 219*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 220*58600ac3SJames Wright **/ 221*58600ac3SJames Wright static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) { 222*58600ac3SJames Wright MatCeedContext ctx; 223*58600ac3SJames Wright 224*58600ac3SJames Wright PetscFunctionBeginUser; 225*58600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 226*58600ac3SJames Wright 227*58600ac3SJames Wright // Check if COO pattern set 228*58600ac3SJames Wright { 229*58600ac3SJames Wright PetscInt index = -1; 230*58600ac3SJames Wright 231*58600ac3SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) { 232*58600ac3SJames Wright if (ctx->mats_assembled_pbd[i] == mat_coo) index = i; 233*58600ac3SJames Wright } 234*58600ac3SJames Wright if (index == -1) { 235*58600ac3SJames Wright PetscInt *rows_petsc = NULL, *cols_petsc = NULL; 236*58600ac3SJames Wright CeedInt *rows_ceed, *cols_ceed; 237*58600ac3SJames Wright PetscCount num_entries; 238*58600ac3SJames Wright PetscLogStage stage_amg_setup; 239*58600ac3SJames Wright 240*58600ac3SJames Wright // -- Assemble sparsity pattern if mat hasn't been assembled before 241*58600ac3SJames Wright PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup)); 242*58600ac3SJames Wright if (stage_amg_setup == -1) { 243*58600ac3SJames Wright PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup)); 244*58600ac3SJames Wright } 245*58600ac3SJames Wright PetscCall(PetscLogStagePush(stage_amg_setup)); 246*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed)); 247*58600ac3SJames Wright PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc)); 248*58600ac3SJames Wright PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc)); 249*58600ac3SJames Wright PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc)); 250*58600ac3SJames Wright free(rows_petsc); 251*58600ac3SJames Wright free(cols_petsc); 252*58600ac3SJames Wright if (!ctx->coo_values_pbd) PetscCeedCall(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd)); 253*58600ac3SJames Wright PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd)); 254*58600ac3SJames Wright ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo; 255*58600ac3SJames Wright PetscCall(PetscLogStagePop()); 256*58600ac3SJames Wright } 257*58600ac3SJames Wright } 258*58600ac3SJames Wright 259*58600ac3SJames Wright // Assemble mat_ceed 260*58600ac3SJames Wright PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY)); 261*58600ac3SJames Wright { 262*58600ac3SJames Wright const CeedScalar *values; 263*58600ac3SJames Wright MatType mat_type; 264*58600ac3SJames Wright CeedMemType mem_type = CEED_MEM_HOST; 265*58600ac3SJames Wright PetscBool is_spd, is_spd_known; 266*58600ac3SJames Wright 267*58600ac3SJames Wright PetscCall(MatGetType(mat_coo, &mat_type)); 268*58600ac3SJames Wright if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE; 269*58600ac3SJames Wright else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE; 270*58600ac3SJames Wright else mem_type = CEED_MEM_HOST; 271*58600ac3SJames Wright 272*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE)); 273*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values)); 274*58600ac3SJames Wright PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES)); 275*58600ac3SJames Wright PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd)); 276*58600ac3SJames Wright if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd)); 277*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values)); 278*58600ac3SJames Wright } 279*58600ac3SJames Wright PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY)); 280*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 281*58600ac3SJames Wright } 282*58600ac3SJames Wright 283*58600ac3SJames Wright /** 284*58600ac3SJames Wright @brief Assemble inner `Mat` for diagonal `PC` operations 285*58600ac3SJames Wright 286*58600ac3SJames Wright Collective across MPI processes. 287*58600ac3SJames Wright 288*58600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 289*58600ac3SJames Wright @param[in] use_ceed_pbd Boolean flag to use libCEED PBD assembly 290*58600ac3SJames Wright @param[out] mat_inner Inner `Mat` for diagonal operations 291*58600ac3SJames Wright 292*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 293*58600ac3SJames Wright **/ 294*58600ac3SJames Wright static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) { 295*58600ac3SJames Wright MatCeedContext ctx; 296*58600ac3SJames Wright 297*58600ac3SJames Wright PetscFunctionBeginUser; 298*58600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 299*58600ac3SJames Wright if (use_ceed_pbd) { 300*58600ac3SJames Wright // Check if COO pattern set 301*58600ac3SJames Wright if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_pbd_internal)); 302*58600ac3SJames Wright 303*58600ac3SJames Wright // Assemble mat_assembled_full_internal 304*58600ac3SJames Wright PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal)); 305*58600ac3SJames Wright if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal; 306*58600ac3SJames Wright } else { 307*58600ac3SJames Wright // Check if COO pattern set 308*58600ac3SJames Wright if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_full_internal)); 309*58600ac3SJames Wright 310*58600ac3SJames Wright // Assemble mat_assembled_full_internal 311*58600ac3SJames Wright PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal)); 312*58600ac3SJames Wright if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal; 313*58600ac3SJames Wright } 314*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 315*58600ac3SJames Wright } 316*58600ac3SJames Wright 317*58600ac3SJames Wright /** 318*58600ac3SJames Wright @brief Get `MATCEED` diagonal block for Jacobi. 319*58600ac3SJames Wright 320*58600ac3SJames Wright Collective across MPI processes. 321*58600ac3SJames Wright 322*58600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 323*58600ac3SJames Wright @param[out] mat_block The diagonal block matrix 324*58600ac3SJames Wright 325*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 326*58600ac3SJames Wright **/ 327*58600ac3SJames Wright static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) { 328*58600ac3SJames Wright Mat mat_inner = NULL; 329*58600ac3SJames Wright MatCeedContext ctx; 330*58600ac3SJames Wright 331*58600ac3SJames Wright PetscFunctionBeginUser; 332*58600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 333*58600ac3SJames Wright 334*58600ac3SJames Wright // Assemble inner mat if needed 335*58600ac3SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner)); 336*58600ac3SJames Wright 337*58600ac3SJames Wright // Get block diagonal 338*58600ac3SJames Wright PetscCall(MatGetDiagonalBlock(mat_inner, mat_block)); 339*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 340*58600ac3SJames Wright } 341*58600ac3SJames Wright 342*58600ac3SJames Wright /** 343*58600ac3SJames Wright @brief Invert `MATCEED` diagonal block for Jacobi. 344*58600ac3SJames Wright 345*58600ac3SJames Wright Collective across MPI processes. 346*58600ac3SJames Wright 347*58600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 348*58600ac3SJames Wright @param[out] values The block inverses in column major order 349*58600ac3SJames Wright 350*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 351*58600ac3SJames Wright **/ 352*58600ac3SJames Wright static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) { 353*58600ac3SJames Wright Mat mat_inner = NULL; 354*58600ac3SJames Wright MatCeedContext ctx; 355*58600ac3SJames Wright 356*58600ac3SJames Wright PetscFunctionBeginUser; 357*58600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 358*58600ac3SJames Wright 359*58600ac3SJames Wright // Assemble inner mat if needed 360*58600ac3SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner)); 361*58600ac3SJames Wright 362*58600ac3SJames Wright // Invert PB diagonal 363*58600ac3SJames Wright PetscCall(MatInvertBlockDiagonal(mat_inner, values)); 364*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 365*58600ac3SJames Wright } 366*58600ac3SJames Wright 367*58600ac3SJames Wright /** 368*58600ac3SJames Wright @brief Invert `MATCEED` variable diagonal block for Jacobi. 369*58600ac3SJames Wright 370*58600ac3SJames Wright Collective across MPI processes. 371*58600ac3SJames Wright 372*58600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 373*58600ac3SJames Wright @param[in] num_blocks The number of blocks on the process 374*58600ac3SJames Wright @param[in] block_sizes The size of each block on the process 375*58600ac3SJames Wright @param[out] values The block inverses in column major order 376*58600ac3SJames Wright 377*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 378*58600ac3SJames Wright **/ 379*58600ac3SJames Wright static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) { 380*58600ac3SJames Wright Mat mat_inner = NULL; 381*58600ac3SJames Wright MatCeedContext ctx; 382*58600ac3SJames Wright 383*58600ac3SJames Wright PetscFunctionBeginUser; 384*58600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 385*58600ac3SJames Wright 386*58600ac3SJames Wright // Assemble inner mat if needed 387*58600ac3SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner)); 388*58600ac3SJames Wright 389*58600ac3SJames Wright // Invert PB diagonal 390*58600ac3SJames Wright PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values)); 391*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 392*58600ac3SJames Wright } 393*58600ac3SJames Wright 394*58600ac3SJames Wright // ----------------------------------------------------------------------------- 395*58600ac3SJames Wright // MatCeed 396*58600ac3SJames Wright // ----------------------------------------------------------------------------- 397*58600ac3SJames Wright 398*58600ac3SJames Wright /** 399*58600ac3SJames Wright @brief Create PETSc `Mat` from libCEED operators. 400*58600ac3SJames Wright 401*58600ac3SJames Wright Collective across MPI processes. 402*58600ac3SJames Wright 403*58600ac3SJames Wright @param[in] dm_x Input `DM` 404*58600ac3SJames Wright @param[in] dm_y Output `DM` 405*58600ac3SJames Wright @param[in] op_mult `CeedOperator` for forward evaluation 406*58600ac3SJames Wright @param[in] op_mult_transpose `CeedOperator` for transpose evaluation 407*58600ac3SJames Wright @param[out] mat New MatCeed 408*58600ac3SJames Wright 409*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 410*58600ac3SJames Wright **/ 411*58600ac3SJames Wright PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) { 412*58600ac3SJames Wright PetscInt X_l_size, X_g_size, Y_l_size, Y_g_size; 413*58600ac3SJames Wright VecType vec_type; 414*58600ac3SJames Wright MatCeedContext ctx; 415*58600ac3SJames Wright 416*58600ac3SJames Wright PetscFunctionBeginUser; 417*58600ac3SJames Wright PetscCall(MatCeedRegisterLogEvents()); 418*58600ac3SJames Wright 419*58600ac3SJames Wright // Collect context data 420*58600ac3SJames Wright PetscCall(DMGetVecType(dm_x, &vec_type)); 421*58600ac3SJames Wright { 422*58600ac3SJames Wright Vec X; 423*58600ac3SJames Wright 424*58600ac3SJames Wright PetscCall(DMGetGlobalVector(dm_x, &X)); 425*58600ac3SJames Wright PetscCall(VecGetSize(X, &X_g_size)); 426*58600ac3SJames Wright PetscCall(VecGetLocalSize(X, &X_l_size)); 427*58600ac3SJames Wright PetscCall(DMRestoreGlobalVector(dm_x, &X)); 428*58600ac3SJames Wright } 429*58600ac3SJames Wright if (dm_y) { 430*58600ac3SJames Wright Vec Y; 431*58600ac3SJames Wright 432*58600ac3SJames Wright PetscCall(DMGetGlobalVector(dm_y, &Y)); 433*58600ac3SJames Wright PetscCall(VecGetSize(Y, &Y_g_size)); 434*58600ac3SJames Wright PetscCall(VecGetLocalSize(Y, &Y_l_size)); 435*58600ac3SJames Wright PetscCall(DMRestoreGlobalVector(dm_y, &Y)); 436*58600ac3SJames Wright } else { 437*58600ac3SJames Wright dm_y = dm_x; 438*58600ac3SJames Wright Y_g_size = X_g_size; 439*58600ac3SJames Wright Y_l_size = X_l_size; 440*58600ac3SJames Wright } 441*58600ac3SJames Wright // Create context 442*58600ac3SJames Wright { 443*58600ac3SJames Wright Vec X_loc, Y_loc_transpose = NULL; 444*58600ac3SJames Wright 445*58600ac3SJames Wright PetscCall(DMCreateLocalVector(dm_x, &X_loc)); 446*58600ac3SJames Wright PetscCall(VecZeroEntries(X_loc)); 447*58600ac3SJames Wright if (op_mult_transpose) { 448*58600ac3SJames Wright PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose)); 449*58600ac3SJames Wright PetscCall(VecZeroEntries(Y_loc_transpose)); 450*58600ac3SJames Wright } 451*58600ac3SJames Wright PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx)); 452*58600ac3SJames Wright PetscCall(VecDestroy(&X_loc)); 453*58600ac3SJames Wright PetscCall(VecDestroy(&Y_loc_transpose)); 454*58600ac3SJames Wright } 455*58600ac3SJames Wright 456*58600ac3SJames Wright // Create mat 457*58600ac3SJames Wright PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat)); 458*58600ac3SJames Wright PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED)); 459*58600ac3SJames Wright // -- Set block and variable block sizes 460*58600ac3SJames Wright if (dm_x == dm_y) { 461*58600ac3SJames Wright MatType dm_mat_type, dm_mat_type_copy; 462*58600ac3SJames Wright Mat temp_mat; 463*58600ac3SJames Wright 464*58600ac3SJames Wright PetscCall(DMGetMatType(dm_x, &dm_mat_type)); 465*58600ac3SJames Wright PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy)); 466*58600ac3SJames Wright PetscCall(DMSetMatType(dm_x, MATAIJ)); 467*58600ac3SJames Wright PetscCall(DMCreateMatrix(dm_x, &temp_mat)); 468*58600ac3SJames Wright PetscCall(DMSetMatType(dm_x, dm_mat_type_copy)); 469*58600ac3SJames Wright PetscCall(PetscFree(dm_mat_type_copy)); 470*58600ac3SJames Wright 471*58600ac3SJames Wright { 472*58600ac3SJames Wright PetscInt block_size, num_blocks, max_vblock_size = PETSC_INT_MAX; 473*58600ac3SJames Wright const PetscInt *vblock_sizes; 474*58600ac3SJames Wright 475*58600ac3SJames Wright // -- Get block sizes 476*58600ac3SJames Wright PetscCall(MatGetBlockSize(temp_mat, &block_size)); 477*58600ac3SJames Wright PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes)); 478*58600ac3SJames Wright { 479*58600ac3SJames Wright PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX}; 480*58600ac3SJames Wright 481*58600ac3SJames Wright for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]); 482*58600ac3SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max)); 483*58600ac3SJames Wright max_vblock_size = global_min_max[1]; 484*58600ac3SJames Wright } 485*58600ac3SJames Wright 486*58600ac3SJames Wright // -- Copy block sizes 487*58600ac3SJames Wright if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size)); 488*58600ac3SJames Wright if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes)); 489*58600ac3SJames Wright 490*58600ac3SJames Wright // -- Check libCEED compatibility 491*58600ac3SJames Wright { 492*58600ac3SJames Wright bool is_composite; 493*58600ac3SJames Wright 494*58600ac3SJames Wright ctx->is_ceed_pbd_valid = PETSC_TRUE; 495*58600ac3SJames Wright ctx->is_ceed_vpbd_valid = PETSC_TRUE; 496*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite)); 497*58600ac3SJames Wright if (is_composite) { 498*58600ac3SJames Wright CeedInt num_sub_operators; 499*58600ac3SJames Wright CeedOperator *sub_operators; 500*58600ac3SJames Wright 501*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators)); 502*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators)); 503*58600ac3SJames Wright for (CeedInt i = 0; i < num_sub_operators; i++) { 504*58600ac3SJames Wright CeedInt num_bases, num_comp; 505*58600ac3SJames Wright CeedBasis *active_bases; 506*58600ac3SJames Wright CeedOperatorAssemblyData assembly_data; 507*58600ac3SJames Wright 508*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data)); 509*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL)); 510*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp)); 511*58600ac3SJames Wright if (num_bases > 1) { 512*58600ac3SJames Wright ctx->is_ceed_pbd_valid = PETSC_FALSE; 513*58600ac3SJames Wright ctx->is_ceed_vpbd_valid = PETSC_FALSE; 514*58600ac3SJames Wright } 515*58600ac3SJames Wright if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE; 516*58600ac3SJames Wright if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE; 517*58600ac3SJames Wright } 518*58600ac3SJames Wright } else { 519*58600ac3SJames Wright // LCOV_EXCL_START 520*58600ac3SJames Wright CeedInt num_bases, num_comp; 521*58600ac3SJames Wright CeedBasis *active_bases; 522*58600ac3SJames Wright CeedOperatorAssemblyData assembly_data; 523*58600ac3SJames Wright 524*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data)); 525*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL)); 526*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp)); 527*58600ac3SJames Wright if (num_bases > 1) { 528*58600ac3SJames Wright ctx->is_ceed_pbd_valid = PETSC_FALSE; 529*58600ac3SJames Wright ctx->is_ceed_vpbd_valid = PETSC_FALSE; 530*58600ac3SJames Wright } 531*58600ac3SJames Wright if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE; 532*58600ac3SJames Wright if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE; 533*58600ac3SJames Wright // LCOV_EXCL_STOP 534*58600ac3SJames Wright } 535*58600ac3SJames Wright { 536*58600ac3SJames Wright PetscInt local_is_valid[2], global_is_valid[2]; 537*58600ac3SJames Wright 538*58600ac3SJames Wright local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid; 539*58600ac3SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid)); 540*58600ac3SJames Wright ctx->is_ceed_pbd_valid = global_is_valid[0]; 541*58600ac3SJames Wright local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid; 542*58600ac3SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid)); 543*58600ac3SJames Wright ctx->is_ceed_vpbd_valid = global_is_valid[0]; 544*58600ac3SJames Wright } 545*58600ac3SJames Wright } 546*58600ac3SJames Wright } 547*58600ac3SJames Wright PetscCall(MatDestroy(&temp_mat)); 548*58600ac3SJames Wright } 549*58600ac3SJames Wright // -- Set internal mat type 550*58600ac3SJames Wright { 551*58600ac3SJames Wright VecType vec_type; 552*58600ac3SJames Wright MatType internal_mat_type = MATAIJ; 553*58600ac3SJames Wright 554*58600ac3SJames Wright PetscCall(VecGetType(ctx->X_loc, &vec_type)); 555*58600ac3SJames Wright if (strstr(vec_type, VECCUDA)) internal_mat_type = MATAIJCUSPARSE; 556*58600ac3SJames Wright else if (strstr(vec_type, VECKOKKOS)) internal_mat_type = MATAIJKOKKOS; 557*58600ac3SJames Wright else internal_mat_type = MATAIJ; 558*58600ac3SJames Wright PetscCall(PetscStrallocpy(internal_mat_type, &ctx->internal_mat_type)); 559*58600ac3SJames Wright } 560*58600ac3SJames Wright // -- Set mat operations 561*58600ac3SJames Wright PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy)); 562*58600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed)); 563*58600ac3SJames Wright if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed)); 564*58600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed)); 565*58600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed)); 566*58600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed)); 567*58600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed)); 568*58600ac3SJames Wright PetscCall(MatShellSetVecType(*mat, vec_type)); 569*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 570*58600ac3SJames Wright } 571*58600ac3SJames Wright 572*58600ac3SJames Wright /** 573*58600ac3SJames Wright @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`. 574*58600ac3SJames Wright 575*58600ac3SJames Wright Collective across MPI processes. 576*58600ac3SJames Wright 577*58600ac3SJames Wright @param[in] mat_ceed `MATCEED` to copy from 578*58600ac3SJames Wright @param[out] mat_other `MatShell` or `MATCEED` to copy into 579*58600ac3SJames Wright 580*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 581*58600ac3SJames Wright **/ 582*58600ac3SJames Wright PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) { 583*58600ac3SJames Wright PetscFunctionBeginUser; 584*58600ac3SJames Wright PetscCall(MatCeedRegisterLogEvents()); 585*58600ac3SJames Wright 586*58600ac3SJames Wright // Check type compatibility 587*58600ac3SJames Wright { 588*58600ac3SJames Wright MatType mat_type_ceed, mat_type_other; 589*58600ac3SJames Wright 590*58600ac3SJames Wright PetscCall(MatGetType(mat_ceed, &mat_type_ceed)); 591*58600ac3SJames Wright PetscCheck(!strcmp(mat_type_ceed, MATCEED), PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED); 592*58600ac3SJames Wright PetscCall(MatGetType(mat_ceed, &mat_type_other)); 593*58600ac3SJames Wright PetscCheck(!strcmp(mat_type_other, MATCEED) || !strcmp(mat_type_other, MATSHELL), PETSC_COMM_SELF, PETSC_ERR_LIB, 594*58600ac3SJames Wright "mat_other must have type " MATCEED " or " MATSHELL); 595*58600ac3SJames Wright } 596*58600ac3SJames Wright 597*58600ac3SJames Wright // Check dimension compatibility 598*58600ac3SJames Wright { 599*58600ac3SJames Wright PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size; 600*58600ac3SJames Wright 601*58600ac3SJames Wright PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size)); 602*58600ac3SJames Wright PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size)); 603*58600ac3SJames Wright PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size)); 604*58600ac3SJames Wright PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size)); 605*58600ac3SJames Wright PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) && 606*58600ac3SJames Wright (X_l_ceed_size == X_l_other_size), 607*58600ac3SJames Wright PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, 608*58600ac3SJames Wright "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT 609*58600ac3SJames Wright "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT 610*58600ac3SJames Wright ", %" PetscInt_FMT ")", 611*58600ac3SJames Wright Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size); 612*58600ac3SJames Wright } 613*58600ac3SJames Wright 614*58600ac3SJames Wright // Convert 615*58600ac3SJames Wright { 616*58600ac3SJames Wright VecType vec_type; 617*58600ac3SJames Wright MatCeedContext ctx; 618*58600ac3SJames Wright 619*58600ac3SJames Wright PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED)); 620*58600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 621*58600ac3SJames Wright PetscCall(MatCeedContextReference(ctx)); 622*58600ac3SJames Wright PetscCall(MatShellSetContext(mat_other, ctx)); 623*58600ac3SJames Wright PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy)); 624*58600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed)); 625*58600ac3SJames Wright if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed)); 626*58600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed)); 627*58600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed)); 628*58600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed)); 629*58600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed)); 630*58600ac3SJames Wright { 631*58600ac3SJames Wright PetscInt block_size; 632*58600ac3SJames Wright 633*58600ac3SJames Wright PetscCall(MatGetBlockSize(mat_ceed, &block_size)); 634*58600ac3SJames Wright if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size)); 635*58600ac3SJames Wright } 636*58600ac3SJames Wright { 637*58600ac3SJames Wright PetscInt num_blocks; 638*58600ac3SJames Wright const PetscInt *block_sizes; 639*58600ac3SJames Wright 640*58600ac3SJames Wright PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes)); 641*58600ac3SJames Wright if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes)); 642*58600ac3SJames Wright } 643*58600ac3SJames Wright PetscCall(DMGetVecType(ctx->dm_x, &vec_type)); 644*58600ac3SJames Wright PetscCall(MatShellSetVecType(mat_other, vec_type)); 645*58600ac3SJames Wright } 646*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 647*58600ac3SJames Wright } 648*58600ac3SJames Wright 649*58600ac3SJames Wright /** 650*58600ac3SJames Wright @brief Assemble a `MATCEED` into a `MATAIJ` or similar. 651*58600ac3SJames Wright The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`. 652*58600ac3SJames Wright The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail. 653*58600ac3SJames Wright 654*58600ac3SJames Wright Collective across MPI processes. 655*58600ac3SJames Wright 656*58600ac3SJames Wright @param[in] mat_ceed `MATCEED` to assemble 657*58600ac3SJames Wright @param[in,out] mat_coo `MATAIJ` or similar to assemble into 658*58600ac3SJames Wright 659*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 660*58600ac3SJames Wright **/ 661*58600ac3SJames Wright PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) { 662*58600ac3SJames Wright MatCeedContext ctx; 663*58600ac3SJames Wright 664*58600ac3SJames Wright PetscFunctionBeginUser; 665*58600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 666*58600ac3SJames Wright 667*58600ac3SJames Wright // Check if COO pattern set 668*58600ac3SJames Wright { 669*58600ac3SJames Wright PetscInt index = -1; 670*58600ac3SJames Wright 671*58600ac3SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) { 672*58600ac3SJames Wright if (ctx->mats_assembled_full[i] == mat_coo) index = i; 673*58600ac3SJames Wright } 674*58600ac3SJames Wright if (index == -1) { 675*58600ac3SJames Wright PetscInt *rows_petsc = NULL, *cols_petsc = NULL; 676*58600ac3SJames Wright CeedInt *rows_ceed, *cols_ceed; 677*58600ac3SJames Wright PetscCount num_entries; 678*58600ac3SJames Wright PetscLogStage stage_amg_setup; 679*58600ac3SJames Wright 680*58600ac3SJames Wright // -- Assemble sparsity pattern if mat hasn't been assembled before 681*58600ac3SJames Wright PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup)); 682*58600ac3SJames Wright if (stage_amg_setup == -1) { 683*58600ac3SJames Wright PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup)); 684*58600ac3SJames Wright } 685*58600ac3SJames Wright PetscCall(PetscLogStagePush(stage_amg_setup)); 686*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed)); 687*58600ac3SJames Wright PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc)); 688*58600ac3SJames Wright PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc)); 689*58600ac3SJames Wright PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc)); 690*58600ac3SJames Wright free(rows_petsc); 691*58600ac3SJames Wright free(cols_petsc); 692*58600ac3SJames Wright if (!ctx->coo_values_full) PetscCeedCall(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full)); 693*58600ac3SJames Wright PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full)); 694*58600ac3SJames Wright ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo; 695*58600ac3SJames Wright PetscCall(PetscLogStagePop()); 696*58600ac3SJames Wright } 697*58600ac3SJames Wright } 698*58600ac3SJames Wright 699*58600ac3SJames Wright // Assemble mat_ceed 700*58600ac3SJames Wright PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY)); 701*58600ac3SJames Wright { 702*58600ac3SJames Wright const CeedScalar *values; 703*58600ac3SJames Wright MatType mat_type; 704*58600ac3SJames Wright CeedMemType mem_type = CEED_MEM_HOST; 705*58600ac3SJames Wright PetscBool is_spd, is_spd_known; 706*58600ac3SJames Wright 707*58600ac3SJames Wright PetscCall(MatGetType(mat_coo, &mat_type)); 708*58600ac3SJames Wright if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE; 709*58600ac3SJames Wright else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE; 710*58600ac3SJames Wright else mem_type = CEED_MEM_HOST; 711*58600ac3SJames Wright 712*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full)); 713*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values)); 714*58600ac3SJames Wright PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES)); 715*58600ac3SJames Wright PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd)); 716*58600ac3SJames Wright if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd)); 717*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values)); 718*58600ac3SJames Wright } 719*58600ac3SJames Wright PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY)); 720*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 721*58600ac3SJames Wright } 722*58600ac3SJames Wright 723*58600ac3SJames Wright /** 724*58600ac3SJames Wright @brief Set user context for a `MATCEED`. 725*58600ac3SJames Wright 726*58600ac3SJames Wright Collective across MPI processes. 727*58600ac3SJames Wright 728*58600ac3SJames Wright @param[in,out] mat `MATCEED` 729*58600ac3SJames Wright @param[in] f The context destroy function, or NULL 730*58600ac3SJames Wright @param[in] ctx User context, or NULL to unset 731*58600ac3SJames Wright 732*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 733*58600ac3SJames Wright **/ 734*58600ac3SJames Wright PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) { 735*58600ac3SJames Wright PetscContainer user_ctx = NULL; 736*58600ac3SJames Wright 737*58600ac3SJames Wright PetscFunctionBeginUser; 738*58600ac3SJames Wright if (ctx) { 739*58600ac3SJames Wright PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx)); 740*58600ac3SJames Wright PetscCall(PetscContainerSetPointer(user_ctx, ctx)); 741*58600ac3SJames Wright PetscCall(PetscContainerSetUserDestroy(user_ctx, f)); 742*58600ac3SJames Wright } 743*58600ac3SJames Wright PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx)); 744*58600ac3SJames Wright PetscCall(PetscContainerDestroy(&user_ctx)); 745*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 746*58600ac3SJames Wright } 747*58600ac3SJames Wright 748*58600ac3SJames Wright /** 749*58600ac3SJames Wright @brief Retrieve the user context for a `MATCEED`. 750*58600ac3SJames Wright 751*58600ac3SJames Wright Collective across MPI processes. 752*58600ac3SJames Wright 753*58600ac3SJames Wright @param[in,out] mat `MATCEED` 754*58600ac3SJames Wright @param[in] ctx User context 755*58600ac3SJames Wright 756*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 757*58600ac3SJames Wright **/ 758*58600ac3SJames Wright PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) { 759*58600ac3SJames Wright PetscContainer user_ctx; 760*58600ac3SJames Wright 761*58600ac3SJames Wright PetscFunctionBeginUser; 762*58600ac3SJames Wright PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx)); 763*58600ac3SJames Wright if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx)); 764*58600ac3SJames Wright else *(void **)ctx = NULL; 765*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 766*58600ac3SJames Wright } 767*58600ac3SJames Wright 768*58600ac3SJames Wright /** 769*58600ac3SJames Wright @brief Sets the inner matrix type as a string from the `MATCEED`. 770*58600ac3SJames Wright 771*58600ac3SJames Wright Collective across MPI processes. 772*58600ac3SJames Wright 773*58600ac3SJames Wright @param[in,out] mat `MATCEED` 774*58600ac3SJames Wright @param[in] type Inner `MatType` to set 775*58600ac3SJames Wright 776*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 777*58600ac3SJames Wright **/ 778*58600ac3SJames Wright PetscErrorCode MatCeedSetInnerMatType(Mat mat, MatType type) { 779*58600ac3SJames Wright MatCeedContext ctx; 780*58600ac3SJames Wright 781*58600ac3SJames Wright PetscFunctionBeginUser; 782*58600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 783*58600ac3SJames Wright // Check if same 784*58600ac3SJames Wright { 785*58600ac3SJames Wright size_t len_old, len_new; 786*58600ac3SJames Wright PetscBool is_same = PETSC_FALSE; 787*58600ac3SJames Wright 788*58600ac3SJames Wright PetscCall(PetscStrlen(ctx->internal_mat_type, &len_old)); 789*58600ac3SJames Wright PetscCall(PetscStrlen(type, &len_new)); 790*58600ac3SJames Wright if (len_old == len_new) PetscCall(PetscStrncmp(ctx->internal_mat_type, type, len_old, &is_same)); 791*58600ac3SJames Wright if (is_same) PetscFunctionReturn(PETSC_SUCCESS); 792*58600ac3SJames Wright } 793*58600ac3SJames Wright // Clean up old mats in different format 794*58600ac3SJames Wright // LCOV_EXCL_START 795*58600ac3SJames Wright if (ctx->mat_assembled_full_internal) { 796*58600ac3SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) { 797*58600ac3SJames Wright if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) { 798*58600ac3SJames Wright for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) { 799*58600ac3SJames Wright ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j]; 800*58600ac3SJames Wright } 801*58600ac3SJames Wright ctx->num_mats_assembled_full--; 802*58600ac3SJames Wright // Note: we'll realloc this array again, so no need to shrink the allocation 803*58600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_full_internal)); 804*58600ac3SJames Wright } 805*58600ac3SJames Wright } 806*58600ac3SJames Wright } 807*58600ac3SJames Wright if (ctx->mat_assembled_pbd_internal) { 808*58600ac3SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) { 809*58600ac3SJames Wright if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) { 810*58600ac3SJames Wright for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) { 811*58600ac3SJames Wright ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j]; 812*58600ac3SJames Wright } 813*58600ac3SJames Wright // Note: we'll realloc this array again, so no need to shrink the allocation 814*58600ac3SJames Wright ctx->num_mats_assembled_pbd--; 815*58600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal)); 816*58600ac3SJames Wright } 817*58600ac3SJames Wright } 818*58600ac3SJames Wright } 819*58600ac3SJames Wright PetscCall(PetscFree(ctx->internal_mat_type)); 820*58600ac3SJames Wright PetscCall(PetscStrallocpy(type, &ctx->internal_mat_type)); 821*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 822*58600ac3SJames Wright // LCOV_EXCL_STOP 823*58600ac3SJames Wright } 824*58600ac3SJames Wright 825*58600ac3SJames Wright /** 826*58600ac3SJames Wright @brief Gets the inner matrix type as a string from the `MATCEED`. 827*58600ac3SJames Wright 828*58600ac3SJames Wright Collective across MPI processes. 829*58600ac3SJames Wright 830*58600ac3SJames Wright @param[in,out] mat `MATCEED` 831*58600ac3SJames Wright @param[in] type Inner `MatType` 832*58600ac3SJames Wright 833*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 834*58600ac3SJames Wright **/ 835*58600ac3SJames Wright PetscErrorCode MatCeedGetInnerMatType(Mat mat, MatType *type) { 836*58600ac3SJames Wright MatCeedContext ctx; 837*58600ac3SJames Wright 838*58600ac3SJames Wright PetscFunctionBeginUser; 839*58600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 840*58600ac3SJames Wright *type = ctx->internal_mat_type; 841*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 842*58600ac3SJames Wright } 843*58600ac3SJames Wright 844*58600ac3SJames Wright /** 845*58600ac3SJames Wright @brief Set a user defined matrix operation for a `MATCEED` matrix. 846*58600ac3SJames Wright 847*58600ac3SJames Wright Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by 848*58600ac3SJames Wright `MatCeedSetContext()`. 849*58600ac3SJames Wright 850*58600ac3SJames Wright Collective across MPI processes. 851*58600ac3SJames Wright 852*58600ac3SJames Wright @param[in,out] mat `MATCEED` 853*58600ac3SJames Wright @param[in] op Name of the `MatOperation` 854*58600ac3SJames Wright @param[in] g Function that provides the operation 855*58600ac3SJames Wright 856*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 857*58600ac3SJames Wright **/ 858*58600ac3SJames Wright PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) { 859*58600ac3SJames Wright PetscFunctionBeginUser; 860*58600ac3SJames Wright PetscCall(MatShellSetOperation(mat, op, g)); 861*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 862*58600ac3SJames Wright } 863*58600ac3SJames Wright 864*58600ac3SJames Wright /** 865*58600ac3SJames Wright @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 866*58600ac3SJames Wright 867*58600ac3SJames Wright Not collective across MPI processes. 868*58600ac3SJames Wright 869*58600ac3SJames Wright @param[in,out] mat `MATCEED` 870*58600ac3SJames Wright @param[in] X_loc Input PETSc local vector, or NULL 871*58600ac3SJames Wright @param[in] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 872*58600ac3SJames Wright 873*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 874*58600ac3SJames Wright **/ 875*58600ac3SJames Wright PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) { 876*58600ac3SJames Wright MatCeedContext ctx; 877*58600ac3SJames Wright 878*58600ac3SJames Wright PetscFunctionBeginUser; 879*58600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 880*58600ac3SJames Wright if (X_loc) { 881*58600ac3SJames Wright PetscInt len_old, len_new; 882*58600ac3SJames Wright 883*58600ac3SJames Wright PetscCall(VecGetSize(ctx->X_loc, &len_old)); 884*58600ac3SJames Wright PetscCall(VecGetSize(X_loc, &len_new)); 885*58600ac3SJames Wright PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT, 886*58600ac3SJames Wright len_new, len_old); 887*58600ac3SJames Wright PetscCall(VecDestroy(&ctx->X_loc)); 888*58600ac3SJames Wright ctx->X_loc = X_loc; 889*58600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)X_loc)); 890*58600ac3SJames Wright } 891*58600ac3SJames Wright if (Y_loc_transpose) { 892*58600ac3SJames Wright PetscInt len_old, len_new; 893*58600ac3SJames Wright 894*58600ac3SJames Wright PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old)); 895*58600ac3SJames Wright PetscCall(VecGetSize(Y_loc_transpose, &len_new)); 896*58600ac3SJames Wright PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, 897*58600ac3SJames Wright "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old); 898*58600ac3SJames Wright PetscCall(VecDestroy(&ctx->Y_loc_transpose)); 899*58600ac3SJames Wright ctx->Y_loc_transpose = Y_loc_transpose; 900*58600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose)); 901*58600ac3SJames Wright } 902*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 903*58600ac3SJames Wright } 904*58600ac3SJames Wright 905*58600ac3SJames Wright /** 906*58600ac3SJames Wright @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 907*58600ac3SJames Wright 908*58600ac3SJames Wright Not collective across MPI processes. 909*58600ac3SJames Wright 910*58600ac3SJames Wright @param[in,out] mat `MATCEED` 911*58600ac3SJames Wright @param[out] X_loc Input PETSc local vector, or NULL 912*58600ac3SJames Wright @param[out] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 913*58600ac3SJames Wright 914*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 915*58600ac3SJames Wright **/ 916*58600ac3SJames Wright PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) { 917*58600ac3SJames Wright MatCeedContext ctx; 918*58600ac3SJames Wright 919*58600ac3SJames Wright PetscFunctionBeginUser; 920*58600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 921*58600ac3SJames Wright if (X_loc) { 922*58600ac3SJames Wright *X_loc = ctx->X_loc; 923*58600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)*X_loc)); 924*58600ac3SJames Wright } 925*58600ac3SJames Wright if (Y_loc_transpose) { 926*58600ac3SJames Wright *Y_loc_transpose = ctx->Y_loc_transpose; 927*58600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)*Y_loc_transpose)); 928*58600ac3SJames Wright } 929*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 930*58600ac3SJames Wright } 931*58600ac3SJames Wright 932*58600ac3SJames Wright /** 933*58600ac3SJames Wright @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 934*58600ac3SJames Wright 935*58600ac3SJames Wright Not collective across MPI processes. 936*58600ac3SJames Wright 937*58600ac3SJames Wright @param[in,out] mat MatCeed 938*58600ac3SJames Wright @param[out] X_loc Input PETSc local vector, or NULL 939*58600ac3SJames Wright @param[out] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 940*58600ac3SJames Wright 941*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 942*58600ac3SJames Wright **/ 943*58600ac3SJames Wright PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) { 944*58600ac3SJames Wright PetscFunctionBeginUser; 945*58600ac3SJames Wright if (X_loc) PetscCall(VecDestroy(X_loc)); 946*58600ac3SJames Wright if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose)); 947*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 948*58600ac3SJames Wright } 949*58600ac3SJames Wright 950*58600ac3SJames Wright /** 951*58600ac3SJames Wright @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 952*58600ac3SJames Wright 953*58600ac3SJames Wright Not collective across MPI processes. 954*58600ac3SJames Wright 955*58600ac3SJames Wright @param[in,out] mat MatCeed 956*58600ac3SJames Wright @param[out] op_mult libCEED `CeedOperator` for `MatMult()`, or NULL 957*58600ac3SJames Wright @param[out] op_mult_transpose libCEED `CeedOperator` for `MatMultTranspose()`, or NULL 958*58600ac3SJames Wright 959*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 960*58600ac3SJames Wright **/ 961*58600ac3SJames Wright PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) { 962*58600ac3SJames Wright MatCeedContext ctx; 963*58600ac3SJames Wright 964*58600ac3SJames Wright PetscFunctionBeginUser; 965*58600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 966*58600ac3SJames Wright if (op_mult) { 967*58600ac3SJames Wright *op_mult = NULL; 968*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult)); 969*58600ac3SJames Wright } 970*58600ac3SJames Wright if (op_mult_transpose) { 971*58600ac3SJames Wright *op_mult_transpose = NULL; 972*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose)); 973*58600ac3SJames Wright } 974*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 975*58600ac3SJames Wright } 976*58600ac3SJames Wright 977*58600ac3SJames Wright /** 978*58600ac3SJames Wright @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 979*58600ac3SJames Wright 980*58600ac3SJames Wright Not collective across MPI processes. 981*58600ac3SJames Wright 982*58600ac3SJames Wright @param[in,out] mat MatCeed 983*58600ac3SJames Wright @param[out] op_mult libCEED `CeedOperator` for `MatMult()`, or NULL 984*58600ac3SJames Wright @param[out] op_mult_transpose libCEED `CeedOperator` for `MatMultTranspose()`, or NULL 985*58600ac3SJames Wright 986*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 987*58600ac3SJames Wright **/ 988*58600ac3SJames Wright PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) { 989*58600ac3SJames Wright MatCeedContext ctx; 990*58600ac3SJames Wright 991*58600ac3SJames Wright PetscFunctionBeginUser; 992*58600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 993*58600ac3SJames Wright if (op_mult) PetscCeedCall(ctx->ceed, CeedOperatorDestroy(op_mult)); 994*58600ac3SJames Wright if (op_mult_transpose) PetscCeedCall(ctx->ceed, CeedOperatorDestroy(op_mult_transpose)); 995*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 996*58600ac3SJames Wright } 997*58600ac3SJames Wright 998*58600ac3SJames Wright /** 999*58600ac3SJames Wright @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators. 1000*58600ac3SJames Wright 1001*58600ac3SJames Wright Not collective across MPI processes. 1002*58600ac3SJames Wright 1003*58600ac3SJames Wright @param[in,out] mat MatCeed 1004*58600ac3SJames Wright @param[out] log_event_mult `PetscLogEvent` for forward evaluation, or NULL 1005*58600ac3SJames Wright @param[out] log_event_mult_transpose `PetscLogEvent` for transpose evaluation, or NULL 1006*58600ac3SJames Wright 1007*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 1008*58600ac3SJames Wright **/ 1009*58600ac3SJames Wright PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) { 1010*58600ac3SJames Wright MatCeedContext ctx; 1011*58600ac3SJames Wright 1012*58600ac3SJames Wright PetscFunctionBeginUser; 1013*58600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 1014*58600ac3SJames Wright if (log_event_mult) ctx->log_event_mult = log_event_mult; 1015*58600ac3SJames Wright if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose; 1016*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1017*58600ac3SJames Wright } 1018*58600ac3SJames Wright 1019*58600ac3SJames Wright /** 1020*58600ac3SJames Wright @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators. 1021*58600ac3SJames Wright 1022*58600ac3SJames Wright Not collective across MPI processes. 1023*58600ac3SJames Wright 1024*58600ac3SJames Wright @param[in,out] mat MatCeed 1025*58600ac3SJames Wright @param[out] log_event_mult `PetscLogEvent` for forward evaluation, or NULL 1026*58600ac3SJames Wright @param[out] log_event_mult_transpose `PetscLogEvent` for transpose evaluation, or NULL 1027*58600ac3SJames Wright 1028*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 1029*58600ac3SJames Wright **/ 1030*58600ac3SJames Wright PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) { 1031*58600ac3SJames Wright MatCeedContext ctx; 1032*58600ac3SJames Wright 1033*58600ac3SJames Wright PetscFunctionBeginUser; 1034*58600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 1035*58600ac3SJames Wright if (log_event_mult) *log_event_mult = ctx->log_event_mult; 1036*58600ac3SJames Wright if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose; 1037*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1038*58600ac3SJames Wright } 1039*58600ac3SJames Wright 1040*58600ac3SJames Wright // ----------------------------------------------------------------------------- 1041*58600ac3SJames Wright // Operator context data 1042*58600ac3SJames Wright // ----------------------------------------------------------------------------- 1043*58600ac3SJames Wright 1044*58600ac3SJames Wright /** 1045*58600ac3SJames Wright @brief Setup context data for operator application. 1046*58600ac3SJames Wright 1047*58600ac3SJames Wright Collective across MPI processes. 1048*58600ac3SJames Wright 1049*58600ac3SJames Wright @param[in] dm_x Input `DM` 1050*58600ac3SJames Wright @param[in] dm_y Output `DM` 1051*58600ac3SJames Wright @param[in] X_loc Input PETSc local vector, or NULL 1052*58600ac3SJames Wright @param[in] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 1053*58600ac3SJames Wright @param[in] op_mult `CeedOperator` for forward evaluation 1054*58600ac3SJames Wright @param[in] op_mult_transpose `CeedOperator` for transpose evaluation 1055*58600ac3SJames Wright @param[in] log_event_mult `PetscLogEvent` for forward evaluation 1056*58600ac3SJames Wright @param[in] log_event_mult_transpose `PetscLogEvent` for transpose evaluation 1057*58600ac3SJames Wright @param[out] ctx Context data for operator evaluation 1058*58600ac3SJames Wright 1059*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 1060*58600ac3SJames Wright **/ 1061*58600ac3SJames Wright PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose, 1062*58600ac3SJames Wright PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) { 1063*58600ac3SJames Wright CeedSize x_loc_len, y_loc_len; 1064*58600ac3SJames Wright 1065*58600ac3SJames Wright PetscFunctionBeginUser; 1066*58600ac3SJames Wright 1067*58600ac3SJames Wright // Allocate 1068*58600ac3SJames Wright PetscCall(PetscNew(ctx)); 1069*58600ac3SJames Wright (*ctx)->ref_count = 1; 1070*58600ac3SJames Wright 1071*58600ac3SJames Wright // Logging 1072*58600ac3SJames Wright (*ctx)->log_event_mult = log_event_mult; 1073*58600ac3SJames Wright (*ctx)->log_event_mult_transpose = log_event_mult_transpose; 1074*58600ac3SJames Wright 1075*58600ac3SJames Wright // PETSc objects 1076*58600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)dm_x)); 1077*58600ac3SJames Wright (*ctx)->dm_x = dm_x; 1078*58600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)dm_y)); 1079*58600ac3SJames Wright (*ctx)->dm_y = dm_y; 1080*58600ac3SJames Wright if (X_loc) PetscCall(PetscObjectReference((PetscObject)X_loc)); 1081*58600ac3SJames Wright (*ctx)->X_loc = X_loc; 1082*58600ac3SJames Wright if (Y_loc_transpose) PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose)); 1083*58600ac3SJames Wright (*ctx)->Y_loc_transpose = Y_loc_transpose; 1084*58600ac3SJames Wright 1085*58600ac3SJames Wright // Memtype 1086*58600ac3SJames Wright { 1087*58600ac3SJames Wright const PetscScalar *x; 1088*58600ac3SJames Wright Vec X; 1089*58600ac3SJames Wright 1090*58600ac3SJames Wright PetscCall(DMGetLocalVector(dm_x, &X)); 1091*58600ac3SJames Wright PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type)); 1092*58600ac3SJames Wright PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 1093*58600ac3SJames Wright PetscCall(DMRestoreLocalVector(dm_x, &X)); 1094*58600ac3SJames Wright } 1095*58600ac3SJames Wright 1096*58600ac3SJames Wright // libCEED objects 1097*58600ac3SJames Wright PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, 1098*58600ac3SJames Wright "retrieving Ceed context object failed"); 1099*58600ac3SJames Wright PetscCeedCall((*ctx)->ceed, CeedReference((*ctx)->ceed)); 1100*58600ac3SJames Wright PetscCeedCall((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len)); 1101*58600ac3SJames Wright PetscCeedCall((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult)); 1102*58600ac3SJames Wright if (op_mult_transpose) PetscCeedCall((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose)); 1103*58600ac3SJames Wright PetscCeedCall((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc)); 1104*58600ac3SJames Wright PetscCeedCall((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc)); 1105*58600ac3SJames Wright 1106*58600ac3SJames Wright // Flop counting 1107*58600ac3SJames Wright { 1108*58600ac3SJames Wright CeedSize ceed_flops_estimate = 0; 1109*58600ac3SJames Wright 1110*58600ac3SJames Wright PetscCeedCall((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate)); 1111*58600ac3SJames Wright (*ctx)->flops_mult = ceed_flops_estimate; 1112*58600ac3SJames Wright if (op_mult_transpose) { 1113*58600ac3SJames Wright PetscCeedCall((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate)); 1114*58600ac3SJames Wright (*ctx)->flops_mult_transpose = ceed_flops_estimate; 1115*58600ac3SJames Wright } 1116*58600ac3SJames Wright } 1117*58600ac3SJames Wright 1118*58600ac3SJames Wright // Check sizes 1119*58600ac3SJames Wright if (x_loc_len > 0 || y_loc_len > 0) { 1120*58600ac3SJames Wright CeedSize ctx_x_loc_len, ctx_y_loc_len; 1121*58600ac3SJames Wright PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len; 1122*58600ac3SJames Wright Vec dm_X_loc, dm_Y_loc; 1123*58600ac3SJames Wright 1124*58600ac3SJames Wright // -- Input 1125*58600ac3SJames Wright PetscCall(DMGetLocalVector(dm_x, &dm_X_loc)); 1126*58600ac3SJames Wright PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len)); 1127*58600ac3SJames Wright PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc)); 1128*58600ac3SJames Wright if (X_loc) PetscCall(VecGetLocalSize(X_loc, &X_loc_len)); 1129*58600ac3SJames Wright PetscCeedCall((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len)); 1130*58600ac3SJames Wright if (X_loc) PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "X_loc must match dm_x dimensions"); 1131*58600ac3SJames Wright PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_x dimensions"); 1132*58600ac3SJames Wright PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc must match op dimensions"); 1133*58600ac3SJames Wright 1134*58600ac3SJames Wright // -- Output 1135*58600ac3SJames Wright PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc)); 1136*58600ac3SJames Wright PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len)); 1137*58600ac3SJames Wright PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc)); 1138*58600ac3SJames Wright PetscCeedCall((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len)); 1139*58600ac3SJames Wright PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_y dimensions"); 1140*58600ac3SJames Wright 1141*58600ac3SJames Wright // -- Transpose 1142*58600ac3SJames Wright if (Y_loc_transpose) { 1143*58600ac3SJames Wright PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len)); 1144*58600ac3SJames Wright PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "Y_loc_transpose must match dm_y dimensions"); 1145*58600ac3SJames Wright } 1146*58600ac3SJames Wright } 1147*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1148*58600ac3SJames Wright } 1149*58600ac3SJames Wright 1150*58600ac3SJames Wright /** 1151*58600ac3SJames Wright @brief Increment reference counter for `MATCEED` context. 1152*58600ac3SJames Wright 1153*58600ac3SJames Wright Not collective across MPI processes. 1154*58600ac3SJames Wright 1155*58600ac3SJames Wright @param[in,out] ctx Context data 1156*58600ac3SJames Wright 1157*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 1158*58600ac3SJames Wright **/ 1159*58600ac3SJames Wright PetscErrorCode MatCeedContextReference(MatCeedContext ctx) { 1160*58600ac3SJames Wright PetscFunctionBeginUser; 1161*58600ac3SJames Wright ctx->ref_count++; 1162*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1163*58600ac3SJames Wright } 1164*58600ac3SJames Wright 1165*58600ac3SJames Wright /** 1166*58600ac3SJames Wright @brief Copy reference for `MATCEED`. 1167*58600ac3SJames Wright Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`. 1168*58600ac3SJames Wright 1169*58600ac3SJames Wright Not collective across MPI processes. 1170*58600ac3SJames Wright 1171*58600ac3SJames Wright @param[in] ctx Context data 1172*58600ac3SJames Wright @param[out] ctx_copy Copy of pointer to context data 1173*58600ac3SJames Wright 1174*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 1175*58600ac3SJames Wright **/ 1176*58600ac3SJames Wright PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) { 1177*58600ac3SJames Wright PetscFunctionBeginUser; 1178*58600ac3SJames Wright PetscCall(MatCeedContextReference(ctx)); 1179*58600ac3SJames Wright PetscCall(MatCeedContextDestroy(*ctx_copy)); 1180*58600ac3SJames Wright *ctx_copy = ctx; 1181*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1182*58600ac3SJames Wright } 1183*58600ac3SJames Wright 1184*58600ac3SJames Wright /** 1185*58600ac3SJames Wright @brief Destroy context data for operator application. 1186*58600ac3SJames Wright 1187*58600ac3SJames Wright Collective across MPI processes. 1188*58600ac3SJames Wright 1189*58600ac3SJames Wright @param[in,out] ctx Context data for operator evaluation 1190*58600ac3SJames Wright 1191*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 1192*58600ac3SJames Wright **/ 1193*58600ac3SJames Wright PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) { 1194*58600ac3SJames Wright PetscFunctionBeginUser; 1195*58600ac3SJames Wright if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS); 1196*58600ac3SJames Wright 1197*58600ac3SJames Wright // PETSc objects 1198*58600ac3SJames Wright PetscCall(DMDestroy(&ctx->dm_x)); 1199*58600ac3SJames Wright PetscCall(DMDestroy(&ctx->dm_y)); 1200*58600ac3SJames Wright PetscCall(VecDestroy(&ctx->X_loc)); 1201*58600ac3SJames Wright PetscCall(VecDestroy(&ctx->Y_loc_transpose)); 1202*58600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_full_internal)); 1203*58600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal)); 1204*58600ac3SJames Wright PetscCall(PetscFree(ctx->internal_mat_type)); 1205*58600ac3SJames Wright PetscCall(PetscFree(ctx->mats_assembled_full)); 1206*58600ac3SJames Wright PetscCall(PetscFree(ctx->mats_assembled_pbd)); 1207*58600ac3SJames Wright 1208*58600ac3SJames Wright // libCEED objects 1209*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->x_loc)); 1210*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->y_loc)); 1211*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full)); 1212*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd)); 1213*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult)); 1214*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose)); 1215*58600ac3SJames Wright PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed"); 1216*58600ac3SJames Wright 1217*58600ac3SJames Wright // Deallocate 1218*58600ac3SJames Wright ctx->is_destroyed = PETSC_TRUE; // Flag as destroyed in case someone has stale ref 1219*58600ac3SJames Wright PetscCall(PetscFree(ctx)); 1220*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1221*58600ac3SJames Wright } 1222*58600ac3SJames Wright 1223*58600ac3SJames Wright /** 1224*58600ac3SJames Wright @brief Compute the diagonal of an operator via libCEED. 1225*58600ac3SJames Wright 1226*58600ac3SJames Wright Collective across MPI processes. 1227*58600ac3SJames Wright 1228*58600ac3SJames Wright @param[in] A `MATCEED` 1229*58600ac3SJames Wright @param[out] D Vector holding operator diagonal 1230*58600ac3SJames Wright 1231*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 1232*58600ac3SJames Wright **/ 1233*58600ac3SJames Wright PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) { 1234*58600ac3SJames Wright PetscMemType mem_type; 1235*58600ac3SJames Wright Vec D_loc; 1236*58600ac3SJames Wright MatCeedContext ctx; 1237*58600ac3SJames Wright 1238*58600ac3SJames Wright PetscFunctionBeginUser; 1239*58600ac3SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 1240*58600ac3SJames Wright 1241*58600ac3SJames Wright // Place PETSc vector in libCEED vector 1242*58600ac3SJames Wright PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc)); 1243*58600ac3SJames Wright PetscCall(VecP2C(ctx->ceed, D_loc, &mem_type, ctx->x_loc)); 1244*58600ac3SJames Wright 1245*58600ac3SJames Wright // Compute Diagonal 1246*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE)); 1247*58600ac3SJames Wright 1248*58600ac3SJames Wright // Restore PETSc vector 1249*58600ac3SJames Wright PetscCall(VecC2P(ctx->ceed, ctx->x_loc, mem_type, D_loc)); 1250*58600ac3SJames Wright 1251*58600ac3SJames Wright // Local-to-Global 1252*58600ac3SJames Wright PetscCall(VecZeroEntries(D)); 1253*58600ac3SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D)); 1254*58600ac3SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc)); 1255*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1256*58600ac3SJames Wright } 1257*58600ac3SJames Wright 1258*58600ac3SJames Wright /** 1259*58600ac3SJames Wright @brief Compute `A X = Y` for a `MATCEED`. 1260*58600ac3SJames Wright 1261*58600ac3SJames Wright Collective across MPI processes. 1262*58600ac3SJames Wright 1263*58600ac3SJames Wright @param[in] A `MATCEED` 1264*58600ac3SJames Wright @param[in] X Input PETSc vector 1265*58600ac3SJames Wright @param[out] Y Output PETSc vector 1266*58600ac3SJames Wright 1267*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 1268*58600ac3SJames Wright **/ 1269*58600ac3SJames Wright PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) { 1270*58600ac3SJames Wright MatCeedContext ctx; 1271*58600ac3SJames Wright 1272*58600ac3SJames Wright PetscFunctionBeginUser; 1273*58600ac3SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 1274*58600ac3SJames Wright PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0)); 1275*58600ac3SJames Wright 1276*58600ac3SJames Wright { 1277*58600ac3SJames Wright PetscMemType x_mem_type, y_mem_type; 1278*58600ac3SJames Wright Vec X_loc = ctx->X_loc, Y_loc; 1279*58600ac3SJames Wright 1280*58600ac3SJames Wright // Get local vectors 1281*58600ac3SJames Wright if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc)); 1282*58600ac3SJames Wright PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc)); 1283*58600ac3SJames Wright 1284*58600ac3SJames Wright // Global-to-local 1285*58600ac3SJames Wright PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc)); 1286*58600ac3SJames Wright 1287*58600ac3SJames Wright // Setup libCEED vectors 1288*58600ac3SJames Wright PetscCall(VecReadP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc)); 1289*58600ac3SJames Wright PetscCall(VecZeroEntries(Y_loc)); 1290*58600ac3SJames Wright PetscCall(VecP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc)); 1291*58600ac3SJames Wright 1292*58600ac3SJames Wright // Apply libCEED operator 1293*58600ac3SJames Wright PetscCall(PetscLogGpuTimeBegin()); 1294*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE)); 1295*58600ac3SJames Wright PetscCall(PetscLogGpuTimeEnd()); 1296*58600ac3SJames Wright 1297*58600ac3SJames Wright // Restore PETSc vectors 1298*58600ac3SJames Wright PetscCall(VecReadC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc)); 1299*58600ac3SJames Wright PetscCall(VecC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc)); 1300*58600ac3SJames Wright 1301*58600ac3SJames Wright // Local-to-global 1302*58600ac3SJames Wright PetscCall(VecZeroEntries(Y)); 1303*58600ac3SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y)); 1304*58600ac3SJames Wright 1305*58600ac3SJames Wright // Restore local vectors, as needed 1306*58600ac3SJames Wright if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc)); 1307*58600ac3SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc)); 1308*58600ac3SJames Wright } 1309*58600ac3SJames Wright 1310*58600ac3SJames Wright // Log flops 1311*58600ac3SJames Wright if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult)); 1312*58600ac3SJames Wright else PetscCall(PetscLogFlops(ctx->flops_mult)); 1313*58600ac3SJames Wright 1314*58600ac3SJames Wright PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0)); 1315*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1316*58600ac3SJames Wright } 1317*58600ac3SJames Wright 1318*58600ac3SJames Wright /** 1319*58600ac3SJames Wright @brief Compute `A^T Y = X` for a `MATCEED`. 1320*58600ac3SJames Wright 1321*58600ac3SJames Wright Collective across MPI processes. 1322*58600ac3SJames Wright 1323*58600ac3SJames Wright @param[in] A `MATCEED` 1324*58600ac3SJames Wright @param[in] Y Input PETSc vector 1325*58600ac3SJames Wright @param[out] X Output PETSc vector 1326*58600ac3SJames Wright 1327*58600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 1328*58600ac3SJames Wright **/ 1329*58600ac3SJames Wright PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) { 1330*58600ac3SJames Wright MatCeedContext ctx; 1331*58600ac3SJames Wright 1332*58600ac3SJames Wright PetscFunctionBeginUser; 1333*58600ac3SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 1334*58600ac3SJames Wright PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0)); 1335*58600ac3SJames Wright 1336*58600ac3SJames Wright { 1337*58600ac3SJames Wright PetscMemType x_mem_type, y_mem_type; 1338*58600ac3SJames Wright Vec X_loc, Y_loc = ctx->Y_loc_transpose; 1339*58600ac3SJames Wright 1340*58600ac3SJames Wright // Get local vectors 1341*58600ac3SJames Wright if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc)); 1342*58600ac3SJames Wright PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc)); 1343*58600ac3SJames Wright 1344*58600ac3SJames Wright // Global-to-local 1345*58600ac3SJames Wright PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc)); 1346*58600ac3SJames Wright 1347*58600ac3SJames Wright // Setup libCEED vectors 1348*58600ac3SJames Wright PetscCall(VecReadP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc)); 1349*58600ac3SJames Wright PetscCall(VecZeroEntries(X_loc)); 1350*58600ac3SJames Wright PetscCall(VecP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc)); 1351*58600ac3SJames Wright 1352*58600ac3SJames Wright // Apply libCEED operator 1353*58600ac3SJames Wright PetscCall(PetscLogGpuTimeBegin()); 1354*58600ac3SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE)); 1355*58600ac3SJames Wright PetscCall(PetscLogGpuTimeEnd()); 1356*58600ac3SJames Wright 1357*58600ac3SJames Wright // Restore PETSc vectors 1358*58600ac3SJames Wright PetscCall(VecReadC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc)); 1359*58600ac3SJames Wright PetscCall(VecC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc)); 1360*58600ac3SJames Wright 1361*58600ac3SJames Wright // Local-to-global 1362*58600ac3SJames Wright PetscCall(VecZeroEntries(X)); 1363*58600ac3SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X)); 1364*58600ac3SJames Wright 1365*58600ac3SJames Wright // Restore local vectors, as needed 1366*58600ac3SJames Wright if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc)); 1367*58600ac3SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc)); 1368*58600ac3SJames Wright } 1369*58600ac3SJames Wright 1370*58600ac3SJames Wright // Log flops 1371*58600ac3SJames Wright if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose)); 1372*58600ac3SJames Wright else PetscCall(PetscLogFlops(ctx->flops_mult_transpose)); 1373*58600ac3SJames Wright 1374*58600ac3SJames Wright PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0)); 1375*58600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1376*58600ac3SJames Wright } 1377