xref: /libCEED/examples/petsc/src/petscutils.c (revision c9d5affad74485f8d1e55e6be07e3d9f76bd4cae)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include "../include/petscutils.h"
9 
10 // -----------------------------------------------------------------------------
11 // Convert PETSc MemType to libCEED MemType
12 // -----------------------------------------------------------------------------
13 CeedMemType MemTypeP2C(PetscMemType mem_type) { return PetscMemTypeDevice(mem_type) ? CEED_MEM_DEVICE : CEED_MEM_HOST; }
14 
15 // ------------------------------------------------------------------------------------------------
16 // PETSc-libCEED memory space utilities
17 // ------------------------------------------------------------------------------------------------
18 PetscErrorCode VecP2C(Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
19   PetscScalar *x;
20 
21   PetscFunctionBeginUser;
22   PetscCall(VecGetArrayAndMemType(X_petsc, &x, mem_type));
23   CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x);
24   PetscFunctionReturn(PETSC_SUCCESS);
25 }
26 
27 PetscErrorCode VecC2P(CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
28   PetscScalar *x;
29 
30   PetscFunctionBeginUser;
31   CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x);
32   PetscCall(VecRestoreArrayAndMemType(X_petsc, &x));
33   PetscFunctionReturn(PETSC_SUCCESS);
34 }
35 
36 PetscErrorCode VecReadP2C(Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
37   PetscScalar *x;
38 
39   PetscFunctionBeginUser;
40   PetscCall(VecGetArrayReadAndMemType(X_petsc, (const PetscScalar **)&x, mem_type));
41   CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x);
42   PetscFunctionReturn(PETSC_SUCCESS);
43 }
44 
45 PetscErrorCode VecReadC2P(CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
46   PetscScalar *x;
47 
48   PetscFunctionBeginUser;
49   CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x);
50   PetscCall(VecRestoreArrayReadAndMemType(X_petsc, (const PetscScalar **)&x));
51   PetscFunctionReturn(PETSC_SUCCESS);
52 }
53 
54 // -----------------------------------------------------------------------------
55 // Apply 3D Kershaw mesh transformation
56 // -----------------------------------------------------------------------------
57 // Transition from a value of "a" for x=0, to a value of "b" for x=1.  Optionally
58 // smooth -- see the commented versions at the end.
59 static double step(const double a, const double b, double x) {
60   if (x <= 0) return a;
61   if (x >= 1) return b;
62   return a + (b - a) * (x);
63 }
64 
65 // 1D transformation at the right boundary
66 static double right(const double eps, const double x) { return (x <= 0.5) ? (2 - eps) * x : 1 + eps * (x - 1); }
67 
68 // 1D transformation at the left boundary
69 static double left(const double eps, const double x) { return 1 - right(eps, 1 - x); }
70 
71 // Apply 3D Kershaw mesh transformation
72 // The eps parameters are in (0, 1]
73 // Uniform mesh is recovered for eps=1
74 PetscErrorCode Kershaw(DM dm_orig, PetscScalar eps) {
75   Vec          coord;
76   PetscInt     ncoord;
77   PetscScalar *c;
78 
79   PetscFunctionBeginUser;
80   PetscCall(DMGetCoordinatesLocal(dm_orig, &coord));
81   PetscCall(VecGetLocalSize(coord, &ncoord));
82   PetscCall(VecGetArray(coord, &c));
83 
84   for (PetscInt i = 0; i < ncoord; i += 3) {
85     PetscScalar x = c[i], y = c[i + 1], z = c[i + 2];
86     PetscInt    layer  = x * 6;
87     PetscScalar lambda = (x - layer / 6.0) * 6;
88     c[i]               = x;
89 
90     switch (layer) {
91       case 0:
92         c[i + 1] = left(eps, y);
93         c[i + 2] = left(eps, z);
94         break;
95       case 1:
96       case 4:
97         c[i + 1] = step(left(eps, y), right(eps, y), lambda);
98         c[i + 2] = step(left(eps, z), right(eps, z), lambda);
99         break;
100       case 2:
101         c[i + 1] = step(right(eps, y), left(eps, y), lambda / 2);
102         c[i + 2] = step(right(eps, z), left(eps, z), lambda / 2);
103         break;
104       case 3:
105         c[i + 1] = step(right(eps, y), left(eps, y), (1 + lambda) / 2);
106         c[i + 2] = step(right(eps, z), left(eps, z), (1 + lambda) / 2);
107         break;
108       default:
109         c[i + 1] = right(eps, y);
110         c[i + 2] = right(eps, z);
111     }
112   }
113   PetscCall(VecRestoreArray(coord, &c));
114   PetscFunctionReturn(PETSC_SUCCESS);
115 }
116 
117 // -----------------------------------------------------------------------------
118 // Create BC label
119 // -----------------------------------------------------------------------------
120 static PetscErrorCode CreateBCLabel(DM dm, const char name[]) {
121   DMLabel label;
122 
123   PetscFunctionBeginUser;
124   PetscCall(DMCreateLabel(dm, name));
125   PetscCall(DMGetLabel(dm, name, &label));
126   PetscCall(DMPlexMarkBoundaryFaces(dm, PETSC_DETERMINE, label));
127   PetscCall(DMPlexLabelComplete(dm, label));
128   PetscFunctionReturn(PETSC_SUCCESS);
129 }
130 
131 // -----------------------------------------------------------------------------
132 // This function sets up a DM for a given degree
133 // -----------------------------------------------------------------------------
134 PetscErrorCode SetupDMByDegree(DM dm, PetscInt p_degree, PetscInt q_extra, PetscInt num_comp_u, PetscInt dim, bool enforce_bc) {
135   PetscInt  marker_ids[1] = {1};
136   PetscInt  q_degree      = p_degree + q_extra;
137   PetscFE   fe;
138   MPI_Comm  comm;
139   PetscBool is_simplex = PETSC_TRUE;
140 
141   PetscFunctionBeginUser;
142   // Check if simplex or tensor-product mesh
143   PetscCall(DMPlexIsSimplex(dm, &is_simplex));
144   // Setup FE
145   PetscCall(PetscObjectGetComm((PetscObject)dm, &comm));
146   PetscCall(PetscFECreateLagrange(comm, dim, num_comp_u, is_simplex, p_degree, q_degree, &fe));
147   PetscCall(DMAddField(dm, NULL, (PetscObject)fe));
148   PetscCall(DMCreateDS(dm));
149 
150   {
151     // create FE field for coordinates
152     PetscFE  fe_coords;
153     PetscInt num_comp_coord;
154     PetscCall(DMGetCoordinateDim(dm, &num_comp_coord));
155     PetscCall(PetscFECreateLagrange(comm, dim, num_comp_coord, is_simplex, 1, q_degree, &fe_coords));
156     PetscCall(DMSetCoordinateDisc(dm, fe_coords, PETSC_TRUE));
157     PetscCall(PetscFEDestroy(&fe_coords));
158   }
159 
160   // Setup Dirichlet BC
161   // Note bp1, bp2 are projection and we don't need to apply BC
162   // For bp3,bp4, the target function is zero on the boundaries
163   // So we pass bcFunc = NULL in DMAddBoundary function
164   if (enforce_bc) {
165     PetscBool has_label;
166     PetscCall(DMHasLabel(dm, "marker", &has_label));
167     if (!has_label) {
168       PetscCall(CreateBCLabel(dm, "marker"));
169     }
170     DMLabel label;
171     PetscCall(DMGetLabel(dm, "marker", &label));
172     PetscCall(DMAddBoundary(dm, DM_BC_ESSENTIAL, "wall", label, 1, marker_ids, 0, 0, NULL, NULL, NULL, NULL, NULL));
173     PetscCall(DMSetOptionsPrefix(dm, "final_"));
174     PetscCall(DMViewFromOptions(dm, NULL, "-dm_view"));
175   }
176 
177   if (!is_simplex) {
178     DM dm_coord;
179     PetscCall(DMGetCoordinateDM(dm, &dm_coord));
180     PetscCall(DMPlexSetClosurePermutationTensor(dm, PETSC_DETERMINE, NULL));
181     PetscCall(DMPlexSetClosurePermutationTensor(dm_coord, PETSC_DETERMINE, NULL));
182   }
183   PetscCall(PetscFEDestroy(&fe));
184   PetscFunctionReturn(PETSC_SUCCESS);
185 }
186 
187 // -----------------------------------------------------------------------------
188 // Get CEED restriction data from DMPlex
189 // -----------------------------------------------------------------------------
190 PetscErrorCode CreateRestrictionFromPlex(Ceed ceed, DM dm, CeedInt height, DMLabel domain_label, CeedInt value, CeedElemRestriction *elem_restr) {
191   PetscInt num_elem, elem_size, num_dof, num_comp, *elem_restr_offsets_petsc;
192   CeedInt *elem_restr_offsets_ceed;
193 
194   PetscFunctionBeginUser;
195   PetscCall(DMPlexGetLocalOffsets(dm, domain_label, value, height, 0, &num_elem, &elem_size, &num_comp, &num_dof, &elem_restr_offsets_petsc));
196 
197   PetscCall(IntArrayPetscToCeed(num_elem * elem_size, &elem_restr_offsets_petsc, &elem_restr_offsets_ceed));
198   CeedElemRestrictionCreate(ceed, num_elem, elem_size, num_comp, 1, num_dof, CEED_MEM_HOST, CEED_COPY_VALUES, elem_restr_offsets_ceed, elem_restr);
199   PetscCall(PetscFree(elem_restr_offsets_ceed));
200   PetscFunctionReturn(PETSC_SUCCESS);
201 }
202 
203 // -----------------------------------------------------------------------------
204 // Utility function - convert from DMPolytopeType to CeedElemTopology
205 // -----------------------------------------------------------------------------
206 CeedElemTopology ElemTopologyP2C(DMPolytopeType cell_type) {
207   switch (cell_type) {
208     case DM_POLYTOPE_TRIANGLE:
209       return CEED_TOPOLOGY_TRIANGLE;
210     case DM_POLYTOPE_QUADRILATERAL:
211       return CEED_TOPOLOGY_QUAD;
212     case DM_POLYTOPE_TETRAHEDRON:
213       return CEED_TOPOLOGY_TET;
214     case DM_POLYTOPE_HEXAHEDRON:
215       return CEED_TOPOLOGY_HEX;
216     default:
217       return 0;
218   }
219 }
220 
221 // -----------------------------------------------------------------------------
222 // Convert DM field to DS field
223 // -----------------------------------------------------------------------------
224 PetscErrorCode DMFieldToDSField(DM dm, DMLabel domain_label, PetscInt dm_field, PetscInt *ds_field) {
225   PetscDS         ds;
226   IS              field_is;
227   const PetscInt *fields;
228   PetscInt        num_fields;
229 
230   PetscFunctionBeginUser;
231   // Translate dm_field to ds_field
232   PetscCall(DMGetRegionDS(dm, domain_label, &field_is, &ds, NULL));
233   PetscCall(ISGetIndices(field_is, &fields));
234   PetscCall(ISGetSize(field_is, &num_fields));
235   for (PetscInt i = 0; i < num_fields; i++) {
236     if (dm_field == fields[i]) {
237       *ds_field = i;
238       break;
239     }
240   }
241   PetscCall(ISRestoreIndices(field_is, &fields));
242 
243   if (*ds_field == -1) SETERRQ(PetscObjectComm((PetscObject)dm), PETSC_ERR_SUP, "Could not find dm_field %" PetscInt_FMT " in DS", dm_field);
244   PetscFunctionReturn(PETSC_SUCCESS);
245 }
246 
247 // -----------------------------------------------------------------------------
248 // Create libCEED Basis from PetscTabulation
249 // -----------------------------------------------------------------------------
250 PetscErrorCode BasisCreateFromTabulation(Ceed ceed, DM dm, DMLabel domain_label, PetscInt label_value, PetscInt height, PetscInt face, PetscFE fe,
251                                          PetscTabulation basis_tabulation, PetscQuadrature quadrature, CeedBasis *basis) {
252   PetscInt           first_point;
253   PetscInt           ids[1] = {label_value};
254   DMLabel            depth_label;
255   DMPolytopeType     cell_type;
256   CeedElemTopology   elem_topo;
257   PetscScalar       *q_points, *interp, *grad;
258   const PetscScalar *q_weights;
259   PetscDualSpace     dual_space;
260   PetscInt           num_dual_basis_vectors;
261   PetscInt           dim, num_comp, P, Q;
262 
263   PetscFunctionBeginUser;
264   // General basis information
265   PetscCall(PetscFEGetSpatialDimension(fe, &dim));
266   PetscCall(PetscFEGetNumComponents(fe, &num_comp));
267   PetscCall(PetscFEGetDualSpace(fe, &dual_space));
268   PetscCall(PetscDualSpaceGetDimension(dual_space, &num_dual_basis_vectors));
269   P = num_dual_basis_vectors / num_comp;
270 
271   // Use depth label if no domain label present
272   if (!domain_label) {
273     PetscInt depth;
274 
275     PetscCall(DMPlexGetDepth(dm, &depth));
276     PetscCall(DMPlexGetDepthLabel(dm, &depth_label));
277     ids[0] = depth - height;
278   }
279 
280   // Get cell interp, grad, and quadrature data
281   PetscCall(DMGetFirstLabeledPoint(dm, dm, domain_label ? domain_label : depth_label, 1, ids, height, &first_point, NULL));
282   PetscCall(DMPlexGetCellType(dm, first_point, &cell_type));
283   elem_topo = ElemTopologyP2C(cell_type);
284   if (!elem_topo) SETERRQ(PetscObjectComm((PetscObject)dm), PETSC_ERR_SUP, "DMPlex topology not supported");
285   {
286     size_t             q_points_size;
287     const PetscScalar *q_points_petsc;
288     PetscInt           q_dim;
289 
290     PetscCall(PetscQuadratureGetData(quadrature, &q_dim, NULL, &Q, &q_points_petsc, &q_weights));
291     q_points_size = Q * dim * sizeof(CeedScalar);
292     PetscCall(PetscCalloc(q_points_size, &q_points));
293     for (PetscInt q = 0; q < Q; q++) {
294       for (PetscInt d = 0; d < q_dim; d++) q_points[q * dim + d] = q_points_petsc[q * q_dim + d];
295     }
296   }
297 
298   // Convert to libCEED orientation
299   {
300     PetscBool       is_simplex  = PETSC_FALSE;
301     IS              permutation = NULL;
302     const PetscInt *permutation_indices;
303 
304     PetscCall(DMPlexIsSimplex(dm, &is_simplex));
305     if (!is_simplex) {
306       PetscSection section;
307 
308       // -- Get permutation
309       PetscCall(DMGetLocalSection(dm, &section));
310       PetscCall(PetscSectionGetClosurePermutation(section, (PetscObject)dm, dim, num_comp * P, &permutation));
311       PetscCall(ISGetIndices(permutation, &permutation_indices));
312     }
313 
314     // -- Copy interp, grad matrices
315     PetscCall(PetscCalloc(P * Q * sizeof(CeedScalar), &interp));
316     PetscCall(PetscCalloc(P * Q * dim * sizeof(CeedScalar), &grad));
317     const CeedInt c = 0;
318     for (CeedInt q = 0; q < Q; q++) {
319       for (CeedInt p_ceed = 0; p_ceed < P; p_ceed++) {
320         CeedInt p_petsc = is_simplex ? (p_ceed * num_comp) : permutation_indices[p_ceed * num_comp];
321 
322         interp[q * P + p_ceed] = basis_tabulation->T[0][((face * Q + q) * P * num_comp + p_petsc) * num_comp + c];
323         for (CeedInt d = 0; d < dim; d++) {
324           grad[(d * Q + q) * P + p_ceed] = basis_tabulation->T[1][(((face * Q + q) * P * num_comp + p_petsc) * num_comp + c) * dim + d];
325         }
326       }
327     }
328 
329     // -- Cleanup
330     if (permutation) PetscCall(ISRestoreIndices(permutation, &permutation_indices));
331     PetscCall(ISDestroy(&permutation));
332   }
333 
334   // Finally, create libCEED basis
335   CeedBasisCreateH1(ceed, elem_topo, num_comp, P, Q, interp, grad, q_points, q_weights, basis);
336   PetscCall(PetscFree(q_points));
337   PetscCall(PetscFree(interp));
338   PetscCall(PetscFree(grad));
339   PetscFunctionReturn(PETSC_SUCCESS);
340 }
341 
342 // -----------------------------------------------------------------------------
343 // Get CEED Basis from DMPlex
344 // -----------------------------------------------------------------------------
345 PetscErrorCode CreateBasisFromPlex(Ceed ceed, DM dm, DMLabel domain_label, CeedInt label_value, CeedInt height, CeedInt dm_field, BPData bp_data,
346                                    CeedBasis *basis) {
347   PetscDS         ds;
348   PetscFE         fe;
349   PetscQuadrature quadrature;
350   PetscBool       is_simplex = PETSC_TRUE;
351   PetscInt        ds_field   = -1;
352 
353   PetscFunctionBeginUser;
354   // Get element information
355   PetscCall(DMGetRegionDS(dm, domain_label, NULL, &ds, NULL));
356   PetscCall(DMFieldToDSField(dm, domain_label, dm_field, &ds_field));
357   PetscCall(PetscDSGetDiscretization(ds, ds_field, (PetscObject *)&fe));
358   PetscCall(PetscFEGetHeightSubspace(fe, height, &fe));
359   PetscCall(PetscFEGetQuadrature(fe, &quadrature));
360 
361   // Check if simplex or tensor-product mesh
362   PetscCall(DMPlexIsSimplex(dm, &is_simplex));
363 
364   // Build libCEED basis
365   if (is_simplex) {
366     PetscTabulation basis_tabulation;
367     PetscInt        num_derivatives = 1, face = 0;
368 
369     PetscCall(PetscFEGetCellTabulation(fe, num_derivatives, &basis_tabulation));
370     PetscCall(BasisCreateFromTabulation(ceed, dm, domain_label, label_value, height, face, fe, basis_tabulation, quadrature, basis));
371   } else {
372     PetscDualSpace dual_space;
373     PetscInt       num_dual_basis_vectors;
374     PetscInt       dim, num_comp, P, Q;
375 
376     PetscCall(PetscFEGetSpatialDimension(fe, &dim));
377     PetscCall(PetscFEGetNumComponents(fe, &num_comp));
378     PetscCall(PetscFEGetDualSpace(fe, &dual_space));
379     PetscCall(PetscDualSpaceGetDimension(dual_space, &num_dual_basis_vectors));
380     P = num_dual_basis_vectors / num_comp;
381     PetscCall(PetscQuadratureGetData(quadrature, NULL, NULL, &Q, NULL, NULL));
382 
383     CeedInt P_1d = (CeedInt)round(pow(P, 1.0 / dim));
384     CeedInt Q_1d = (CeedInt)round(pow(Q, 1.0 / dim));
385 
386     CeedBasisCreateTensorH1Lagrange(ceed, dim, num_comp, P_1d, Q_1d, bp_data.q_mode, basis);
387   }
388   PetscFunctionReturn(PETSC_SUCCESS);
389 }
390 
391 // -----------------------------------------------------------------------------
392 // Utilities
393 // -----------------------------------------------------------------------------
394 
395 // Utility function, compute three factors of an integer
396 static void Split3(PetscInt size, PetscInt m[3], bool reverse) {
397   for (PetscInt d = 0, size_left = size; d < 3; d++) {
398     PetscInt try = (PetscInt)PetscCeilReal(PetscPowReal(size_left, 1. / (3 - d)));
399     while (try * (size_left / try) != size_left) try++;
400     m[reverse ? 2 - d : d] = try;
401     size_left /= try;
402   }
403 }
404 
405 static int Max3(const PetscInt a[3]) { return PetscMax(a[0], PetscMax(a[1], a[2])); }
406 
407 static int Min3(const PetscInt a[3]) { return PetscMin(a[0], PetscMin(a[1], a[2])); }
408 
409 // -----------------------------------------------------------------------------
410 // Create distribute dm
411 // -----------------------------------------------------------------------------
412 PetscErrorCode CreateDistributedDM(RunParams rp, DM *dm) {
413   PetscFunctionBeginUser;
414   // Setup DM
415   if (rp->read_mesh) {
416     PetscCall(DMPlexCreateFromFile(PETSC_COMM_WORLD, rp->filename, NULL, PETSC_TRUE, dm));
417   } else {
418     if (rp->user_l_nodes) {
419       // Find a nicely composite number of elements no less than global nodes
420       PetscMPIInt size;
421       PetscCall(MPI_Comm_size(rp->comm, &size));
422       for (PetscInt g_elem = PetscMax(1, size * rp->local_nodes / PetscPowInt(rp->degree, rp->dim));; g_elem++) {
423         Split3(g_elem, rp->mesh_elem, true);
424         if (Max3(rp->mesh_elem) / Min3(rp->mesh_elem) <= 2) break;
425       }
426     }
427 
428     PetscCall(DMPlexCreateBoxMesh(PETSC_COMM_WORLD, rp->dim, rp->simplex, rp->mesh_elem, NULL, NULL, NULL, PETSC_TRUE, 0, PETSC_FALSE, dm));
429   }
430 
431   PetscCall(DMSetFromOptions(*dm));
432   PetscCall(DMViewFromOptions(*dm, NULL, "-dm_view"));
433   PetscFunctionReturn(PETSC_SUCCESS);
434 }
435 
436 // -----------------------------------------------------------------------------
437