xref: /petsc/src/dm/interface/dmceed.c (revision 80f88c66aac8f9d1abe2df64bcd926bcfcc2a09a)
1 #include <petsc/private/dmimpl.h> /*I      "petscdm.h"          I*/
2 #include <petscdmceed.h>
3 
4 #ifdef PETSC_HAVE_LIBCEED
5   #include <petsc/private/dmpleximpl.h>
6   #include <petscdmplexceed.h>
7   #include <petscfeceed.h>
8 
9 /*@C
10   DMGetCeed - Get the LibCEED context associated with this `DM`
11 
12   Not Collective
13 
14   Input Parameter:
15 . DM   - The `DM`
16 
17   Output Parameter:
18 . ceed - The LibCEED context
19 
20   Level: intermediate
21 
22 .seealso: `DM`, `DMCreate()`
23 @*/
24 PetscErrorCode DMGetCeed(DM dm, Ceed *ceed)
25 {
26   PetscFunctionBegin;
27   PetscValidHeaderSpecific(dm, DM_CLASSID, 1);
28   PetscAssertPointer(ceed, 2);
29   if (!dm->ceed) {
30     char        ceedresource[PETSC_MAX_PATH_LEN]; /* libCEED resource specifier */
31     const char *prefix;
32 
33     PetscCall(PetscStrncpy(ceedresource, "/cpu/self", sizeof(ceedresource)));
34     PetscCall(PetscObjectGetOptionsPrefix((PetscObject)dm, &prefix));
35     PetscCall(PetscOptionsGetString(NULL, prefix, "-dm_ceed", ceedresource, sizeof(ceedresource), NULL));
36     PetscCallCEED(CeedInit(ceedresource, &dm->ceed));
37   }
38   *ceed = dm->ceed;
39   PetscFunctionReturn(PETSC_SUCCESS);
40 }
41 
42 static CeedMemType PetscMemType2Ceed(PetscMemType mem_type)
43 {
44   return PetscMemTypeDevice(mem_type) ? CEED_MEM_DEVICE : CEED_MEM_HOST;
45 }
46 
47 PetscErrorCode VecGetCeedVector(Vec X, Ceed ceed, CeedVector *cx)
48 {
49   PetscMemType memtype;
50   PetscScalar *x;
51   PetscInt     n;
52 
53   PetscFunctionBegin;
54   PetscCall(VecGetLocalSize(X, &n));
55   PetscCall(VecGetArrayAndMemType(X, &x, &memtype));
56   PetscCallCEED(CeedVectorCreate(ceed, n, cx));
57   PetscCallCEED(CeedVectorSetArray(*cx, PetscMemType2Ceed(memtype), CEED_USE_POINTER, x));
58   PetscFunctionReturn(PETSC_SUCCESS);
59 }
60 
61 PetscErrorCode VecRestoreCeedVector(Vec X, CeedVector *cx)
62 {
63   PetscFunctionBegin;
64   PetscCall(VecRestoreArrayAndMemType(X, NULL));
65   PetscCallCEED(CeedVectorDestroy(cx));
66   PetscFunctionReturn(PETSC_SUCCESS);
67 }
68 
69 PetscErrorCode VecGetCeedVectorRead(Vec X, Ceed ceed, CeedVector *cx)
70 {
71   PetscMemType       memtype;
72   const PetscScalar *x;
73   PetscInt           n;
74   PetscFunctionBegin;
75   PetscCall(VecGetLocalSize(X, &n));
76   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype));
77   PetscCallCEED(CeedVectorCreate(ceed, n, cx));
78   PetscCallCEED(CeedVectorSetArray(*cx, PetscMemType2Ceed(memtype), CEED_USE_POINTER, (PetscScalar *)x));
79   PetscFunctionReturn(PETSC_SUCCESS);
80 }
81 
82 PetscErrorCode VecRestoreCeedVectorRead(Vec X, CeedVector *cx)
83 {
84   PetscFunctionBegin;
85   PetscCall(VecRestoreArrayReadAndMemType(X, NULL));
86   PetscCallCEED(CeedVectorDestroy(cx));
87   PetscFunctionReturn(PETSC_SUCCESS);
88 }
89 
90 CEED_QFUNCTION(Geometry2D)(void *ctx, const CeedInt Q, const CeedScalar *const *in, CeedScalar *const *out)
91 {
92   const CeedScalar *x = in[0], *Jac = in[1], *w = in[2];
93   CeedScalar       *qdata = out[0];
94 
95   CeedPragmaSIMD for (CeedInt i = 0; i < Q; ++i)
96   {
97     const CeedScalar J[2][2] = {
98       {Jac[i + Q * 0], Jac[i + Q * 2]},
99       {Jac[i + Q * 1], Jac[i + Q * 3]}
100     };
101     const CeedScalar det = J[0][0] * J[1][1] - J[0][1] * J[1][0];
102 
103     qdata[i + Q * 0] = det * w[i];
104     qdata[i + Q * 1] = x[i + Q * 0];
105     qdata[i + Q * 2] = x[i + Q * 1];
106     qdata[i + Q * 3] = J[1][1] / det;
107     qdata[i + Q * 4] = -J[1][0] / det;
108     qdata[i + Q * 5] = -J[0][1] / det;
109     qdata[i + Q * 6] = J[0][0] / det;
110   }
111   return CEED_ERROR_SUCCESS;
112 }
113 
114 CEED_QFUNCTION(Geometry3D)(void *ctx, const CeedInt Q, const CeedScalar *const *in, CeedScalar *const *out)
115 {
116   const CeedScalar *Jac = in[1], *w = in[2];
117   CeedScalar       *qdata = out[0];
118 
119   CeedPragmaSIMD for (CeedInt i = 0; i < Q; ++i)
120   {
121     const CeedScalar J[3][3] = {
122       {Jac[i + Q * 0], Jac[i + Q * 3], Jac[i + Q * 6]},
123       {Jac[i + Q * 1], Jac[i + Q * 4], Jac[i + Q * 7]},
124       {Jac[i + Q * 2], Jac[i + Q * 5], Jac[i + Q * 8]}
125     };
126     const CeedScalar det = J[0][0] * (J[1][1] * J[2][2] - J[1][2] * J[2][1]) + J[0][1] * (J[1][2] * J[2][0] - J[1][0] * J[2][2]) + J[0][2] * (J[1][0] * J[2][1] - J[1][1] * J[2][0]);
127 
128     qdata[i + Q * 0] = det * w[i]; /* det J * weight */
129   }
130   return CEED_ERROR_SUCCESS;
131 }
132 
133 static PetscErrorCode DMCeedCreateGeometry(DM dm, IS cellIS, PetscInt *Nqdata, CeedElemRestriction *erq, CeedVector *qd, DMCeed *soldata)
134 {
135   Ceed              ceed;
136   DMCeed            sd;
137   PetscDS           ds;
138   PetscFE           fe;
139   CeedQFunctionUser geom     = NULL;
140   const char       *geomName = NULL;
141   const PetscInt   *cells;
142   PetscInt          dim, cdim, cStart, cEnd, Ncell;
143   CeedInt           Nq;
144 
145   PetscFunctionBegin;
146   PetscCall(PetscCalloc1(1, &sd));
147   PetscCall(DMGetDimension(dm, &dim));
148   PetscCall(DMGetCoordinateDim(dm, &cdim));
149   PetscCall(DMGetCeed(dm, &ceed));
150   PetscCall(ISGetPointRange(cellIS, &cStart, &cEnd, &cells));
151   Ncell = cEnd - cStart;
152 
153   PetscCall(DMGetCellDS(dm, cells ? cells[cStart] : cStart, &ds, NULL));
154   PetscCall(PetscDSGetDiscretization(ds, 0, (PetscObject *)&fe));
155   PetscCall(PetscFEGetCeedBasis(fe, &sd->basis));
156   PetscCall(CeedBasisGetNumQuadraturePoints(sd->basis, &Nq));
157   PetscCall(DMPlexGetCeedRestriction(dm, NULL, 0, 0, 0, &sd->er));
158 
159   *Nqdata = 1 + cdim + cdim * dim; // |J| * w_q, x, J^{-1}
160   PetscCallCEED(CeedElemRestrictionCreateStrided(ceed, Ncell, Nq, *Nqdata, Ncell * Nq * (*Nqdata), CEED_STRIDES_BACKEND, erq));
161 
162   switch (dim) {
163   case 2:
164     geom     = Geometry2D;
165     geomName = Geometry2D_loc;
166     break;
167   case 3:
168     geom     = Geometry3D;
169     geomName = Geometry3D_loc;
170     break;
171   }
172   PetscCallCEED(CeedQFunctionCreateInterior(ceed, 1, geom, geomName, &sd->qf));
173   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "x", cdim, CEED_EVAL_INTERP));
174   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "dx", cdim * dim, CEED_EVAL_GRAD));
175   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "weight", 1, CEED_EVAL_WEIGHT));
176   PetscCallCEED(CeedQFunctionAddOutput(sd->qf, "qdata", *Nqdata, CEED_EVAL_NONE));
177 
178   PetscCallCEED(CeedOperatorCreate(ceed, sd->qf, CEED_QFUNCTION_NONE, CEED_QFUNCTION_NONE, &sd->op));
179   PetscCallCEED(CeedOperatorSetField(sd->op, "x", sd->er, sd->basis, CEED_VECTOR_ACTIVE));
180   PetscCallCEED(CeedOperatorSetField(sd->op, "dx", sd->er, sd->basis, CEED_VECTOR_ACTIVE));
181   PetscCallCEED(CeedOperatorSetField(sd->op, "weight", CEED_ELEMRESTRICTION_NONE, sd->basis, CEED_VECTOR_NONE));
182   PetscCallCEED(CeedOperatorSetField(sd->op, "qdata", *erq, CEED_BASIS_NONE, CEED_VECTOR_ACTIVE));
183 
184   PetscCallCEED(CeedElemRestrictionCreateVector(*erq, qd, NULL));
185   *soldata = sd;
186   PetscFunctionReturn(PETSC_SUCCESS);
187 }
188 
189 PetscErrorCode DMRefineHook_Ceed(DM coarse, DM fine, void *ctx)
190 {
191   PetscFunctionBegin;
192   if (coarse->dmceed) PetscCall(DMCeedCreate(fine, coarse->dmceed->geom ? PETSC_TRUE : PETSC_FALSE, coarse->dmceed->func, coarse->dmceed->funcSource));
193   PetscFunctionReturn(PETSC_SUCCESS);
194 }
195 
196 PetscErrorCode DMCeedCreate_Internal(DM dm, IS cellIS, PetscBool createGeometry, CeedQFunctionUser func, const char *func_source, DMCeed *soldata)
197 {
198   PetscDS  ds;
199   PetscFE  fe;
200   DMCeed   sd;
201   Ceed     ceed;
202   PetscInt dim, Nc, Nqdata = 0;
203   CeedInt  Nq;
204 
205   PetscFunctionBegin;
206   PetscCall(PetscCalloc1(1, &sd));
207   PetscCall(DMGetDimension(dm, &dim));
208   PetscCall(DMGetCeed(dm, &ceed));
209   PetscCall(DMGetDS(dm, &ds));
210   PetscCall(PetscDSGetDiscretization(ds, 0, (PetscObject *)&fe));
211   PetscCall(PetscFEGetCeedBasis(fe, &sd->basis));
212   PetscCall(PetscFEGetNumComponents(fe, &Nc));
213   PetscCall(CeedBasisGetNumQuadraturePoints(sd->basis, &Nq));
214   PetscCall(DMPlexGetCeedRestriction(dm, NULL, 0, 0, 0, &sd->er));
215 
216   if (createGeometry) {
217     DM cdm;
218 
219     PetscCall(DMGetCoordinateDM(dm, &cdm));
220     PetscCall(DMCeedCreateGeometry(cdm, cellIS, &Nqdata, &sd->erq, &sd->qd, &sd->geom));
221   }
222 
223   if (sd->geom) {
224     PetscInt cdim;
225     CeedInt  Nqx;
226 
227     PetscCallCEED(CeedBasisGetNumQuadraturePoints(sd->geom->basis, &Nqx));
228     PetscCheck(Nqx == Nq, PetscObjectComm((PetscObject)dm), PETSC_ERR_ARG_INCOMP, "Number of qpoints for solution %" CeedInt_FMT " != %" CeedInt_FMT " Number of qpoints for coordinates", Nq, Nqx);
229     /* TODO Remove this limitation */
230     PetscCall(DMGetCoordinateDim(dm, &cdim));
231     PetscCheck(dim == cdim, PetscObjectComm((PetscObject)dm), PETSC_ERR_ARG_INCOMP, "Topological dimension %" PetscInt_FMT " != %" PetscInt_FMT " embedding dimension", dim, cdim);
232   }
233 
234   PetscCallCEED(CeedQFunctionCreateInterior(ceed, 1, func, func_source, &sd->qf));
235   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "u", Nc, CEED_EVAL_INTERP));
236   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "du", Nc * dim, CEED_EVAL_GRAD));
237   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "qdata", Nqdata, CEED_EVAL_NONE));
238   PetscCallCEED(CeedQFunctionAddOutput(sd->qf, "v", Nc, CEED_EVAL_INTERP));
239   PetscCallCEED(CeedQFunctionAddOutput(sd->qf, "dv", Nc * dim, CEED_EVAL_GRAD));
240 
241   PetscCallCEED(CeedOperatorCreate(ceed, sd->qf, CEED_QFUNCTION_NONE, CEED_QFUNCTION_NONE, &sd->op));
242   PetscCallCEED(CeedOperatorSetField(sd->op, "u", sd->er, sd->basis, CEED_VECTOR_ACTIVE));
243   PetscCallCEED(CeedOperatorSetField(sd->op, "du", sd->er, sd->basis, CEED_VECTOR_ACTIVE));
244   PetscCallCEED(CeedOperatorSetField(sd->op, "qdata", sd->erq, CEED_BASIS_NONE, sd->qd));
245   PetscCallCEED(CeedOperatorSetField(sd->op, "v", sd->er, sd->basis, CEED_VECTOR_ACTIVE));
246   PetscCallCEED(CeedOperatorSetField(sd->op, "dv", sd->er, sd->basis, CEED_VECTOR_ACTIVE));
247 
248   // Handle refinement
249   sd->func = func;
250   PetscCall(PetscStrallocpy(func_source, &sd->funcSource));
251   PetscCall(DMRefineHookAdd(dm, DMRefineHook_Ceed, NULL, NULL));
252 
253   *soldata = sd;
254   PetscFunctionReturn(PETSC_SUCCESS);
255 }
256 
257 PetscErrorCode DMCeedCreate(DM dm, PetscBool createGeometry, CeedQFunctionUser func, const char *func_source)
258 {
259   DM plex;
260   IS cellIS;
261 
262   PetscFunctionBegin;
263   PetscCall(DMConvert(dm, DMPLEX, &plex));
264   PetscCall(DMPlexGetAllCells_Internal(plex, &cellIS));
265   #ifdef PETSC_HAVE_LIBCEED
266   PetscCall(DMCeedCreate_Internal(dm, cellIS, createGeometry, func, func_source, &dm->dmceed));
267   #endif
268   PetscCall(ISDestroy(&cellIS));
269   PetscCall(DMDestroy(&plex));
270   PetscFunctionReturn(PETSC_SUCCESS);
271 }
272 
273 static PetscErrorCode DMCeedCreateGeometryFVM(DM dm, IS faceIS, PetscInt *Nqdata, CeedElemRestriction *erq, CeedVector *qd, DMCeed *soldata)
274 {
275   Ceed            ceed;
276   DMCeed          sd;
277   const PetscInt *faces;
278   PetscInt        dim, cdim, fStart, fEnd, Nface, Nq = 1;
279 
280   PetscFunctionBegin;
281   PetscCall(PetscCalloc1(1, &sd));
282   PetscCall(DMGetDimension(dm, &dim));
283   PetscCall(DMGetCoordinateDim(dm, &cdim));
284   PetscCall(DMGetCeed(dm, &ceed));
285   PetscCall(ISGetPointRange(faceIS, &fStart, &fEnd, &faces));
286   Nface = fEnd - fStart;
287 
288   *Nqdata = cdim + 2; // face normal and support cell volumes
289   PetscCallCEED(CeedElemRestrictionCreateStrided(ceed, Nface, Nq, *Nqdata, Nface * Nq * (*Nqdata), CEED_STRIDES_BACKEND, erq));
290   PetscCallCEED(CeedElemRestrictionCreateVector(*erq, qd, NULL));
291   *soldata = sd;
292   PetscFunctionReturn(PETSC_SUCCESS);
293 }
294 
295 PetscErrorCode DMCeedCreateFVM_Internal(DM dm, IS faceIS, PetscBool createGeometry, CeedQFunctionUser func, const char *func_source, DMCeed *soldata, CeedQFunctionContext qfCtx)
296 {
297   PetscDS  ds;
298   PetscFV  fv;
299   DMCeed   sd;
300   Ceed     ceed;
301   PetscInt dim, Nc, Nqdata = 0;
302 
303   PetscFunctionBegin;
304   PetscCall(PetscCalloc1(1, &sd));
305   PetscCall(DMGetDimension(dm, &dim));
306   PetscCall(DMGetCeed(dm, &ceed));
307   PetscCall(DMGetDS(dm, &ds));
308   PetscCall(PetscDSGetDiscretization(ds, 0, (PetscObject *)&fv));
309   PetscCall(PetscFVGetNumComponents(fv, &Nc));
310   PetscCall(DMPlexCreateCeedRestrictionFVM(dm, &sd->erL, &sd->erR));
311 
312   if (createGeometry) {
313     DM cdm;
314 
315     PetscCall(DMGetCoordinateDM(dm, &cdm));
316     PetscCall(DMCeedCreateGeometryFVM(cdm, faceIS, &Nqdata, &sd->erq, &sd->qd, &sd->geom));
317   }
318 
319   PetscCallCEED(CeedQFunctionCreateInterior(ceed, 1, func, func_source, &sd->qf));
320   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "uL", Nc, CEED_EVAL_NONE));
321   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "uR", Nc, CEED_EVAL_NONE));
322   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "geom", Nqdata, CEED_EVAL_NONE));
323   PetscCallCEED(CeedQFunctionAddOutput(sd->qf, "cL", Nc, CEED_EVAL_NONE));
324   PetscCallCEED(CeedQFunctionAddOutput(sd->qf, "cR", Nc, CEED_EVAL_NONE));
325 
326   PetscCallCEED(CeedQFunctionSetContext(sd->qf, qfCtx));
327 
328   PetscCallCEED(CeedOperatorCreate(ceed, sd->qf, CEED_QFUNCTION_NONE, CEED_QFUNCTION_NONE, &sd->op));
329   PetscCallCEED(CeedOperatorSetField(sd->op, "uL", sd->erL, CEED_BASIS_NONE, CEED_VECTOR_ACTIVE));
330   PetscCallCEED(CeedOperatorSetField(sd->op, "uR", sd->erR, CEED_BASIS_NONE, CEED_VECTOR_ACTIVE));
331   PetscCallCEED(CeedOperatorSetField(sd->op, "geom", sd->erq, CEED_BASIS_NONE, sd->qd));
332   PetscCallCEED(CeedOperatorSetField(sd->op, "cL", sd->erL, CEED_BASIS_NONE, CEED_VECTOR_ACTIVE));
333   PetscCallCEED(CeedOperatorSetField(sd->op, "cR", sd->erR, CEED_BASIS_NONE, CEED_VECTOR_ACTIVE));
334 
335   // Handle refinement
336   sd->func = func;
337   PetscCall(PetscStrallocpy(func_source, &sd->funcSource));
338   PetscCall(DMRefineHookAdd(dm, DMRefineHook_Ceed, NULL, NULL));
339 
340   *soldata = sd;
341   PetscFunctionReturn(PETSC_SUCCESS);
342 }
343 
344 PetscErrorCode DMCeedCreateFVM(DM dm, PetscBool createGeometry, CeedQFunctionUser func, const char *func_source, CeedQFunctionContext qfCtx)
345 {
346   DM plex;
347   IS faceIS;
348 
349   PetscFunctionBegin;
350   PetscCall(DMConvert(dm, DMPLEX, &plex));
351   PetscCall(DMPlexGetAllFaces_Internal(plex, &faceIS));
352   #ifdef PETSC_HAVE_LIBCEED
353   PetscCall(DMCeedCreateFVM_Internal(dm, faceIS, createGeometry, func, func_source, &dm->dmceed, qfCtx));
354   #endif
355   PetscCall(ISDestroy(&faceIS));
356   PetscCall(DMDestroy(&plex));
357   PetscFunctionReturn(PETSC_SUCCESS);
358 }
359 
360 #endif
361 
362 PetscErrorCode DMCeedDestroy(DMCeed *pceed)
363 {
364   DMCeed p = *pceed;
365 
366   PetscFunctionBegin;
367   if (!p) PetscFunctionReturn(PETSC_SUCCESS);
368 #ifdef PETSC_HAVE_LIBCEED
369   PetscCall(PetscFree(p->funcSource));
370   if (p->qd) PetscCallCEED(CeedVectorDestroy(&p->qd));
371   if (p->op) PetscCallCEED(CeedOperatorDestroy(&p->op));
372   if (p->qf) PetscCallCEED(CeedQFunctionDestroy(&p->qf));
373   if (p->erL) PetscCallCEED(CeedElemRestrictionDestroy(&p->erL));
374   if (p->erR) PetscCallCEED(CeedElemRestrictionDestroy(&p->erR));
375   if (p->erq) PetscCallCEED(CeedElemRestrictionDestroy(&p->erq));
376   PetscCall(DMCeedDestroy(&p->geom));
377 #endif
378   PetscCall(PetscFree(p));
379   *pceed = NULL;
380   PetscFunctionReturn(PETSC_SUCCESS);
381 }
382 
383 PetscErrorCode DMCeedComputeGeometry(DM dm, DMCeed sd)
384 {
385 #ifdef PETSC_HAVE_LIBCEED
386   Ceed       ceed;
387   Vec        coords;
388   CeedVector ccoords;
389 #endif
390 
391   PetscFunctionBegin;
392 #ifdef PETSC_HAVE_LIBCEED
393   PetscCall(DMGetCeed(dm, &ceed));
394   PetscCall(DMGetCoordinatesLocal(dm, &coords));
395   PetscCall(VecGetCeedVectorRead(coords, ceed, &ccoords));
396   if (sd->geom->op) PetscCallCEED(CeedOperatorApply(sd->geom->op, ccoords, sd->qd, CEED_REQUEST_IMMEDIATE));
397   else PetscCall(DMPlexCeedComputeGeometryFVM(dm, sd->qd));
398   //PetscCallCEED(CeedVectorView(sd->qd, "%g", stdout));
399   PetscCall(VecRestoreCeedVectorRead(coords, &ccoords));
400 #endif
401   PetscFunctionReturn(PETSC_SUCCESS);
402 }
403