xref: /petsc/src/snes/impls/patch/snespatch.c (revision d52a580b706c59ca78066c1e38754e45b6b56e2b)
1 /*
2       Defines a SNES that can consist of a collection of SNESes on patches of the domain
3 */
4 #include <petsc/private/vecimpl.h>     /* For vec->map */
5 #include <petsc/private/snesimpl.h>    /*I "petscsnes.h" I*/
6 #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */
7 #include <petscsf.h>
8 #include <petscsection.h>
9 
10 typedef struct {
11   PC pc; /* The linear patch preconditioner */
12 } SNES_Patch;
13 
14 static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, PetscCtx ctx)
15 {
16   PC                 pc      = (PC)ctx;
17   PC_PATCH          *pcpatch = (PC_PATCH *)pc->data;
18   PetscInt           pt, size, i;
19   const PetscInt    *indices;
20   const PetscScalar *X;
21   PetscScalar       *XWithAll;
22 
23   PetscFunctionBegin;
24   /* scatter from x to patch->patchStateWithAll[pt] */
25   pt = pcpatch->currentPatch;
26   PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
27 
28   PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
29   PetscCall(VecGetArrayRead(x, &X));
30   PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
31 
32   for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
33 
34   PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
35   PetscCall(VecRestoreArrayRead(x, &X));
36   PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
37 
38   PetscCall(PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt));
39   PetscFunctionReturn(PETSC_SUCCESS);
40 }
41 
42 static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, PetscCtx ctx)
43 {
44   PC                 pc      = (PC)ctx;
45   PC_PATCH          *pcpatch = (PC_PATCH *)pc->data;
46   PetscInt           pt, size, i;
47   const PetscInt    *indices;
48   const PetscScalar *X;
49   PetscScalar       *XWithAll;
50 
51   PetscFunctionBegin;
52   /* scatter from x to patch->patchStateWithAll[pt] */
53   pt = pcpatch->currentPatch;
54   PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
55 
56   PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
57   PetscCall(VecGetArrayRead(x, &X));
58   PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
59 
60   for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
61 
62   PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
63   PetscCall(VecRestoreArrayRead(x, &X));
64   PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
65 
66   PetscCall(PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE));
67   PetscFunctionReturn(PETSC_SUCCESS);
68 }
69 
70 static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
71 {
72   PC_PATCH   *patch = (PC_PATCH *)pc->data;
73   const char *prefix;
74   PetscInt    i, pStart, dof, maxDof = -1;
75 
76   PetscFunctionBegin;
77   if (!pc->setupcalled) {
78     PetscCall(PetscMalloc1(patch->npatch, &patch->solver));
79     PetscCall(PCGetOptionsPrefix(pc, &prefix));
80     PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
81     for (i = 0; i < patch->npatch; ++i) {
82       SNES snes;
83 
84       PetscCall(SNESCreate(PETSC_COMM_SELF, &snes));
85       PetscCall(SNESSetOptionsPrefix(snes, prefix));
86       PetscCall(SNESAppendOptionsPrefix(snes, "sub_"));
87       PetscCall(PetscObjectIncrementTabLevel((PetscObject)snes, (PetscObject)pc, 2));
88       patch->solver[i] = (PetscObject)snes;
89 
90       PetscCall(PetscSectionGetDof(patch->gtolCountsWithAll, i + pStart, &dof));
91       maxDof = PetscMax(maxDof, dof);
92     }
93     PetscCall(VecDuplicate(patch->localUpdate, &patch->localState));
94     PetscCall(VecDuplicate(patch->patchRHS, &patch->patchResidual));
95     PetscCall(VecDuplicate(patch->patchUpdate, &patch->patchState));
96 
97     PetscCall(VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll));
98     PetscCall(VecSetUp(patch->patchStateWithAll));
99   }
100   for (i = 0; i < patch->npatch; ++i) {
101     SNES snes = (SNES)patch->solver[i];
102 
103     PetscCall(SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc));
104     PetscCall(SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc));
105   }
106   if (!pc->setupcalled && patch->optionsSet)
107     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESSetFromOptions((SNES)patch->solver[i]));
108   PetscFunctionReturn(PETSC_SUCCESS);
109 }
110 
111 static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
112 {
113   PC_PATCH *patch = (PC_PATCH *)pc->data;
114   PetscInt  pStart, n;
115 
116   PetscFunctionBegin;
117   patch->currentPatch = i;
118   PetscCall(PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0));
119 
120   /* Scatter the overlapped global state to our patch state vector */
121   PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
122   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR));
123   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL));
124 
125   PetscCall(MatGetLocalSize(patch->mat[i], NULL, &n));
126   patch->patchState->map->n = n;
127   patch->patchState->map->N = n;
128   patchUpdate->map->n       = n;
129   patchUpdate->map->N       = n;
130   patchRHS->map->n          = n;
131   patchRHS->map->N          = n;
132   /* Set initial guess to be current state*/
133   PetscCall(VecCopy(patch->patchState, patchUpdate));
134   /* Solve for new state */
135   PetscCall(SNESSolve((SNES)patch->solver[i], patchRHS, patchUpdate));
136   /* To compute update, subtract off previous state */
137   PetscCall(VecAXPY(patchUpdate, -1.0, patch->patchState));
138 
139   PetscCall(PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0));
140   PetscFunctionReturn(PETSC_SUCCESS);
141 }
142 
143 static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
144 {
145   PC_PATCH *patch = (PC_PATCH *)pc->data;
146   PetscInt  i;
147 
148   PetscFunctionBegin;
149   if (patch->solver) {
150     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESReset((SNES)patch->solver[i]));
151   }
152 
153   PetscCall(VecDestroy(&patch->patchResidual));
154   PetscCall(VecDestroy(&patch->patchState));
155   PetscCall(VecDestroy(&patch->patchStateWithAll));
156 
157   PetscCall(VecDestroy(&patch->localState));
158   PetscFunctionReturn(PETSC_SUCCESS);
159 }
160 
161 static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
162 {
163   PC_PATCH *patch = (PC_PATCH *)pc->data;
164   PetscInt  i;
165 
166   PetscFunctionBegin;
167   if (patch->solver) {
168     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESDestroy((SNES *)&patch->solver[i]));
169     PetscCall(PetscFree(patch->solver));
170   }
171   PetscFunctionReturn(PETSC_SUCCESS);
172 }
173 
174 static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
175 {
176   PC_PATCH *patch = (PC_PATCH *)pc->data;
177 
178   PetscFunctionBegin;
179   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR));
180   PetscFunctionReturn(PETSC_SUCCESS);
181 }
182 
183 static PetscErrorCode SNESSetUp_Patch(SNES snes)
184 {
185   SNES_Patch *patch = (SNES_Patch *)snes->data;
186   DM          dm;
187   Mat         dummy;
188   Vec         F;
189   PetscInt    n, N;
190 
191   PetscFunctionBegin;
192   PetscCall(SNESGetDM(snes, &dm));
193   PetscCall(PCSetDM(patch->pc, dm));
194   PetscCall(SNESGetFunction(snes, &F, NULL, NULL));
195   PetscCall(VecGetLocalSize(F, &n));
196   PetscCall(VecGetSize(F, &N));
197   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)snes), n, n, N, N, (void *)snes, &dummy));
198   PetscCall(PCSetOperators(patch->pc, dummy, dummy));
199   PetscCall(MatDestroy(&dummy));
200   PetscCall(PCSetUp(patch->pc));
201   /* allocate workspace */
202   PetscFunctionReturn(PETSC_SUCCESS);
203 }
204 
205 static PetscErrorCode SNESReset_Patch(SNES snes)
206 {
207   SNES_Patch *patch = (SNES_Patch *)snes->data;
208 
209   PetscFunctionBegin;
210   PetscCall(PCReset(patch->pc));
211   PetscFunctionReturn(PETSC_SUCCESS);
212 }
213 
214 static PetscErrorCode SNESDestroy_Patch(SNES snes)
215 {
216   SNES_Patch *patch = (SNES_Patch *)snes->data;
217 
218   PetscFunctionBegin;
219   PetscCall(SNESReset_Patch(snes));
220   PetscCall(PCDestroy(&patch->pc));
221   PetscCall(PetscFree(snes->data));
222   PetscFunctionReturn(PETSC_SUCCESS);
223 }
224 
225 static PetscErrorCode SNESSetFromOptions_Patch(SNES snes, PetscOptionItems PetscOptionsObject)
226 {
227   SNES_Patch *patch = (SNES_Patch *)snes->data;
228   const char *prefix;
229 
230   PetscFunctionBegin;
231   PetscCall(PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix));
232   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix));
233   PetscCall(PCSetFromOptions(patch->pc));
234   PetscFunctionReturn(PETSC_SUCCESS);
235 }
236 
237 static PetscErrorCode SNESView_Patch(SNES snes, PetscViewer viewer)
238 {
239   SNES_Patch *patch = (SNES_Patch *)snes->data;
240   PetscBool   isascii;
241 
242   PetscFunctionBegin;
243   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
244   if (isascii) PetscCall(PetscViewerASCIIPrintf(viewer, "SNESPATCH\n"));
245   PetscCall(PetscViewerASCIIPushTab(viewer));
246   PetscCall(PCView(patch->pc, viewer));
247   PetscCall(PetscViewerASCIIPopTab(viewer));
248   PetscFunctionReturn(PETSC_SUCCESS);
249 }
250 
251 static PetscErrorCode SNESSolve_Patch(SNES snes)
252 {
253   SNES_Patch        *patch   = (SNES_Patch *)snes->data;
254   PC_PATCH          *pcpatch = (PC_PATCH *)patch->pc->data;
255   SNESLineSearch     ls;
256   Vec                rhs, update, state, residual;
257   const PetscScalar *globalState = NULL;
258   PetscScalar       *localState  = NULL;
259   PetscInt           its         = 0;
260   PetscReal          xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
261 
262   PetscFunctionBegin;
263   PetscCall(SNESGetSolution(snes, &state));
264   PetscCall(SNESGetSolutionUpdate(snes, &update));
265   PetscCall(SNESGetRhs(snes, &rhs));
266 
267   PetscCall(SNESGetFunction(snes, &residual, NULL, NULL));
268   PetscCall(SNESGetLineSearch(snes, &ls));
269 
270   PetscCall(SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING));
271   PetscCall(VecSet(update, 0.0));
272   PetscCall(SNESComputeFunction(snes, state, residual));
273 
274   PetscCall(VecNorm(state, NORM_2, &xnorm));
275   PetscCall(VecNorm(residual, NORM_2, &fnorm));
276   SNESCheckFunctionDomainError(snes, fnorm);
277   snes->ttol = fnorm * snes->rtol;
278 
279   if (snes->ops->converged) {
280     PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
281   } else {
282     PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
283   }
284   PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* should we count lits from the patches? */
285   PetscCall(SNESMonitor(snes, its, fnorm));
286 
287   /* The main solver loop */
288   for (its = 0; its < snes->max_its; its++) {
289     PetscCall(SNESSetIterationNumber(snes, its));
290 
291     /* Scatter state vector to overlapped vector on all patches.
292        The vector pcpatch->localState is scattered to each patch
293        in PCApply_PATCH_Nonlinear. */
294     PetscCall(VecGetArrayRead(state, &globalState));
295     PetscCall(VecGetArray(pcpatch->localState, &localState));
296     PetscCall(PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
297     PetscCall(PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
298     PetscCall(VecRestoreArray(pcpatch->localState, &localState));
299     PetscCall(VecRestoreArrayRead(state, &globalState));
300 
301     /* The looping over patches happens here */
302     PetscCall(PCApply(patch->pc, rhs, update));
303 
304     /* Apply a line search. This will often be basic with
305        damping = 1/(max number of patches a dof can be in),
306        but not always */
307     PetscCall(VecScale(update, -1.0));
308     PetscCall(SNESLineSearchApply(ls, state, residual, &fnorm, update));
309 
310     PetscCall(VecNorm(state, NORM_2, &xnorm));
311     PetscCall(VecNorm(update, NORM_2, &ynorm));
312 
313     if (snes->ops->converged) {
314       PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
315     } else {
316       PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
317     }
318     PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* FIXME: should we count lits? */
319     PetscCall(SNESMonitor(snes, its, fnorm));
320   }
321 
322   if (its == snes->max_its) PetscCall(SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT));
323   PetscFunctionReturn(PETSC_SUCCESS);
324 }
325 
326 /*MC
327   SNESPATCH - Solve a nonlinear problem or apply a nonlinear smoother by composing together many nonlinear solvers on (often overlapping) patches {cite}`bruneknepleysmithtu15`
328 
329   Level: intermediate
330 
331 .seealso: [](ch_snes), `SNESFAS`, `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, `PCPATCH`
332 M*/
333 PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
334 {
335   SNES_Patch    *patch;
336   PC_PATCH      *patchpc;
337   SNESLineSearch linesearch;
338 
339   PetscFunctionBegin;
340   PetscCall(PetscNew(&patch));
341 
342   snes->ops->solve          = SNESSolve_Patch;
343   snes->ops->setup          = SNESSetUp_Patch;
344   snes->ops->reset          = SNESReset_Patch;
345   snes->ops->destroy        = SNESDestroy_Patch;
346   snes->ops->setfromoptions = SNESSetFromOptions_Patch;
347   snes->ops->view           = SNESView_Patch;
348 
349   PetscCall(SNESGetLineSearch(snes, &linesearch));
350   if (!((PetscObject)linesearch)->type_name) PetscCall(SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC));
351   snes->usesksp = PETSC_FALSE;
352 
353   snes->alwayscomputesfinalresidual = PETSC_FALSE;
354 
355   PetscCall(SNESParametersInitialize(snes));
356 
357   snes->data = (void *)patch;
358   PetscCall(PCCreate(PetscObjectComm((PetscObject)snes), &patch->pc));
359   PetscCall(PCSetType(patch->pc, PCPATCH));
360 
361   patchpc              = (PC_PATCH *)patch->pc->data;
362   patchpc->classname   = "snes";
363   patchpc->isNonlinear = PETSC_TRUE;
364 
365   patchpc->setupsolver          = PCSetUp_PATCH_Nonlinear;
366   patchpc->applysolver          = PCApply_PATCH_Nonlinear;
367   patchpc->resetsolver          = PCReset_PATCH_Nonlinear;
368   patchpc->destroysolver        = PCDestroy_PATCH_Nonlinear;
369   patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;
370   PetscFunctionReturn(PETSC_SUCCESS);
371 }
372 
373 PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
374 {
375   SNES_Patch *patch = (SNES_Patch *)snes->data;
376   DM          dm;
377 
378   PetscFunctionBegin;
379   PetscCall(SNESGetDM(snes, &dm));
380   PetscCheck(dm, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES");
381   PetscCall(PCSetDM(patch->pc, dm));
382   PetscCall(PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes));
383   PetscFunctionReturn(PETSC_SUCCESS);
384 }
385 
386 PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), PetscCtx ctx)
387 {
388   SNES_Patch *patch = (SNES_Patch *)snes->data;
389 
390   PetscFunctionBegin;
391   PetscCall(PCPatchSetComputeOperator(patch->pc, func, ctx));
392   PetscFunctionReturn(PETSC_SUCCESS);
393 }
394 
395 PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), PetscCtx ctx)
396 {
397   SNES_Patch *patch = (SNES_Patch *)snes->data;
398 
399   PetscFunctionBegin;
400   PetscCall(PCPatchSetComputeFunction(patch->pc, func, ctx));
401   PetscFunctionReturn(PETSC_SUCCESS);
402 }
403 
404 PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), PetscCtx ctx)
405 {
406   SNES_Patch *patch = (SNES_Patch *)snes->data;
407 
408   PetscFunctionBegin;
409   PetscCall(PCPatchSetConstructType(patch->pc, ctype, func, ctx));
410   PetscFunctionReturn(PETSC_SUCCESS);
411 }
412 
413 PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
414 {
415   SNES_Patch *patch = (SNES_Patch *)snes->data;
416 
417   PetscFunctionBegin;
418   PetscCall(PCPatchSetCellNumbering(patch->pc, cellNumbering));
419   PetscFunctionReturn(PETSC_SUCCESS);
420 }
421