xref: /petsc/src/snes/impls/patch/snespatch.c (revision 21e3ffae2f3b73c0bd738cf6d0a809700fc04bb0)
1 /*
2       Defines a SNES that can consist of a collection of SNESes on patches of the domain
3 */
4 #include <petsc/private/vecimpl.h>     /* For vec->map */
5 #include <petsc/private/snesimpl.h>    /*I "petscsnes.h" I*/
6 #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */
7 #include <petscsf.h>
8 #include <petscsection.h>
9 
10 typedef struct {
11   PC pc; /* The linear patch preconditioner */
12 } SNES_Patch;
13 
14 static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
15 {
16   PC                 pc      = (PC)ctx;
17   PC_PATCH          *pcpatch = (PC_PATCH *)pc->data;
18   PetscInt           pt, size, i;
19   const PetscInt    *indices;
20   const PetscScalar *X;
21   PetscScalar       *XWithAll;
22 
23   PetscFunctionBegin;
24 
25   /* scatter from x to patch->patchStateWithAll[pt] */
26   pt = pcpatch->currentPatch;
27   PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
28 
29   PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
30   PetscCall(VecGetArrayRead(x, &X));
31   PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
32 
33   for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
34 
35   PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
36   PetscCall(VecRestoreArrayRead(x, &X));
37   PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
38 
39   PetscCall(PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt));
40   PetscFunctionReturn(PETSC_SUCCESS);
41 }
42 
43 static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
44 {
45   PC                 pc      = (PC)ctx;
46   PC_PATCH          *pcpatch = (PC_PATCH *)pc->data;
47   PetscInt           pt, size, i;
48   const PetscInt    *indices;
49   const PetscScalar *X;
50   PetscScalar       *XWithAll;
51 
52   PetscFunctionBegin;
53   /* scatter from x to patch->patchStateWithAll[pt] */
54   pt = pcpatch->currentPatch;
55   PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
56 
57   PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
58   PetscCall(VecGetArrayRead(x, &X));
59   PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
60 
61   for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
62 
63   PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
64   PetscCall(VecRestoreArrayRead(x, &X));
65   PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
66 
67   PetscCall(PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE));
68   PetscFunctionReturn(PETSC_SUCCESS);
69 }
70 
71 static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
72 {
73   PC_PATCH   *patch = (PC_PATCH *)pc->data;
74   const char *prefix;
75   PetscInt    i, pStart, dof, maxDof = -1;
76 
77   PetscFunctionBegin;
78   if (!pc->setupcalled) {
79     PetscCall(PetscMalloc1(patch->npatch, &patch->solver));
80     PetscCall(PCGetOptionsPrefix(pc, &prefix));
81     PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
82     for (i = 0; i < patch->npatch; ++i) {
83       SNES snes;
84 
85       PetscCall(SNESCreate(PETSC_COMM_SELF, &snes));
86       PetscCall(SNESSetOptionsPrefix(snes, prefix));
87       PetscCall(SNESAppendOptionsPrefix(snes, "sub_"));
88       PetscCall(PetscObjectIncrementTabLevel((PetscObject)snes, (PetscObject)pc, 2));
89       patch->solver[i] = (PetscObject)snes;
90 
91       PetscCall(PetscSectionGetDof(patch->gtolCountsWithAll, i + pStart, &dof));
92       maxDof = PetscMax(maxDof, dof);
93     }
94     PetscCall(VecDuplicate(patch->localUpdate, &patch->localState));
95     PetscCall(VecDuplicate(patch->patchRHS, &patch->patchResidual));
96     PetscCall(VecDuplicate(patch->patchUpdate, &patch->patchState));
97 
98     PetscCall(VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll));
99     PetscCall(VecSetUp(patch->patchStateWithAll));
100   }
101   for (i = 0; i < patch->npatch; ++i) {
102     SNES snes = (SNES)patch->solver[i];
103 
104     PetscCall(SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc));
105     PetscCall(SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc));
106   }
107   if (!pc->setupcalled && patch->optionsSet)
108     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESSetFromOptions((SNES)patch->solver[i]));
109   PetscFunctionReturn(PETSC_SUCCESS);
110 }
111 
112 static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
113 {
114   PC_PATCH *patch = (PC_PATCH *)pc->data;
115   PetscInt  pStart, n;
116 
117   PetscFunctionBegin;
118   patch->currentPatch = i;
119   PetscCall(PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0));
120 
121   /* Scatter the overlapped global state to our patch state vector */
122   PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
123   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR));
124   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL));
125 
126   PetscCall(MatGetLocalSize(patch->mat[i], NULL, &n));
127   patch->patchState->map->n = n;
128   patch->patchState->map->N = n;
129   patchUpdate->map->n       = n;
130   patchUpdate->map->N       = n;
131   patchRHS->map->n          = n;
132   patchRHS->map->N          = n;
133   /* Set initial guess to be current state*/
134   PetscCall(VecCopy(patch->patchState, patchUpdate));
135   /* Solve for new state */
136   PetscCall(SNESSolve((SNES)patch->solver[i], patchRHS, patchUpdate));
137   /* To compute update, subtract off previous state */
138   PetscCall(VecAXPY(patchUpdate, -1.0, patch->patchState));
139 
140   PetscCall(PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0));
141   PetscFunctionReturn(PETSC_SUCCESS);
142 }
143 
144 static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
145 {
146   PC_PATCH *patch = (PC_PATCH *)pc->data;
147   PetscInt  i;
148 
149   PetscFunctionBegin;
150   if (patch->solver) {
151     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESReset((SNES)patch->solver[i]));
152   }
153 
154   PetscCall(VecDestroy(&patch->patchResidual));
155   PetscCall(VecDestroy(&patch->patchState));
156   PetscCall(VecDestroy(&patch->patchStateWithAll));
157 
158   PetscCall(VecDestroy(&patch->localState));
159   PetscFunctionReturn(PETSC_SUCCESS);
160 }
161 
162 static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
163 {
164   PC_PATCH *patch = (PC_PATCH *)pc->data;
165   PetscInt  i;
166 
167   PetscFunctionBegin;
168   if (patch->solver) {
169     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESDestroy((SNES *)&patch->solver[i]));
170     PetscCall(PetscFree(patch->solver));
171   }
172   PetscFunctionReturn(PETSC_SUCCESS);
173 }
174 
175 static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
176 {
177   PC_PATCH *patch = (PC_PATCH *)pc->data;
178 
179   PetscFunctionBegin;
180   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR));
181   PetscFunctionReturn(PETSC_SUCCESS);
182 }
183 
184 static PetscErrorCode SNESSetUp_Patch(SNES snes)
185 {
186   SNES_Patch *patch = (SNES_Patch *)snes->data;
187   DM          dm;
188   Mat         dummy;
189   Vec         F;
190   PetscInt    n, N;
191 
192   PetscFunctionBegin;
193   PetscCall(SNESGetDM(snes, &dm));
194   PetscCall(PCSetDM(patch->pc, dm));
195   PetscCall(SNESGetFunction(snes, &F, NULL, NULL));
196   PetscCall(VecGetLocalSize(F, &n));
197   PetscCall(VecGetSize(F, &N));
198   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)snes), n, n, N, N, (void *)snes, &dummy));
199   PetscCall(PCSetOperators(patch->pc, dummy, dummy));
200   PetscCall(MatDestroy(&dummy));
201   PetscCall(PCSetUp(patch->pc));
202   /* allocate workspace */
203   PetscFunctionReturn(PETSC_SUCCESS);
204 }
205 
206 static PetscErrorCode SNESReset_Patch(SNES snes)
207 {
208   SNES_Patch *patch = (SNES_Patch *)snes->data;
209 
210   PetscFunctionBegin;
211   PetscCall(PCReset(patch->pc));
212   PetscFunctionReturn(PETSC_SUCCESS);
213 }
214 
215 static PetscErrorCode SNESDestroy_Patch(SNES snes)
216 {
217   SNES_Patch *patch = (SNES_Patch *)snes->data;
218 
219   PetscFunctionBegin;
220   PetscCall(SNESReset_Patch(snes));
221   PetscCall(PCDestroy(&patch->pc));
222   PetscCall(PetscFree(snes->data));
223   PetscFunctionReturn(PETSC_SUCCESS);
224 }
225 
226 static PetscErrorCode SNESSetFromOptions_Patch(SNES snes, PetscOptionItems *PetscOptionsObject)
227 {
228   SNES_Patch *patch = (SNES_Patch *)snes->data;
229   const char *prefix;
230 
231   PetscFunctionBegin;
232   PetscCall(PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix));
233   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix));
234   PetscCall(PCSetFromOptions(patch->pc));
235   PetscFunctionReturn(PETSC_SUCCESS);
236 }
237 
238 static PetscErrorCode SNESView_Patch(SNES snes, PetscViewer viewer)
239 {
240   SNES_Patch *patch = (SNES_Patch *)snes->data;
241   PetscBool   iascii;
242 
243   PetscFunctionBegin;
244   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
245   if (iascii) PetscCall(PetscViewerASCIIPrintf(viewer, "SNESPATCH\n"));
246   PetscCall(PetscViewerASCIIPushTab(viewer));
247   PetscCall(PCView(patch->pc, viewer));
248   PetscCall(PetscViewerASCIIPopTab(viewer));
249   PetscFunctionReturn(PETSC_SUCCESS);
250 }
251 
252 static PetscErrorCode SNESSolve_Patch(SNES snes)
253 {
254   SNES_Patch        *patch   = (SNES_Patch *)snes->data;
255   PC_PATCH          *pcpatch = (PC_PATCH *)patch->pc->data;
256   SNESLineSearch     ls;
257   Vec                rhs, update, state, residual;
258   const PetscScalar *globalState = NULL;
259   PetscScalar       *localState  = NULL;
260   PetscInt           its         = 0;
261   PetscReal          xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
262 
263   PetscFunctionBegin;
264   PetscCall(SNESGetSolution(snes, &state));
265   PetscCall(SNESGetSolutionUpdate(snes, &update));
266   PetscCall(SNESGetRhs(snes, &rhs));
267 
268   PetscCall(SNESGetFunction(snes, &residual, NULL, NULL));
269   PetscCall(SNESGetLineSearch(snes, &ls));
270 
271   PetscCall(SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING));
272   PetscCall(VecSet(update, 0.0));
273   PetscCall(SNESComputeFunction(snes, state, residual));
274 
275   PetscCall(VecNorm(state, NORM_2, &xnorm));
276   PetscCall(VecNorm(residual, NORM_2, &fnorm));
277   snes->ttol = fnorm * snes->rtol;
278 
279   if (snes->ops->converged) {
280     PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
281   } else {
282     PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
283   }
284   PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* should we count lits from the patches? */
285   PetscCall(SNESMonitor(snes, its, fnorm));
286 
287   /* The main solver loop */
288   for (its = 0; its < snes->max_its; its++) {
289     PetscCall(SNESSetIterationNumber(snes, its));
290 
291     /* Scatter state vector to overlapped vector on all patches.
292        The vector pcpatch->localState is scattered to each patch
293        in PCApply_PATCH_Nonlinear. */
294     PetscCall(VecGetArrayRead(state, &globalState));
295     PetscCall(VecGetArray(pcpatch->localState, &localState));
296     PetscCall(PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
297     PetscCall(PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
298     PetscCall(VecRestoreArray(pcpatch->localState, &localState));
299     PetscCall(VecRestoreArrayRead(state, &globalState));
300 
301     /* The looping over patches happens here */
302     PetscCall(PCApply(patch->pc, rhs, update));
303 
304     /* Apply a line search. This will often be basic with
305        damping = 1/(max number of patches a dof can be in),
306        but not always */
307     PetscCall(VecScale(update, -1.0));
308     PetscCall(SNESLineSearchApply(ls, state, residual, &fnorm, update));
309 
310     PetscCall(VecNorm(state, NORM_2, &xnorm));
311     PetscCall(VecNorm(update, NORM_2, &ynorm));
312 
313     if (snes->ops->converged) {
314       PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
315     } else {
316       PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
317     }
318     PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* FIXME: should we count lits? */
319     PetscCall(SNESMonitor(snes, its, fnorm));
320   }
321 
322   if (its == snes->max_its) PetscCall(SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT));
323   PetscFunctionReturn(PETSC_SUCCESS);
324 }
325 
326 /*MC
327   SNESPATCH - Solve a nonlinear problem or apply a nonlinear smoother by composing together many nonlinear solvers on (often overlapping) patches
328 
329   Level: intermediate
330 
331    References:
332 .  * - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015
333 
334 .seealso: `SNESFAS`, `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, `PCPATCH`
335 M*/
336 PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
337 {
338   SNES_Patch    *patch;
339   PC_PATCH      *patchpc;
340   SNESLineSearch linesearch;
341 
342   PetscFunctionBegin;
343   PetscCall(PetscNew(&patch));
344 
345   snes->ops->solve          = SNESSolve_Patch;
346   snes->ops->setup          = SNESSetUp_Patch;
347   snes->ops->reset          = SNESReset_Patch;
348   snes->ops->destroy        = SNESDestroy_Patch;
349   snes->ops->setfromoptions = SNESSetFromOptions_Patch;
350   snes->ops->view           = SNESView_Patch;
351 
352   PetscCall(SNESGetLineSearch(snes, &linesearch));
353   if (!((PetscObject)linesearch)->type_name) PetscCall(SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC));
354   snes->usesksp = PETSC_FALSE;
355 
356   snes->alwayscomputesfinalresidual = PETSC_FALSE;
357 
358   snes->data = (void *)patch;
359   PetscCall(PCCreate(PetscObjectComm((PetscObject)snes), &patch->pc));
360   PetscCall(PCSetType(patch->pc, PCPATCH));
361 
362   patchpc              = (PC_PATCH *)patch->pc->data;
363   patchpc->classname   = "snes";
364   patchpc->isNonlinear = PETSC_TRUE;
365 
366   patchpc->setupsolver          = PCSetUp_PATCH_Nonlinear;
367   patchpc->applysolver          = PCApply_PATCH_Nonlinear;
368   patchpc->resetsolver          = PCReset_PATCH_Nonlinear;
369   patchpc->destroysolver        = PCDestroy_PATCH_Nonlinear;
370   patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;
371 
372   PetscFunctionReturn(PETSC_SUCCESS);
373 }
374 
375 PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
376 {
377   SNES_Patch *patch = (SNES_Patch *)snes->data;
378   DM          dm;
379 
380   PetscFunctionBegin;
381   PetscCall(SNESGetDM(snes, &dm));
382   PetscCheck(dm, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES");
383   PetscCall(PCSetDM(patch->pc, dm));
384   PetscCall(PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes));
385   PetscFunctionReturn(PETSC_SUCCESS);
386 }
387 
388 PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
389 {
390   SNES_Patch *patch = (SNES_Patch *)snes->data;
391 
392   PetscFunctionBegin;
393   PetscCall(PCPatchSetComputeOperator(patch->pc, func, ctx));
394   PetscFunctionReturn(PETSC_SUCCESS);
395 }
396 
397 PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
398 {
399   SNES_Patch *patch = (SNES_Patch *)snes->data;
400 
401   PetscFunctionBegin;
402   PetscCall(PCPatchSetComputeFunction(patch->pc, func, ctx));
403   PetscFunctionReturn(PETSC_SUCCESS);
404 }
405 
406 PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
407 {
408   SNES_Patch *patch = (SNES_Patch *)snes->data;
409 
410   PetscFunctionBegin;
411   PetscCall(PCPatchSetConstructType(patch->pc, ctype, func, ctx));
412   PetscFunctionReturn(PETSC_SUCCESS);
413 }
414 
415 PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
416 {
417   SNES_Patch *patch = (SNES_Patch *)snes->data;
418 
419   PetscFunctionBegin;
420   PetscCall(PCPatchSetCellNumbering(patch->pc, cellNumbering));
421   PetscFunctionReturn(PETSC_SUCCESS);
422 }
423