xref: /libCEED/examples/fluids/src/mat-ceed.c (revision 54831c5f47ae63ad598bd81054291277b4e7e451)
1c8564c30SJames Wright /// @file
2c8564c30SJames Wright /// MatCeed and it's related operators
3c8564c30SJames Wright 
4c8564c30SJames Wright #include <ceed.h>
5c8564c30SJames Wright #include <ceed/backend.h>
6c8564c30SJames Wright #include <mat-ceed-impl.h>
7c8564c30SJames Wright #include <mat-ceed.h>
8c8564c30SJames Wright #include <petscdmplex.h>
9c8564c30SJames Wright #include <stdlib.h>
10c8564c30SJames Wright #include <string.h>
11c8564c30SJames Wright 
12c8564c30SJames Wright PetscClassId  MATCEED_CLASSID;
13c8564c30SJames Wright PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE;
14c8564c30SJames Wright 
15c8564c30SJames Wright /**
16c8564c30SJames Wright   @brief Register MATCEED log events.
17c8564c30SJames Wright 
18c8564c30SJames Wright   Not collective across MPI processes.
19c8564c30SJames Wright 
20c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
21c8564c30SJames Wright **/
22c8564c30SJames Wright static PetscErrorCode MatCeedRegisterLogEvents() {
23c8564c30SJames Wright   static bool registered = false;
24c8564c30SJames Wright 
25c8564c30SJames Wright   PetscFunctionBeginUser;
26c8564c30SJames Wright   if (registered) PetscFunctionReturn(PETSC_SUCCESS);
27c8564c30SJames Wright   PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID));
28c8564c30SJames Wright   PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT));
29c8564c30SJames Wright   PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE));
30c8564c30SJames Wright   registered = true;
31c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
32c8564c30SJames Wright }
33c8564c30SJames Wright 
34c8564c30SJames Wright /**
35c8564c30SJames Wright   @brief Translate PetscMemType to CeedMemType
36c8564c30SJames Wright 
37c8564c30SJames Wright   @param[in]  mem_type  PetscMemType
38c8564c30SJames Wright 
39c8564c30SJames Wright   @return Equivalent CeedMemType
40c8564c30SJames Wright **/
41c8564c30SJames Wright static inline CeedMemType MemTypeP2C(PetscMemType mem_type) { return PetscMemTypeDevice(mem_type) ? CEED_MEM_DEVICE : CEED_MEM_HOST; }
42c8564c30SJames Wright 
43c8564c30SJames Wright /**
44c8564c30SJames Wright   @brief Translate array of `CeedInt` to `PetscInt`.
45c8564c30SJames Wright     If the types differ, `array_ceed` is freed with `free()` and `array_petsc` is allocated with `malloc()`.
46c8564c30SJames Wright     Caller is responsible for freeing `array_petsc` with `free()`.
47c8564c30SJames Wright 
48c8564c30SJames Wright   Not collective across MPI processes.
49c8564c30SJames Wright 
50c8564c30SJames Wright   @param[in]      num_entries  Number of array entries
51c8564c30SJames Wright   @param[in,out]  array_ceed   Array of `CeedInt`
52c8564c30SJames Wright   @param[out]     array_petsc  Array of `PetscInt`
53c8564c30SJames Wright 
54c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
55c8564c30SJames Wright **/
56c8564c30SJames Wright static inline PetscErrorCode IntArrayC2P(PetscInt num_entries, CeedInt **array_ceed, PetscInt **array_petsc) {
57c8564c30SJames Wright   const CeedInt  int_c = 0;
58c8564c30SJames Wright   const PetscInt int_p = 0;
59c8564c30SJames Wright 
60c8564c30SJames Wright   PetscFunctionBeginUser;
61c8564c30SJames Wright   if (sizeof(int_c) == sizeof(int_p)) {
62c8564c30SJames Wright     *array_petsc = (PetscInt *)*array_ceed;
63c8564c30SJames Wright   } else {
64c8564c30SJames Wright     *array_petsc = malloc(num_entries * sizeof(PetscInt));
65c8564c30SJames Wright     for (PetscInt i = 0; i < num_entries; i++) (*array_petsc)[i] = (*array_ceed)[i];
66c8564c30SJames Wright     free(*array_ceed);
67c8564c30SJames Wright   }
68c8564c30SJames Wright   *array_ceed = NULL;
69c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
70c8564c30SJames Wright }
71c8564c30SJames Wright 
72c8564c30SJames Wright /**
73c8564c30SJames Wright   @brief Transfer array from PETSc `Vec` to `CeedVector`.
74c8564c30SJames Wright 
75c8564c30SJames Wright   Collective across MPI processes.
76c8564c30SJames Wright 
77c8564c30SJames Wright   @param[in]   ceed      libCEED context
78c8564c30SJames Wright   @param[in]   X_petsc   PETSc `Vec`
79c8564c30SJames Wright   @param[out]  mem_type  PETSc `MemType`
80c8564c30SJames Wright   @param[out]  x_ceed    `CeedVector`
81c8564c30SJames Wright 
82c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
83c8564c30SJames Wright **/
84c8564c30SJames Wright static inline PetscErrorCode VecP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
85c8564c30SJames Wright   PetscScalar *x;
86c8564c30SJames Wright 
87c8564c30SJames Wright   PetscFunctionBeginUser;
88c8564c30SJames Wright   PetscCall(VecGetArrayAndMemType(X_petsc, &x, mem_type));
89*54831c5fSJames Wright   PetscCallCeed(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x));
90c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
91c8564c30SJames Wright }
92c8564c30SJames Wright 
93c8564c30SJames Wright /**
94c8564c30SJames Wright   @brief Transfer array from `CeedVector` to PETSc `Vec`.
95c8564c30SJames Wright 
96c8564c30SJames Wright   Collective across MPI processes.
97c8564c30SJames Wright 
98c8564c30SJames Wright   @param[in]   ceed      libCEED context
99c8564c30SJames Wright   @param[in]   x_ceed    `CeedVector`
100c8564c30SJames Wright   @param[in]   mem_type  PETSc `MemType`
101c8564c30SJames Wright   @param[out]  X_petsc   PETSc `Vec`
102c8564c30SJames Wright 
103c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
104c8564c30SJames Wright **/
105c8564c30SJames Wright static inline PetscErrorCode VecC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
106c8564c30SJames Wright   PetscScalar *x;
107c8564c30SJames Wright 
108c8564c30SJames Wright   PetscFunctionBeginUser;
109*54831c5fSJames Wright   PetscCallCeed(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x));
110c8564c30SJames Wright   PetscCall(VecRestoreArrayAndMemType(X_petsc, &x));
111c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
112c8564c30SJames Wright }
113c8564c30SJames Wright 
114c8564c30SJames Wright /**
115c8564c30SJames Wright   @brief Transfer read only array from PETSc `Vec` to `CeedVector`.
116c8564c30SJames Wright 
117c8564c30SJames Wright   Collective across MPI processes.
118c8564c30SJames Wright 
119c8564c30SJames Wright   @param[in]   ceed      libCEED context
120c8564c30SJames Wright   @param[in]   X_petsc   PETSc `Vec`
121c8564c30SJames Wright   @param[out]  mem_type  PETSc `MemType`
122c8564c30SJames Wright   @param[out]  x_ceed    `CeedVector`
123c8564c30SJames Wright 
124c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
125c8564c30SJames Wright **/
126c8564c30SJames Wright static inline PetscErrorCode VecReadP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
127c8564c30SJames Wright   PetscScalar *x;
128c8564c30SJames Wright 
129c8564c30SJames Wright   PetscFunctionBeginUser;
130c8564c30SJames Wright   PetscCall(VecGetArrayReadAndMemType(X_petsc, (const PetscScalar **)&x, mem_type));
131*54831c5fSJames Wright   PetscCallCeed(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x));
132c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
133c8564c30SJames Wright }
134c8564c30SJames Wright 
135c8564c30SJames Wright /**
136c8564c30SJames Wright   @brief Transfer read only array from `CeedVector` to PETSc `Vec`.
137c8564c30SJames Wright 
138c8564c30SJames Wright   Collective across MPI processes.
139c8564c30SJames Wright 
140c8564c30SJames Wright   @param[in]   ceed      libCEED context
141c8564c30SJames Wright   @param[in]   x_ceed    `CeedVector`
142c8564c30SJames Wright   @param[in]   mem_type  PETSc `MemType`
143c8564c30SJames Wright   @param[out]  X_petsc   PETSc `Vec`
144c8564c30SJames Wright 
145c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
146c8564c30SJames Wright **/
147c8564c30SJames Wright static inline PetscErrorCode VecReadC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
148c8564c30SJames Wright   PetscScalar *x;
149c8564c30SJames Wright 
150c8564c30SJames Wright   PetscFunctionBeginUser;
151*54831c5fSJames Wright   PetscCallCeed(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x));
152c8564c30SJames Wright   PetscCall(VecRestoreArrayReadAndMemType(X_petsc, (const PetscScalar **)&x));
153c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
154c8564c30SJames Wright }
155c8564c30SJames Wright 
156c8564c30SJames Wright /**
157c8564c30SJames Wright   @brief Setup inner `Mat` for `PC` operations not directly supported by libCEED.
158c8564c30SJames Wright 
159c8564c30SJames Wright   Collective across MPI processes.
160c8564c30SJames Wright 
161c8564c30SJames Wright   @param[in]   mat_ceed   `MATCEED` to setup
162c8564c30SJames Wright   @param[out]  mat_inner  Inner `Mat`
163c8564c30SJames Wright 
164c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
165c8564c30SJames Wright **/
166c8564c30SJames Wright static PetscErrorCode MatCeedSetupInnerMat(Mat mat_ceed, Mat *mat_inner) {
167c8564c30SJames Wright   MatCeedContext ctx;
168c8564c30SJames Wright 
169c8564c30SJames Wright   PetscFunctionBeginUser;
170c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
171c8564c30SJames Wright 
172c8564c30SJames Wright   PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "PC only supported for MATCEED on a single DM");
173c8564c30SJames Wright 
174c8564c30SJames Wright   // Check cl mat type
175c8564c30SJames Wright   {
176c8564c30SJames Wright     PetscBool is_internal_mat_type_cl = PETSC_FALSE;
177c8564c30SJames Wright     char      internal_mat_type_cl[64];
178c8564c30SJames Wright 
179c8564c30SJames Wright     // Check for specific CL inner mat type for this Mat
180c8564c30SJames Wright     {
181c8564c30SJames Wright       const char *mat_ceed_prefix = NULL;
182c8564c30SJames Wright 
183c8564c30SJames Wright       PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix));
184c8564c30SJames Wright       PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL);
185c8564c30SJames Wright       PetscCall(PetscOptionsFList("-ceed_inner_mat_type", "MATCEED inner assembled MatType for PC support", NULL, MatList, internal_mat_type_cl,
186c8564c30SJames Wright                                   internal_mat_type_cl, sizeof(internal_mat_type_cl), &is_internal_mat_type_cl));
187c8564c30SJames Wright       PetscOptionsEnd();
188c8564c30SJames Wright       if (is_internal_mat_type_cl) {
189c8564c30SJames Wright         PetscCall(PetscFree(ctx->internal_mat_type));
190c8564c30SJames Wright         PetscCall(PetscStrallocpy(internal_mat_type_cl, &ctx->internal_mat_type));
191c8564c30SJames Wright       }
192c8564c30SJames Wright     }
193c8564c30SJames Wright   }
194c8564c30SJames Wright 
195c8564c30SJames Wright   // Create sparse matrix
196c8564c30SJames Wright   {
197c8564c30SJames Wright     MatType dm_mat_type, dm_mat_type_copy;
198c8564c30SJames Wright 
199c8564c30SJames Wright     PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type));
200c8564c30SJames Wright     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
201c8564c30SJames Wright     PetscCall(DMSetMatType(ctx->dm_x, ctx->internal_mat_type));
202c8564c30SJames Wright     PetscCall(DMCreateMatrix(ctx->dm_x, mat_inner));
203c8564c30SJames Wright     PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy));
204c8564c30SJames Wright     PetscCall(PetscFree(dm_mat_type_copy));
205c8564c30SJames Wright   }
206c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
207c8564c30SJames Wright }
208c8564c30SJames Wright 
209c8564c30SJames Wright /**
210c8564c30SJames Wright   @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar.
211c8564c30SJames Wright          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
212c8564c30SJames Wright          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
213c8564c30SJames Wright 
214c8564c30SJames Wright   Collective across MPI processes.
215c8564c30SJames Wright 
216c8564c30SJames Wright   @param[in]      mat_ceed  `MATCEED` to assemble
217c8564c30SJames Wright   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
218c8564c30SJames Wright 
219c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
220c8564c30SJames Wright **/
221c8564c30SJames Wright static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) {
222c8564c30SJames Wright   MatCeedContext ctx;
223c8564c30SJames Wright 
224c8564c30SJames Wright   PetscFunctionBeginUser;
225c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
226c8564c30SJames Wright 
227c8564c30SJames Wright   // Check if COO pattern set
228c8564c30SJames Wright   {
229c8564c30SJames Wright     PetscInt index = -1;
230c8564c30SJames Wright 
231c8564c30SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
232c8564c30SJames Wright       if (ctx->mats_assembled_pbd[i] == mat_coo) index = i;
233c8564c30SJames Wright     }
234c8564c30SJames Wright     if (index == -1) {
235c8564c30SJames Wright       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
236c8564c30SJames Wright       CeedInt      *rows_ceed, *cols_ceed;
237c8564c30SJames Wright       PetscCount    num_entries;
238c8564c30SJames Wright       PetscLogStage stage_amg_setup;
239c8564c30SJames Wright 
240c8564c30SJames Wright       // -- Assemble sparsity pattern if mat hasn't been assembled before
241c8564c30SJames Wright       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
242c8564c30SJames Wright       if (stage_amg_setup == -1) {
243c8564c30SJames Wright         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
244c8564c30SJames Wright       }
245c8564c30SJames Wright       PetscCall(PetscLogStagePush(stage_amg_setup));
246*54831c5fSJames Wright       PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
247c8564c30SJames Wright       PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc));
248c8564c30SJames Wright       PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc));
249c8564c30SJames Wright       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
250c8564c30SJames Wright       free(rows_petsc);
251c8564c30SJames Wright       free(cols_petsc);
252*54831c5fSJames Wright       if (!ctx->coo_values_pbd) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd));
253c8564c30SJames Wright       PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd));
254c8564c30SJames Wright       ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo;
255c8564c30SJames Wright       PetscCall(PetscLogStagePop());
256c8564c30SJames Wright     }
257c8564c30SJames Wright   }
258c8564c30SJames Wright 
259c8564c30SJames Wright   // Assemble mat_ceed
260c8564c30SJames Wright   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
261c8564c30SJames Wright   {
262c8564c30SJames Wright     const CeedScalar *values;
263c8564c30SJames Wright     MatType           mat_type;
264c8564c30SJames Wright     CeedMemType       mem_type = CEED_MEM_HOST;
265c8564c30SJames Wright     PetscBool         is_spd, is_spd_known;
266c8564c30SJames Wright 
267c8564c30SJames Wright     PetscCall(MatGetType(mat_coo, &mat_type));
268c8564c30SJames Wright     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
269c8564c30SJames Wright     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
270c8564c30SJames Wright     else mem_type = CEED_MEM_HOST;
271c8564c30SJames Wright 
272*54831c5fSJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE));
273*54831c5fSJames Wright     PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values));
274c8564c30SJames Wright     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
275c8564c30SJames Wright     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
276c8564c30SJames Wright     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
277*54831c5fSJames Wright     PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values));
278c8564c30SJames Wright   }
279c8564c30SJames Wright   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
280c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
281c8564c30SJames Wright }
282c8564c30SJames Wright 
283c8564c30SJames Wright /**
284c8564c30SJames Wright   @brief Assemble inner `Mat` for diagonal `PC` operations
285c8564c30SJames Wright 
286c8564c30SJames Wright   Collective across MPI processes.
287c8564c30SJames Wright 
288c8564c30SJames Wright   @param[in]   mat_ceed      `MATCEED` to invert
289c8564c30SJames Wright   @param[in]   use_ceed_pbd  Boolean flag to use libCEED PBD assembly
290c8564c30SJames Wright   @param[out]  mat_inner     Inner `Mat` for diagonal operations
291c8564c30SJames Wright 
292c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
293c8564c30SJames Wright **/
294c8564c30SJames Wright static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) {
295c8564c30SJames Wright   MatCeedContext ctx;
296c8564c30SJames Wright 
297c8564c30SJames Wright   PetscFunctionBeginUser;
298c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
299c8564c30SJames Wright   if (use_ceed_pbd) {
300c8564c30SJames Wright     // Check if COO pattern set
301c8564c30SJames Wright     if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_pbd_internal));
302c8564c30SJames Wright 
303c8564c30SJames Wright     // Assemble mat_assembled_full_internal
304c8564c30SJames Wright     PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal));
305c8564c30SJames Wright     if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal;
306c8564c30SJames Wright   } else {
307c8564c30SJames Wright     // Check if COO pattern set
308c8564c30SJames Wright     if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_full_internal));
309c8564c30SJames Wright 
310c8564c30SJames Wright     // Assemble mat_assembled_full_internal
311c8564c30SJames Wright     PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal));
312c8564c30SJames Wright     if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal;
313c8564c30SJames Wright   }
314c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
315c8564c30SJames Wright }
316c8564c30SJames Wright 
317c8564c30SJames Wright /**
318c8564c30SJames Wright   @brief Get `MATCEED` diagonal block for Jacobi.
319c8564c30SJames Wright 
320c8564c30SJames Wright   Collective across MPI processes.
321c8564c30SJames Wright 
322c8564c30SJames Wright   @param[in]   mat_ceed   `MATCEED` to invert
323c8564c30SJames Wright   @param[out]  mat_block  The diagonal block matrix
324c8564c30SJames Wright 
325c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
326c8564c30SJames Wright **/
327c8564c30SJames Wright static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) {
328c8564c30SJames Wright   Mat            mat_inner = NULL;
329c8564c30SJames Wright   MatCeedContext ctx;
330c8564c30SJames Wright 
331c8564c30SJames Wright   PetscFunctionBeginUser;
332c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
333c8564c30SJames Wright 
334c8564c30SJames Wright   // Assemble inner mat if needed
335c8564c30SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
336c8564c30SJames Wright 
337c8564c30SJames Wright   // Get block diagonal
338c8564c30SJames Wright   PetscCall(MatGetDiagonalBlock(mat_inner, mat_block));
339c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
340c8564c30SJames Wright }
341c8564c30SJames Wright 
342c8564c30SJames Wright /**
343c8564c30SJames Wright   @brief Invert `MATCEED` diagonal block for Jacobi.
344c8564c30SJames Wright 
345c8564c30SJames Wright   Collective across MPI processes.
346c8564c30SJames Wright 
347c8564c30SJames Wright   @param[in]   mat_ceed  `MATCEED` to invert
348c8564c30SJames Wright   @param[out]  values    The block inverses in column major order
349c8564c30SJames Wright 
350c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
351c8564c30SJames Wright **/
352c8564c30SJames Wright static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) {
353c8564c30SJames Wright   Mat            mat_inner = NULL;
354c8564c30SJames Wright   MatCeedContext ctx;
355c8564c30SJames Wright 
356c8564c30SJames Wright   PetscFunctionBeginUser;
357c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
358c8564c30SJames Wright 
359c8564c30SJames Wright   // Assemble inner mat if needed
360c8564c30SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
361c8564c30SJames Wright 
362c8564c30SJames Wright   // Invert PB diagonal
363c8564c30SJames Wright   PetscCall(MatInvertBlockDiagonal(mat_inner, values));
364c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
365c8564c30SJames Wright }
366c8564c30SJames Wright 
367c8564c30SJames Wright /**
368c8564c30SJames Wright   @brief Invert `MATCEED` variable diagonal block for Jacobi.
369c8564c30SJames Wright 
370c8564c30SJames Wright   Collective across MPI processes.
371c8564c30SJames Wright 
372c8564c30SJames Wright   @param[in]   mat_ceed     `MATCEED` to invert
373c8564c30SJames Wright   @param[in]   num_blocks   The number of blocks on the process
374c8564c30SJames Wright   @param[in]   block_sizes  The size of each block on the process
375c8564c30SJames Wright   @param[out]  values       The block inverses in column major order
376c8564c30SJames Wright 
377c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
378c8564c30SJames Wright **/
379c8564c30SJames Wright static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) {
380c8564c30SJames Wright   Mat            mat_inner = NULL;
381c8564c30SJames Wright   MatCeedContext ctx;
382c8564c30SJames Wright 
383c8564c30SJames Wright   PetscFunctionBeginUser;
384c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
385c8564c30SJames Wright 
386c8564c30SJames Wright   // Assemble inner mat if needed
387c8564c30SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner));
388c8564c30SJames Wright 
389c8564c30SJames Wright   // Invert PB diagonal
390c8564c30SJames Wright   PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values));
391c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
392c8564c30SJames Wright }
393c8564c30SJames Wright 
394c8564c30SJames Wright // -----------------------------------------------------------------------------
395c8564c30SJames Wright // MatCeed
396c8564c30SJames Wright // -----------------------------------------------------------------------------
397c8564c30SJames Wright 
398c8564c30SJames Wright /**
399c8564c30SJames Wright   @brief Create PETSc `Mat` from libCEED operators.
400c8564c30SJames Wright 
401c8564c30SJames Wright   Collective across MPI processes.
402c8564c30SJames Wright 
403c8564c30SJames Wright   @param[in]   dm_x                      Input `DM`
404c8564c30SJames Wright   @param[in]   dm_y                      Output `DM`
405c8564c30SJames Wright   @param[in]   op_mult                   `CeedOperator` for forward evaluation
406c8564c30SJames Wright   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
407c8564c30SJames Wright   @param[out]  mat                        New MatCeed
408c8564c30SJames Wright 
409c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
410c8564c30SJames Wright **/
411c8564c30SJames Wright PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) {
412c8564c30SJames Wright   PetscInt       X_l_size, X_g_size, Y_l_size, Y_g_size;
413c8564c30SJames Wright   VecType        vec_type;
414c8564c30SJames Wright   MatCeedContext ctx;
415c8564c30SJames Wright 
416c8564c30SJames Wright   PetscFunctionBeginUser;
417c8564c30SJames Wright   PetscCall(MatCeedRegisterLogEvents());
418c8564c30SJames Wright 
419c8564c30SJames Wright   // Collect context data
420c8564c30SJames Wright   PetscCall(DMGetVecType(dm_x, &vec_type));
421c8564c30SJames Wright   {
422c8564c30SJames Wright     Vec X;
423c8564c30SJames Wright 
424c8564c30SJames Wright     PetscCall(DMGetGlobalVector(dm_x, &X));
425c8564c30SJames Wright     PetscCall(VecGetSize(X, &X_g_size));
426c8564c30SJames Wright     PetscCall(VecGetLocalSize(X, &X_l_size));
427c8564c30SJames Wright     PetscCall(DMRestoreGlobalVector(dm_x, &X));
428c8564c30SJames Wright   }
429c8564c30SJames Wright   if (dm_y) {
430c8564c30SJames Wright     Vec Y;
431c8564c30SJames Wright 
432c8564c30SJames Wright     PetscCall(DMGetGlobalVector(dm_y, &Y));
433c8564c30SJames Wright     PetscCall(VecGetSize(Y, &Y_g_size));
434c8564c30SJames Wright     PetscCall(VecGetLocalSize(Y, &Y_l_size));
435c8564c30SJames Wright     PetscCall(DMRestoreGlobalVector(dm_y, &Y));
436c8564c30SJames Wright   } else {
437c8564c30SJames Wright     dm_y     = dm_x;
438c8564c30SJames Wright     Y_g_size = X_g_size;
439c8564c30SJames Wright     Y_l_size = X_l_size;
440c8564c30SJames Wright   }
441c8564c30SJames Wright   // Create context
442c8564c30SJames Wright   {
443c8564c30SJames Wright     Vec X_loc, Y_loc_transpose = NULL;
444c8564c30SJames Wright 
445c8564c30SJames Wright     PetscCall(DMCreateLocalVector(dm_x, &X_loc));
446c8564c30SJames Wright     PetscCall(VecZeroEntries(X_loc));
447c8564c30SJames Wright     if (op_mult_transpose) {
448c8564c30SJames Wright       PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose));
449c8564c30SJames Wright       PetscCall(VecZeroEntries(Y_loc_transpose));
450c8564c30SJames Wright     }
451c8564c30SJames Wright     PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx));
452c8564c30SJames Wright     PetscCall(VecDestroy(&X_loc));
453c8564c30SJames Wright     PetscCall(VecDestroy(&Y_loc_transpose));
454c8564c30SJames Wright   }
455c8564c30SJames Wright 
456c8564c30SJames Wright   // Create mat
457c8564c30SJames Wright   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat));
458c8564c30SJames Wright   PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED));
459c8564c30SJames Wright   // -- Set block and variable block sizes
460c8564c30SJames Wright   if (dm_x == dm_y) {
461c8564c30SJames Wright     MatType dm_mat_type, dm_mat_type_copy;
462c8564c30SJames Wright     Mat     temp_mat;
463c8564c30SJames Wright 
464c8564c30SJames Wright     PetscCall(DMGetMatType(dm_x, &dm_mat_type));
465c8564c30SJames Wright     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
466c8564c30SJames Wright     PetscCall(DMSetMatType(dm_x, MATAIJ));
467c8564c30SJames Wright     PetscCall(DMCreateMatrix(dm_x, &temp_mat));
468c8564c30SJames Wright     PetscCall(DMSetMatType(dm_x, dm_mat_type_copy));
469c8564c30SJames Wright     PetscCall(PetscFree(dm_mat_type_copy));
470c8564c30SJames Wright 
471c8564c30SJames Wright     {
472c8564c30SJames Wright       PetscInt        block_size, num_blocks, max_vblock_size = PETSC_INT_MAX;
473c8564c30SJames Wright       const PetscInt *vblock_sizes;
474c8564c30SJames Wright 
475c8564c30SJames Wright       // -- Get block sizes
476c8564c30SJames Wright       PetscCall(MatGetBlockSize(temp_mat, &block_size));
477c8564c30SJames Wright       PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes));
478c8564c30SJames Wright       {
479c8564c30SJames Wright         PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX};
480c8564c30SJames Wright 
481c8564c30SJames Wright         for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]);
482c8564c30SJames Wright         PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max));
483c8564c30SJames Wright         max_vblock_size = global_min_max[1];
484c8564c30SJames Wright       }
485c8564c30SJames Wright 
486c8564c30SJames Wright       // -- Copy block sizes
487c8564c30SJames Wright       if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size));
488c8564c30SJames Wright       if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes));
489c8564c30SJames Wright 
490c8564c30SJames Wright       // -- Check libCEED compatibility
491c8564c30SJames Wright       {
492c8564c30SJames Wright         bool is_composite;
493c8564c30SJames Wright 
494c8564c30SJames Wright         ctx->is_ceed_pbd_valid  = PETSC_TRUE;
495c8564c30SJames Wright         ctx->is_ceed_vpbd_valid = PETSC_TRUE;
496*54831c5fSJames Wright         PetscCallCeed(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite));
497c8564c30SJames Wright         if (is_composite) {
498c8564c30SJames Wright           CeedInt       num_sub_operators;
499c8564c30SJames Wright           CeedOperator *sub_operators;
500c8564c30SJames Wright 
501*54831c5fSJames Wright           PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators));
502*54831c5fSJames Wright           PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators));
503c8564c30SJames Wright           for (CeedInt i = 0; i < num_sub_operators; i++) {
504c8564c30SJames Wright             CeedInt                  num_bases, num_comp;
505c8564c30SJames Wright             CeedBasis               *active_bases;
506c8564c30SJames Wright             CeedOperatorAssemblyData assembly_data;
507c8564c30SJames Wright 
508*54831c5fSJames Wright             PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data));
509*54831c5fSJames Wright             PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
510*54831c5fSJames Wright             PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
511c8564c30SJames Wright             if (num_bases > 1) {
512c8564c30SJames Wright               ctx->is_ceed_pbd_valid  = PETSC_FALSE;
513c8564c30SJames Wright               ctx->is_ceed_vpbd_valid = PETSC_FALSE;
514c8564c30SJames Wright             }
515c8564c30SJames Wright             if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
516c8564c30SJames Wright             if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
517c8564c30SJames Wright           }
518c8564c30SJames Wright         } else {
519c8564c30SJames Wright           // LCOV_EXCL_START
520c8564c30SJames Wright           CeedInt                  num_bases, num_comp;
521c8564c30SJames Wright           CeedBasis               *active_bases;
522c8564c30SJames Wright           CeedOperatorAssemblyData assembly_data;
523c8564c30SJames Wright 
524*54831c5fSJames Wright           PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data));
525*54831c5fSJames Wright           PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
526*54831c5fSJames Wright           PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
527c8564c30SJames Wright           if (num_bases > 1) {
528c8564c30SJames Wright             ctx->is_ceed_pbd_valid  = PETSC_FALSE;
529c8564c30SJames Wright             ctx->is_ceed_vpbd_valid = PETSC_FALSE;
530c8564c30SJames Wright           }
531c8564c30SJames Wright           if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
532c8564c30SJames Wright           if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
533c8564c30SJames Wright           // LCOV_EXCL_STOP
534c8564c30SJames Wright         }
535c8564c30SJames Wright         {
536c8564c30SJames Wright           PetscInt local_is_valid[2], global_is_valid[2];
537c8564c30SJames Wright 
538c8564c30SJames Wright           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid;
539c8564c30SJames Wright           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
540c8564c30SJames Wright           ctx->is_ceed_pbd_valid = global_is_valid[0];
541c8564c30SJames Wright           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid;
542c8564c30SJames Wright           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
543c8564c30SJames Wright           ctx->is_ceed_vpbd_valid = global_is_valid[0];
544c8564c30SJames Wright         }
545c8564c30SJames Wright       }
546c8564c30SJames Wright     }
547c8564c30SJames Wright     PetscCall(MatDestroy(&temp_mat));
548c8564c30SJames Wright   }
549c8564c30SJames Wright   // -- Set internal mat type
550c8564c30SJames Wright   {
551c8564c30SJames Wright     VecType vec_type;
552c8564c30SJames Wright     MatType internal_mat_type = MATAIJ;
553c8564c30SJames Wright 
554c8564c30SJames Wright     PetscCall(VecGetType(ctx->X_loc, &vec_type));
555c8564c30SJames Wright     if (strstr(vec_type, VECCUDA)) internal_mat_type = MATAIJCUSPARSE;
556c8564c30SJames Wright     else if (strstr(vec_type, VECKOKKOS)) internal_mat_type = MATAIJKOKKOS;
557c8564c30SJames Wright     else internal_mat_type = MATAIJ;
558c8564c30SJames Wright     PetscCall(PetscStrallocpy(internal_mat_type, &ctx->internal_mat_type));
559c8564c30SJames Wright   }
560c8564c30SJames Wright   // -- Set mat operations
561c8564c30SJames Wright   PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
562c8564c30SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed));
563c8564c30SJames Wright   if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
564c8564c30SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
565c8564c30SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
566c8564c30SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
567c8564c30SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
568c8564c30SJames Wright   PetscCall(MatShellSetVecType(*mat, vec_type));
569c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
570c8564c30SJames Wright }
571c8564c30SJames Wright 
572c8564c30SJames Wright /**
573c8564c30SJames Wright   @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`.
574c8564c30SJames Wright 
575c8564c30SJames Wright   Collective across MPI processes.
576c8564c30SJames Wright 
577c8564c30SJames Wright   @param[in]   mat_ceed   `MATCEED` to copy from
578c8564c30SJames Wright   @param[out]  mat_other  `MatShell` or `MATCEED` to copy into
579c8564c30SJames Wright 
580c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
581c8564c30SJames Wright **/
582c8564c30SJames Wright PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) {
583c8564c30SJames Wright   PetscFunctionBeginUser;
584c8564c30SJames Wright   PetscCall(MatCeedRegisterLogEvents());
585c8564c30SJames Wright 
586c8564c30SJames Wright   // Check type compatibility
587c8564c30SJames Wright   {
588c8564c30SJames Wright     MatType mat_type_ceed, mat_type_other;
589c8564c30SJames Wright 
590c8564c30SJames Wright     PetscCall(MatGetType(mat_ceed, &mat_type_ceed));
591c8564c30SJames Wright     PetscCheck(!strcmp(mat_type_ceed, MATCEED), PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED);
592c8564c30SJames Wright     PetscCall(MatGetType(mat_ceed, &mat_type_other));
593c8564c30SJames Wright     PetscCheck(!strcmp(mat_type_other, MATCEED) || !strcmp(mat_type_other, MATSHELL), PETSC_COMM_SELF, PETSC_ERR_LIB,
594c8564c30SJames Wright                "mat_other must have type " MATCEED " or " MATSHELL);
595c8564c30SJames Wright   }
596c8564c30SJames Wright 
597c8564c30SJames Wright   // Check dimension compatibility
598c8564c30SJames Wright   {
599c8564c30SJames Wright     PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size;
600c8564c30SJames Wright 
601c8564c30SJames Wright     PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size));
602c8564c30SJames Wright     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size));
603c8564c30SJames Wright     PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size));
604c8564c30SJames Wright     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size));
605c8564c30SJames Wright     PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) &&
606c8564c30SJames Wright                    (X_l_ceed_size == X_l_other_size),
607c8564c30SJames Wright                PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ,
608c8564c30SJames Wright                "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT
609c8564c30SJames Wright                "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT
610c8564c30SJames Wright                ", %" PetscInt_FMT ")",
611c8564c30SJames Wright                Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size);
612c8564c30SJames Wright   }
613c8564c30SJames Wright 
614c8564c30SJames Wright   // Convert
615c8564c30SJames Wright   {
616c8564c30SJames Wright     VecType        vec_type;
617c8564c30SJames Wright     MatCeedContext ctx;
618c8564c30SJames Wright 
619c8564c30SJames Wright     PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED));
620c8564c30SJames Wright     PetscCall(MatShellGetContext(mat_ceed, &ctx));
621c8564c30SJames Wright     PetscCall(MatCeedContextReference(ctx));
622c8564c30SJames Wright     PetscCall(MatShellSetContext(mat_other, ctx));
623c8564c30SJames Wright     PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
624c8564c30SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed));
625c8564c30SJames Wright     if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
626c8564c30SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
627c8564c30SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
628c8564c30SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
629c8564c30SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
630c8564c30SJames Wright     {
631c8564c30SJames Wright       PetscInt block_size;
632c8564c30SJames Wright 
633c8564c30SJames Wright       PetscCall(MatGetBlockSize(mat_ceed, &block_size));
634c8564c30SJames Wright       if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size));
635c8564c30SJames Wright     }
636c8564c30SJames Wright     {
637c8564c30SJames Wright       PetscInt        num_blocks;
638c8564c30SJames Wright       const PetscInt *block_sizes;
639c8564c30SJames Wright 
640c8564c30SJames Wright       PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes));
641c8564c30SJames Wright       if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes));
642c8564c30SJames Wright     }
643c8564c30SJames Wright     PetscCall(DMGetVecType(ctx->dm_x, &vec_type));
644c8564c30SJames Wright     PetscCall(MatShellSetVecType(mat_other, vec_type));
645c8564c30SJames Wright   }
646c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
647c8564c30SJames Wright }
648c8564c30SJames Wright 
649c8564c30SJames Wright /**
650c8564c30SJames Wright   @brief Assemble a `MATCEED` into a `MATAIJ` or similar.
651c8564c30SJames Wright          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
652c8564c30SJames Wright          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
653c8564c30SJames Wright 
654c8564c30SJames Wright   Collective across MPI processes.
655c8564c30SJames Wright 
656c8564c30SJames Wright   @param[in]      mat_ceed  `MATCEED` to assemble
657c8564c30SJames Wright   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
658c8564c30SJames Wright 
659c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
660c8564c30SJames Wright **/
661c8564c30SJames Wright PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) {
662c8564c30SJames Wright   MatCeedContext ctx;
663c8564c30SJames Wright 
664c8564c30SJames Wright   PetscFunctionBeginUser;
665c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
666c8564c30SJames Wright 
667c8564c30SJames Wright   // Check if COO pattern set
668c8564c30SJames Wright   {
669c8564c30SJames Wright     PetscInt index = -1;
670c8564c30SJames Wright 
671c8564c30SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
672c8564c30SJames Wright       if (ctx->mats_assembled_full[i] == mat_coo) index = i;
673c8564c30SJames Wright     }
674c8564c30SJames Wright     if (index == -1) {
675c8564c30SJames Wright       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
676c8564c30SJames Wright       CeedInt      *rows_ceed, *cols_ceed;
677c8564c30SJames Wright       PetscCount    num_entries;
678c8564c30SJames Wright       PetscLogStage stage_amg_setup;
679c8564c30SJames Wright 
680c8564c30SJames Wright       // -- Assemble sparsity pattern if mat hasn't been assembled before
681c8564c30SJames Wright       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
682c8564c30SJames Wright       if (stage_amg_setup == -1) {
683c8564c30SJames Wright         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
684c8564c30SJames Wright       }
685c8564c30SJames Wright       PetscCall(PetscLogStagePush(stage_amg_setup));
686*54831c5fSJames Wright       PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
687c8564c30SJames Wright       PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc));
688c8564c30SJames Wright       PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc));
689c8564c30SJames Wright       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
690c8564c30SJames Wright       free(rows_petsc);
691c8564c30SJames Wright       free(cols_petsc);
692*54831c5fSJames Wright       if (!ctx->coo_values_full) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full));
693c8564c30SJames Wright       PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full));
694c8564c30SJames Wright       ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo;
695c8564c30SJames Wright       PetscCall(PetscLogStagePop());
696c8564c30SJames Wright     }
697c8564c30SJames Wright   }
698c8564c30SJames Wright 
699c8564c30SJames Wright   // Assemble mat_ceed
700c8564c30SJames Wright   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
701c8564c30SJames Wright   {
702c8564c30SJames Wright     const CeedScalar *values;
703c8564c30SJames Wright     MatType           mat_type;
704c8564c30SJames Wright     CeedMemType       mem_type = CEED_MEM_HOST;
705c8564c30SJames Wright     PetscBool         is_spd, is_spd_known;
706c8564c30SJames Wright 
707c8564c30SJames Wright     PetscCall(MatGetType(mat_coo, &mat_type));
708c8564c30SJames Wright     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
709c8564c30SJames Wright     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
710c8564c30SJames Wright     else mem_type = CEED_MEM_HOST;
711c8564c30SJames Wright 
712*54831c5fSJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full));
713*54831c5fSJames Wright     PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values));
714c8564c30SJames Wright     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
715c8564c30SJames Wright     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
716c8564c30SJames Wright     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
717*54831c5fSJames Wright     PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values));
718c8564c30SJames Wright   }
719c8564c30SJames Wright   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
720c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
721c8564c30SJames Wright }
722c8564c30SJames Wright 
723c8564c30SJames Wright /**
724c8564c30SJames Wright   @brief Set user context for a `MATCEED`.
725c8564c30SJames Wright 
726c8564c30SJames Wright   Collective across MPI processes.
727c8564c30SJames Wright 
728c8564c30SJames Wright   @param[in,out]  mat  `MATCEED`
729c8564c30SJames Wright   @param[in]      f    The context destroy function, or NULL
730c8564c30SJames Wright   @param[in]      ctx  User context, or NULL to unset
731c8564c30SJames Wright 
732c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
733c8564c30SJames Wright **/
734c8564c30SJames Wright PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) {
735c8564c30SJames Wright   PetscContainer user_ctx = NULL;
736c8564c30SJames Wright 
737c8564c30SJames Wright   PetscFunctionBeginUser;
738c8564c30SJames Wright   if (ctx) {
739c8564c30SJames Wright     PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx));
740c8564c30SJames Wright     PetscCall(PetscContainerSetPointer(user_ctx, ctx));
741c8564c30SJames Wright     PetscCall(PetscContainerSetUserDestroy(user_ctx, f));
742c8564c30SJames Wright   }
743c8564c30SJames Wright   PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx));
744c8564c30SJames Wright   PetscCall(PetscContainerDestroy(&user_ctx));
745c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
746c8564c30SJames Wright }
747c8564c30SJames Wright 
748c8564c30SJames Wright /**
749c8564c30SJames Wright   @brief Retrieve the user context for a `MATCEED`.
750c8564c30SJames Wright 
751c8564c30SJames Wright   Collective across MPI processes.
752c8564c30SJames Wright 
753c8564c30SJames Wright   @param[in,out]  mat  `MATCEED`
754c8564c30SJames Wright   @param[in]      ctx  User context
755c8564c30SJames Wright 
756c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
757c8564c30SJames Wright **/
758c8564c30SJames Wright PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) {
759c8564c30SJames Wright   PetscContainer user_ctx;
760c8564c30SJames Wright 
761c8564c30SJames Wright   PetscFunctionBeginUser;
762c8564c30SJames Wright   PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx));
763c8564c30SJames Wright   if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx));
764c8564c30SJames Wright   else *(void **)ctx = NULL;
765c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
766c8564c30SJames Wright }
767c8564c30SJames Wright 
768c8564c30SJames Wright /**
769c8564c30SJames Wright   @brief Sets the inner matrix type as a string from the `MATCEED`.
770c8564c30SJames Wright 
771c8564c30SJames Wright   Collective across MPI processes.
772c8564c30SJames Wright 
773c8564c30SJames Wright   @param[in,out]  mat   `MATCEED`
774c8564c30SJames Wright   @param[in]      type  Inner `MatType` to set
775c8564c30SJames Wright 
776c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
777c8564c30SJames Wright **/
778c8564c30SJames Wright PetscErrorCode MatCeedSetInnerMatType(Mat mat, MatType type) {
779c8564c30SJames Wright   MatCeedContext ctx;
780c8564c30SJames Wright 
781c8564c30SJames Wright   PetscFunctionBeginUser;
782c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
783c8564c30SJames Wright   // Check if same
784c8564c30SJames Wright   {
785c8564c30SJames Wright     size_t    len_old, len_new;
786c8564c30SJames Wright     PetscBool is_same = PETSC_FALSE;
787c8564c30SJames Wright 
788c8564c30SJames Wright     PetscCall(PetscStrlen(ctx->internal_mat_type, &len_old));
789c8564c30SJames Wright     PetscCall(PetscStrlen(type, &len_new));
790c8564c30SJames Wright     if (len_old == len_new) PetscCall(PetscStrncmp(ctx->internal_mat_type, type, len_old, &is_same));
791c8564c30SJames Wright     if (is_same) PetscFunctionReturn(PETSC_SUCCESS);
792c8564c30SJames Wright   }
793c8564c30SJames Wright   // Clean up old mats in different format
794c8564c30SJames Wright   // LCOV_EXCL_START
795c8564c30SJames Wright   if (ctx->mat_assembled_full_internal) {
796c8564c30SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
797c8564c30SJames Wright       if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) {
798c8564c30SJames Wright         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) {
799c8564c30SJames Wright           ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j];
800c8564c30SJames Wright         }
801c8564c30SJames Wright         ctx->num_mats_assembled_full--;
802c8564c30SJames Wright         // Note: we'll realloc this array again, so no need to shrink the allocation
803c8564c30SJames Wright         PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
804c8564c30SJames Wright       }
805c8564c30SJames Wright     }
806c8564c30SJames Wright   }
807c8564c30SJames Wright   if (ctx->mat_assembled_pbd_internal) {
808c8564c30SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
809c8564c30SJames Wright       if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) {
810c8564c30SJames Wright         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) {
811c8564c30SJames Wright           ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j];
812c8564c30SJames Wright         }
813c8564c30SJames Wright         // Note: we'll realloc this array again, so no need to shrink the allocation
814c8564c30SJames Wright         ctx->num_mats_assembled_pbd--;
815c8564c30SJames Wright         PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
816c8564c30SJames Wright       }
817c8564c30SJames Wright     }
818c8564c30SJames Wright   }
819c8564c30SJames Wright   PetscCall(PetscFree(ctx->internal_mat_type));
820c8564c30SJames Wright   PetscCall(PetscStrallocpy(type, &ctx->internal_mat_type));
821c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
822c8564c30SJames Wright   // LCOV_EXCL_STOP
823c8564c30SJames Wright }
824c8564c30SJames Wright 
825c8564c30SJames Wright /**
826c8564c30SJames Wright   @brief Gets the inner matrix type as a string from the `MATCEED`.
827c8564c30SJames Wright 
828c8564c30SJames Wright   Collective across MPI processes.
829c8564c30SJames Wright 
830c8564c30SJames Wright   @param[in,out]  mat   `MATCEED`
831c8564c30SJames Wright   @param[in]      type  Inner `MatType`
832c8564c30SJames Wright 
833c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
834c8564c30SJames Wright **/
835c8564c30SJames Wright PetscErrorCode MatCeedGetInnerMatType(Mat mat, MatType *type) {
836c8564c30SJames Wright   MatCeedContext ctx;
837c8564c30SJames Wright 
838c8564c30SJames Wright   PetscFunctionBeginUser;
839c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
840c8564c30SJames Wright   *type = ctx->internal_mat_type;
841c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
842c8564c30SJames Wright }
843c8564c30SJames Wright 
844c8564c30SJames Wright /**
845c8564c30SJames Wright   @brief Set a user defined matrix operation for a `MATCEED` matrix.
846c8564c30SJames Wright 
847c8564c30SJames Wright   Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by
848c8564c30SJames Wright `MatCeedSetContext()`.
849c8564c30SJames Wright 
850c8564c30SJames Wright   Collective across MPI processes.
851c8564c30SJames Wright 
852c8564c30SJames Wright   @param[in,out]  mat  `MATCEED`
853c8564c30SJames Wright   @param[in]      op   Name of the `MatOperation`
854c8564c30SJames Wright   @param[in]      g    Function that provides the operation
855c8564c30SJames Wright 
856c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
857c8564c30SJames Wright **/
858c8564c30SJames Wright PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) {
859c8564c30SJames Wright   PetscFunctionBeginUser;
860c8564c30SJames Wright   PetscCall(MatShellSetOperation(mat, op, g));
861c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
862c8564c30SJames Wright }
863c8564c30SJames Wright 
864c8564c30SJames Wright /**
865c8564c30SJames Wright   @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
866c8564c30SJames Wright 
867c8564c30SJames Wright   Not collective across MPI processes.
868c8564c30SJames Wright 
869c8564c30SJames Wright   @param[in,out]  mat              `MATCEED`
870c8564c30SJames Wright   @param[in]      X_loc            Input PETSc local vector, or NULL
871c8564c30SJames Wright   @param[in]      Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
872c8564c30SJames Wright 
873c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
874c8564c30SJames Wright **/
875c8564c30SJames Wright PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) {
876c8564c30SJames Wright   MatCeedContext ctx;
877c8564c30SJames Wright 
878c8564c30SJames Wright   PetscFunctionBeginUser;
879c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
880c8564c30SJames Wright   if (X_loc) {
881c8564c30SJames Wright     PetscInt len_old, len_new;
882c8564c30SJames Wright 
883c8564c30SJames Wright     PetscCall(VecGetSize(ctx->X_loc, &len_old));
884c8564c30SJames Wright     PetscCall(VecGetSize(X_loc, &len_new));
885c8564c30SJames Wright     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT,
886c8564c30SJames Wright                len_new, len_old);
887c8564c30SJames Wright     PetscCall(VecDestroy(&ctx->X_loc));
888c8564c30SJames Wright     ctx->X_loc = X_loc;
889c8564c30SJames Wright     PetscCall(PetscObjectReference((PetscObject)X_loc));
890c8564c30SJames Wright   }
891c8564c30SJames Wright   if (Y_loc_transpose) {
892c8564c30SJames Wright     PetscInt len_old, len_new;
893c8564c30SJames Wright 
894c8564c30SJames Wright     PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old));
895c8564c30SJames Wright     PetscCall(VecGetSize(Y_loc_transpose, &len_new));
896c8564c30SJames Wright     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB,
897c8564c30SJames Wright                "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old);
898c8564c30SJames Wright     PetscCall(VecDestroy(&ctx->Y_loc_transpose));
899c8564c30SJames Wright     ctx->Y_loc_transpose = Y_loc_transpose;
900c8564c30SJames Wright     PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
901c8564c30SJames Wright   }
902c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
903c8564c30SJames Wright }
904c8564c30SJames Wright 
905c8564c30SJames Wright /**
906c8564c30SJames Wright   @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
907c8564c30SJames Wright 
908c8564c30SJames Wright   Not collective across MPI processes.
909c8564c30SJames Wright 
910c8564c30SJames Wright   @param[in,out]  mat              `MATCEED`
911c8564c30SJames Wright   @param[out]     X_loc            Input PETSc local vector, or NULL
912c8564c30SJames Wright   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
913c8564c30SJames Wright 
914c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
915c8564c30SJames Wright **/
916c8564c30SJames Wright PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
917c8564c30SJames Wright   MatCeedContext ctx;
918c8564c30SJames Wright 
919c8564c30SJames Wright   PetscFunctionBeginUser;
920c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
921c8564c30SJames Wright   if (X_loc) {
922c8564c30SJames Wright     *X_loc = ctx->X_loc;
923c8564c30SJames Wright     PetscCall(PetscObjectReference((PetscObject)*X_loc));
924c8564c30SJames Wright   }
925c8564c30SJames Wright   if (Y_loc_transpose) {
926c8564c30SJames Wright     *Y_loc_transpose = ctx->Y_loc_transpose;
927c8564c30SJames Wright     PetscCall(PetscObjectReference((PetscObject)*Y_loc_transpose));
928c8564c30SJames Wright   }
929c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
930c8564c30SJames Wright }
931c8564c30SJames Wright 
932c8564c30SJames Wright /**
933c8564c30SJames Wright   @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
934c8564c30SJames Wright 
935c8564c30SJames Wright   Not collective across MPI processes.
936c8564c30SJames Wright 
937c8564c30SJames Wright   @param[in,out]  mat              MatCeed
938c8564c30SJames Wright   @param[out]     X_loc            Input PETSc local vector, or NULL
939c8564c30SJames Wright   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
940c8564c30SJames Wright 
941c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
942c8564c30SJames Wright **/
943c8564c30SJames Wright PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
944c8564c30SJames Wright   PetscFunctionBeginUser;
945c8564c30SJames Wright   if (X_loc) PetscCall(VecDestroy(X_loc));
946c8564c30SJames Wright   if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose));
947c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
948c8564c30SJames Wright }
949c8564c30SJames Wright 
950c8564c30SJames Wright /**
951c8564c30SJames Wright   @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
952c8564c30SJames Wright 
953c8564c30SJames Wright   Not collective across MPI processes.
954c8564c30SJames Wright 
955c8564c30SJames Wright   @param[in,out]  mat                MatCeed
956c8564c30SJames Wright   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
957c8564c30SJames Wright   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
958c8564c30SJames Wright 
959c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
960c8564c30SJames Wright **/
961c8564c30SJames Wright PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
962c8564c30SJames Wright   MatCeedContext ctx;
963c8564c30SJames Wright 
964c8564c30SJames Wright   PetscFunctionBeginUser;
965c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
966c8564c30SJames Wright   if (op_mult) {
967c8564c30SJames Wright     *op_mult = NULL;
968*54831c5fSJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult));
969c8564c30SJames Wright   }
970c8564c30SJames Wright   if (op_mult_transpose) {
971c8564c30SJames Wright     *op_mult_transpose = NULL;
972*54831c5fSJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose));
973c8564c30SJames Wright   }
974c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
975c8564c30SJames Wright }
976c8564c30SJames Wright 
977c8564c30SJames Wright /**
978c8564c30SJames Wright   @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
979c8564c30SJames Wright 
980c8564c30SJames Wright   Not collective across MPI processes.
981c8564c30SJames Wright 
982c8564c30SJames Wright   @param[in,out]  mat                MatCeed
983c8564c30SJames Wright   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
984c8564c30SJames Wright   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
985c8564c30SJames Wright 
986c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
987c8564c30SJames Wright **/
988c8564c30SJames Wright PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
989c8564c30SJames Wright   MatCeedContext ctx;
990c8564c30SJames Wright 
991c8564c30SJames Wright   PetscFunctionBeginUser;
992c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
993*54831c5fSJames Wright   if (op_mult) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult));
994*54831c5fSJames Wright   if (op_mult_transpose) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult_transpose));
995c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
996c8564c30SJames Wright }
997c8564c30SJames Wright 
998c8564c30SJames Wright /**
999c8564c30SJames Wright   @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
1000c8564c30SJames Wright 
1001c8564c30SJames Wright   Not collective across MPI processes.
1002c8564c30SJames Wright 
1003c8564c30SJames Wright   @param[in,out]  mat                       MatCeed
1004c8564c30SJames Wright   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
1005c8564c30SJames Wright   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
1006c8564c30SJames Wright 
1007c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1008c8564c30SJames Wright **/
1009c8564c30SJames Wright PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) {
1010c8564c30SJames Wright   MatCeedContext ctx;
1011c8564c30SJames Wright 
1012c8564c30SJames Wright   PetscFunctionBeginUser;
1013c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
1014c8564c30SJames Wright   if (log_event_mult) ctx->log_event_mult = log_event_mult;
1015c8564c30SJames Wright   if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose;
1016c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1017c8564c30SJames Wright }
1018c8564c30SJames Wright 
1019c8564c30SJames Wright /**
1020c8564c30SJames Wright   @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
1021c8564c30SJames Wright 
1022c8564c30SJames Wright   Not collective across MPI processes.
1023c8564c30SJames Wright 
1024c8564c30SJames Wright   @param[in,out]  mat                       MatCeed
1025c8564c30SJames Wright   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
1026c8564c30SJames Wright   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
1027c8564c30SJames Wright 
1028c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1029c8564c30SJames Wright **/
1030c8564c30SJames Wright PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) {
1031c8564c30SJames Wright   MatCeedContext ctx;
1032c8564c30SJames Wright 
1033c8564c30SJames Wright   PetscFunctionBeginUser;
1034c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
1035c8564c30SJames Wright   if (log_event_mult) *log_event_mult = ctx->log_event_mult;
1036c8564c30SJames Wright   if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose;
1037c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1038c8564c30SJames Wright }
1039c8564c30SJames Wright 
1040c8564c30SJames Wright // -----------------------------------------------------------------------------
1041c8564c30SJames Wright // Operator context data
1042c8564c30SJames Wright // -----------------------------------------------------------------------------
1043c8564c30SJames Wright 
1044c8564c30SJames Wright /**
1045c8564c30SJames Wright   @brief Setup context data for operator application.
1046c8564c30SJames Wright 
1047c8564c30SJames Wright   Collective across MPI processes.
1048c8564c30SJames Wright 
1049c8564c30SJames Wright   @param[in]   dm_x                      Input `DM`
1050c8564c30SJames Wright   @param[in]   dm_y                      Output `DM`
1051c8564c30SJames Wright   @param[in]   X_loc                     Input PETSc local vector, or NULL
1052c8564c30SJames Wright   @param[in]   Y_loc_transpose           Input PETSc local vector for transpose operation, or NULL
1053c8564c30SJames Wright   @param[in]   op_mult                   `CeedOperator` for forward evaluation
1054c8564c30SJames Wright   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
1055c8564c30SJames Wright   @param[in]   log_event_mult            `PetscLogEvent` for forward evaluation
1056c8564c30SJames Wright   @param[in]   log_event_mult_transpose  `PetscLogEvent` for transpose evaluation
1057c8564c30SJames Wright   @param[out]  ctx                       Context data for operator evaluation
1058c8564c30SJames Wright 
1059c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1060c8564c30SJames Wright **/
1061c8564c30SJames Wright PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose,
1062c8564c30SJames Wright                                     PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) {
1063c8564c30SJames Wright   CeedSize x_loc_len, y_loc_len;
1064c8564c30SJames Wright 
1065c8564c30SJames Wright   PetscFunctionBeginUser;
1066c8564c30SJames Wright 
1067c8564c30SJames Wright   // Allocate
1068c8564c30SJames Wright   PetscCall(PetscNew(ctx));
1069c8564c30SJames Wright   (*ctx)->ref_count = 1;
1070c8564c30SJames Wright 
1071c8564c30SJames Wright   // Logging
1072c8564c30SJames Wright   (*ctx)->log_event_mult           = log_event_mult;
1073c8564c30SJames Wright   (*ctx)->log_event_mult_transpose = log_event_mult_transpose;
1074c8564c30SJames Wright 
1075c8564c30SJames Wright   // PETSc objects
1076c8564c30SJames Wright   PetscCall(PetscObjectReference((PetscObject)dm_x));
1077c8564c30SJames Wright   (*ctx)->dm_x = dm_x;
1078c8564c30SJames Wright   PetscCall(PetscObjectReference((PetscObject)dm_y));
1079c8564c30SJames Wright   (*ctx)->dm_y = dm_y;
1080c8564c30SJames Wright   if (X_loc) PetscCall(PetscObjectReference((PetscObject)X_loc));
1081c8564c30SJames Wright   (*ctx)->X_loc = X_loc;
1082c8564c30SJames Wright   if (Y_loc_transpose) PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
1083c8564c30SJames Wright   (*ctx)->Y_loc_transpose = Y_loc_transpose;
1084c8564c30SJames Wright 
1085c8564c30SJames Wright   // Memtype
1086c8564c30SJames Wright   {
1087c8564c30SJames Wright     const PetscScalar *x;
1088c8564c30SJames Wright     Vec                X;
1089c8564c30SJames Wright 
1090c8564c30SJames Wright     PetscCall(DMGetLocalVector(dm_x, &X));
1091c8564c30SJames Wright     PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type));
1092c8564c30SJames Wright     PetscCall(VecRestoreArrayReadAndMemType(X, &x));
1093c8564c30SJames Wright     PetscCall(DMRestoreLocalVector(dm_x, &X));
1094c8564c30SJames Wright   }
1095c8564c30SJames Wright 
1096c8564c30SJames Wright   // libCEED objects
1097c8564c30SJames Wright   PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB,
1098c8564c30SJames Wright              "retrieving Ceed context object failed");
1099*54831c5fSJames Wright   PetscCallCeed((*ctx)->ceed, CeedReference((*ctx)->ceed));
1100*54831c5fSJames Wright   PetscCallCeed((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len));
1101*54831c5fSJames Wright   PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult));
1102*54831c5fSJames Wright   if (op_mult_transpose) PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose));
1103*54831c5fSJames Wright   PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc));
1104*54831c5fSJames Wright   PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc));
1105c8564c30SJames Wright 
1106c8564c30SJames Wright   // Flop counting
1107c8564c30SJames Wright   {
1108c8564c30SJames Wright     CeedSize ceed_flops_estimate = 0;
1109c8564c30SJames Wright 
1110*54831c5fSJames Wright     PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate));
1111c8564c30SJames Wright     (*ctx)->flops_mult = ceed_flops_estimate;
1112c8564c30SJames Wright     if (op_mult_transpose) {
1113*54831c5fSJames Wright       PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate));
1114c8564c30SJames Wright       (*ctx)->flops_mult_transpose = ceed_flops_estimate;
1115c8564c30SJames Wright     }
1116c8564c30SJames Wright   }
1117c8564c30SJames Wright 
1118c8564c30SJames Wright   // Check sizes
1119c8564c30SJames Wright   if (x_loc_len > 0 || y_loc_len > 0) {
1120c8564c30SJames Wright     CeedSize ctx_x_loc_len, ctx_y_loc_len;
1121c8564c30SJames Wright     PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len;
1122c8564c30SJames Wright     Vec      dm_X_loc, dm_Y_loc;
1123c8564c30SJames Wright 
1124c8564c30SJames Wright     // -- Input
1125c8564c30SJames Wright     PetscCall(DMGetLocalVector(dm_x, &dm_X_loc));
1126c8564c30SJames Wright     PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len));
1127c8564c30SJames Wright     PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc));
1128c8564c30SJames Wright     if (X_loc) PetscCall(VecGetLocalSize(X_loc, &X_loc_len));
1129*54831c5fSJames Wright     PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len));
1130c8564c30SJames Wright     if (X_loc) PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "X_loc must match dm_x dimensions");
1131c8564c30SJames Wright     PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_x dimensions");
1132c8564c30SJames Wright     PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc must match op dimensions");
1133c8564c30SJames Wright 
1134c8564c30SJames Wright     // -- Output
1135c8564c30SJames Wright     PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc));
1136c8564c30SJames Wright     PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len));
1137c8564c30SJames Wright     PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc));
1138*54831c5fSJames Wright     PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len));
1139c8564c30SJames Wright     PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_y dimensions");
1140c8564c30SJames Wright 
1141c8564c30SJames Wright     // -- Transpose
1142c8564c30SJames Wright     if (Y_loc_transpose) {
1143c8564c30SJames Wright       PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len));
1144c8564c30SJames Wright       PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "Y_loc_transpose must match dm_y dimensions");
1145c8564c30SJames Wright     }
1146c8564c30SJames Wright   }
1147c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1148c8564c30SJames Wright }
1149c8564c30SJames Wright 
1150c8564c30SJames Wright /**
1151c8564c30SJames Wright   @brief Increment reference counter for `MATCEED` context.
1152c8564c30SJames Wright 
1153c8564c30SJames Wright   Not collective across MPI processes.
1154c8564c30SJames Wright 
1155c8564c30SJames Wright   @param[in,out]  ctx  Context data
1156c8564c30SJames Wright 
1157c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1158c8564c30SJames Wright **/
1159c8564c30SJames Wright PetscErrorCode MatCeedContextReference(MatCeedContext ctx) {
1160c8564c30SJames Wright   PetscFunctionBeginUser;
1161c8564c30SJames Wright   ctx->ref_count++;
1162c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1163c8564c30SJames Wright }
1164c8564c30SJames Wright 
1165c8564c30SJames Wright /**
1166c8564c30SJames Wright   @brief Copy reference for `MATCEED`.
1167c8564c30SJames Wright          Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`.
1168c8564c30SJames Wright 
1169c8564c30SJames Wright   Not collective across MPI processes.
1170c8564c30SJames Wright 
1171c8564c30SJames Wright   @param[in]   ctx       Context data
1172c8564c30SJames Wright   @param[out]  ctx_copy  Copy of pointer to context data
1173c8564c30SJames Wright 
1174c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1175c8564c30SJames Wright **/
1176c8564c30SJames Wright PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) {
1177c8564c30SJames Wright   PetscFunctionBeginUser;
1178c8564c30SJames Wright   PetscCall(MatCeedContextReference(ctx));
1179c8564c30SJames Wright   PetscCall(MatCeedContextDestroy(*ctx_copy));
1180c8564c30SJames Wright   *ctx_copy = ctx;
1181c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1182c8564c30SJames Wright }
1183c8564c30SJames Wright 
1184c8564c30SJames Wright /**
1185c8564c30SJames Wright   @brief Destroy context data for operator application.
1186c8564c30SJames Wright 
1187c8564c30SJames Wright   Collective across MPI processes.
1188c8564c30SJames Wright 
1189c8564c30SJames Wright   @param[in,out]  ctx  Context data for operator evaluation
1190c8564c30SJames Wright 
1191c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1192c8564c30SJames Wright **/
1193c8564c30SJames Wright PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) {
1194c8564c30SJames Wright   PetscFunctionBeginUser;
1195c8564c30SJames Wright   if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS);
1196c8564c30SJames Wright 
1197c8564c30SJames Wright   // PETSc objects
1198c8564c30SJames Wright   PetscCall(DMDestroy(&ctx->dm_x));
1199c8564c30SJames Wright   PetscCall(DMDestroy(&ctx->dm_y));
1200c8564c30SJames Wright   PetscCall(VecDestroy(&ctx->X_loc));
1201c8564c30SJames Wright   PetscCall(VecDestroy(&ctx->Y_loc_transpose));
1202c8564c30SJames Wright   PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
1203c8564c30SJames Wright   PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
1204c8564c30SJames Wright   PetscCall(PetscFree(ctx->internal_mat_type));
1205c8564c30SJames Wright   PetscCall(PetscFree(ctx->mats_assembled_full));
1206c8564c30SJames Wright   PetscCall(PetscFree(ctx->mats_assembled_pbd));
1207c8564c30SJames Wright 
1208c8564c30SJames Wright   // libCEED objects
1209*54831c5fSJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->x_loc));
1210*54831c5fSJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->y_loc));
1211*54831c5fSJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full));
1212*54831c5fSJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd));
1213*54831c5fSJames Wright   PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult));
1214*54831c5fSJames Wright   PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose));
1215c8564c30SJames Wright   PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed");
1216c8564c30SJames Wright 
1217c8564c30SJames Wright   // Deallocate
1218c8564c30SJames Wright   ctx->is_destroyed = PETSC_TRUE;  // Flag as destroyed in case someone has stale ref
1219c8564c30SJames Wright   PetscCall(PetscFree(ctx));
1220c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1221c8564c30SJames Wright }
1222c8564c30SJames Wright 
1223c8564c30SJames Wright /**
1224c8564c30SJames Wright   @brief Compute the diagonal of an operator via libCEED.
1225c8564c30SJames Wright 
1226c8564c30SJames Wright   Collective across MPI processes.
1227c8564c30SJames Wright 
1228c8564c30SJames Wright   @param[in]   A  `MATCEED`
1229c8564c30SJames Wright   @param[out]  D  Vector holding operator diagonal
1230c8564c30SJames Wright 
1231c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1232c8564c30SJames Wright **/
1233c8564c30SJames Wright PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) {
1234c8564c30SJames Wright   PetscMemType   mem_type;
1235c8564c30SJames Wright   Vec            D_loc;
1236c8564c30SJames Wright   MatCeedContext ctx;
1237c8564c30SJames Wright 
1238c8564c30SJames Wright   PetscFunctionBeginUser;
1239c8564c30SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
1240c8564c30SJames Wright 
1241c8564c30SJames Wright   // Place PETSc vector in libCEED vector
1242c8564c30SJames Wright   PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc));
1243c8564c30SJames Wright   PetscCall(VecP2C(ctx->ceed, D_loc, &mem_type, ctx->x_loc));
1244c8564c30SJames Wright 
1245c8564c30SJames Wright   // Compute Diagonal
1246*54831c5fSJames Wright   PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1247c8564c30SJames Wright 
1248c8564c30SJames Wright   // Restore PETSc vector
1249c8564c30SJames Wright   PetscCall(VecC2P(ctx->ceed, ctx->x_loc, mem_type, D_loc));
1250c8564c30SJames Wright 
1251c8564c30SJames Wright   // Local-to-Global
1252c8564c30SJames Wright   PetscCall(VecZeroEntries(D));
1253c8564c30SJames Wright   PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D));
1254c8564c30SJames Wright   PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc));
1255c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1256c8564c30SJames Wright }
1257c8564c30SJames Wright 
1258c8564c30SJames Wright /**
1259c8564c30SJames Wright   @brief Compute `A X = Y` for a `MATCEED`.
1260c8564c30SJames Wright 
1261c8564c30SJames Wright   Collective across MPI processes.
1262c8564c30SJames Wright 
1263c8564c30SJames Wright   @param[in]   A  `MATCEED`
1264c8564c30SJames Wright   @param[in]   X  Input PETSc vector
1265c8564c30SJames Wright   @param[out]  Y  Output PETSc vector
1266c8564c30SJames Wright 
1267c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1268c8564c30SJames Wright **/
1269c8564c30SJames Wright PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) {
1270c8564c30SJames Wright   MatCeedContext ctx;
1271c8564c30SJames Wright 
1272c8564c30SJames Wright   PetscFunctionBeginUser;
1273c8564c30SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
1274c8564c30SJames Wright   PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0));
1275c8564c30SJames Wright 
1276c8564c30SJames Wright   {
1277c8564c30SJames Wright     PetscMemType x_mem_type, y_mem_type;
1278c8564c30SJames Wright     Vec          X_loc = ctx->X_loc, Y_loc;
1279c8564c30SJames Wright 
1280c8564c30SJames Wright     // Get local vectors
1281c8564c30SJames Wright     if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1282c8564c30SJames Wright     PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1283c8564c30SJames Wright 
1284c8564c30SJames Wright     // Global-to-local
1285c8564c30SJames Wright     PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc));
1286c8564c30SJames Wright 
1287c8564c30SJames Wright     // Setup libCEED vectors
1288c8564c30SJames Wright     PetscCall(VecReadP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc));
1289c8564c30SJames Wright     PetscCall(VecZeroEntries(Y_loc));
1290c8564c30SJames Wright     PetscCall(VecP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc));
1291c8564c30SJames Wright 
1292c8564c30SJames Wright     // Apply libCEED operator
1293c8564c30SJames Wright     PetscCall(PetscLogGpuTimeBegin());
1294*54831c5fSJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE));
1295c8564c30SJames Wright     PetscCall(PetscLogGpuTimeEnd());
1296c8564c30SJames Wright 
1297c8564c30SJames Wright     // Restore PETSc vectors
1298c8564c30SJames Wright     PetscCall(VecReadC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc));
1299c8564c30SJames Wright     PetscCall(VecC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc));
1300c8564c30SJames Wright 
1301c8564c30SJames Wright     // Local-to-global
1302c8564c30SJames Wright     PetscCall(VecZeroEntries(Y));
1303c8564c30SJames Wright     PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y));
1304c8564c30SJames Wright 
1305c8564c30SJames Wright     // Restore local vectors, as needed
1306c8564c30SJames Wright     if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1307c8564c30SJames Wright     PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1308c8564c30SJames Wright   }
1309c8564c30SJames Wright 
1310c8564c30SJames Wright   // Log flops
1311c8564c30SJames Wright   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult));
1312c8564c30SJames Wright   else PetscCall(PetscLogFlops(ctx->flops_mult));
1313c8564c30SJames Wright 
1314c8564c30SJames Wright   PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0));
1315c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1316c8564c30SJames Wright }
1317c8564c30SJames Wright 
1318c8564c30SJames Wright /**
1319c8564c30SJames Wright   @brief Compute `A^T Y = X` for a `MATCEED`.
1320c8564c30SJames Wright 
1321c8564c30SJames Wright   Collective across MPI processes.
1322c8564c30SJames Wright 
1323c8564c30SJames Wright   @param[in]   A  `MATCEED`
1324c8564c30SJames Wright   @param[in]   Y  Input PETSc vector
1325c8564c30SJames Wright   @param[out]  X  Output PETSc vector
1326c8564c30SJames Wright 
1327c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1328c8564c30SJames Wright **/
1329c8564c30SJames Wright PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) {
1330c8564c30SJames Wright   MatCeedContext ctx;
1331c8564c30SJames Wright 
1332c8564c30SJames Wright   PetscFunctionBeginUser;
1333c8564c30SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
1334c8564c30SJames Wright   PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0));
1335c8564c30SJames Wright 
1336c8564c30SJames Wright   {
1337c8564c30SJames Wright     PetscMemType x_mem_type, y_mem_type;
1338c8564c30SJames Wright     Vec          X_loc, Y_loc = ctx->Y_loc_transpose;
1339c8564c30SJames Wright 
1340c8564c30SJames Wright     // Get local vectors
1341c8564c30SJames Wright     if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1342c8564c30SJames Wright     PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1343c8564c30SJames Wright 
1344c8564c30SJames Wright     // Global-to-local
1345c8564c30SJames Wright     PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc));
1346c8564c30SJames Wright 
1347c8564c30SJames Wright     // Setup libCEED vectors
1348c8564c30SJames Wright     PetscCall(VecReadP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc));
1349c8564c30SJames Wright     PetscCall(VecZeroEntries(X_loc));
1350c8564c30SJames Wright     PetscCall(VecP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc));
1351c8564c30SJames Wright 
1352c8564c30SJames Wright     // Apply libCEED operator
1353c8564c30SJames Wright     PetscCall(PetscLogGpuTimeBegin());
1354*54831c5fSJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1355c8564c30SJames Wright     PetscCall(PetscLogGpuTimeEnd());
1356c8564c30SJames Wright 
1357c8564c30SJames Wright     // Restore PETSc vectors
1358c8564c30SJames Wright     PetscCall(VecReadC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc));
1359c8564c30SJames Wright     PetscCall(VecC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc));
1360c8564c30SJames Wright 
1361c8564c30SJames Wright     // Local-to-global
1362c8564c30SJames Wright     PetscCall(VecZeroEntries(X));
1363c8564c30SJames Wright     PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X));
1364c8564c30SJames Wright 
1365c8564c30SJames Wright     // Restore local vectors, as needed
1366c8564c30SJames Wright     if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1367c8564c30SJames Wright     PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1368c8564c30SJames Wright   }
1369c8564c30SJames Wright 
1370c8564c30SJames Wright   // Log flops
1371c8564c30SJames Wright   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose));
1372c8564c30SJames Wright   else PetscCall(PetscLogFlops(ctx->flops_mult_transpose));
1373c8564c30SJames Wright 
1374c8564c30SJames Wright   PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0));
1375c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1376c8564c30SJames Wright }
1377