xref: /petsc/src/snes/impls/patch/snespatch.c (revision 6797ed33bfc5684cca29e5d9b9bbfc6c8aac93e0)
1 /*
2       Defines a SNES that can consist of a collection of SNESes on patches of the domain
3 */
4 #include <petsc/private/vecimpl.h>     /* For vec->map */
5 #include <petsc/private/snesimpl.h>    /*I "petscsnes.h" I*/
6 #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */
7 #include <petscsf.h>
8 #include <petscsection.h>
9 
10 typedef struct {
11   PC pc; /* The linear patch preconditioner */
12 } SNES_Patch;
13 
14 static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx) {
15   PC                 pc      = (PC)ctx;
16   PC_PATCH          *pcpatch = (PC_PATCH *)pc->data;
17   PetscInt           pt, size, i;
18   const PetscInt    *indices;
19   const PetscScalar *X;
20   PetscScalar       *XWithAll;
21 
22   PetscFunctionBegin;
23 
24   /* scatter from x to patch->patchStateWithAll[pt] */
25   pt = pcpatch->currentPatch;
26   PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
27 
28   PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
29   PetscCall(VecGetArrayRead(x, &X));
30   PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
31 
32   for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
33 
34   PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
35   PetscCall(VecRestoreArrayRead(x, &X));
36   PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
37 
38   PetscCall(PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt));
39   PetscFunctionReturn(0);
40 }
41 
42 static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx) {
43   PC                 pc      = (PC)ctx;
44   PC_PATCH          *pcpatch = (PC_PATCH *)pc->data;
45   PetscInt           pt, size, i;
46   const PetscInt    *indices;
47   const PetscScalar *X;
48   PetscScalar       *XWithAll;
49 
50   PetscFunctionBegin;
51   /* scatter from x to patch->patchStateWithAll[pt] */
52   pt = pcpatch->currentPatch;
53   PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
54 
55   PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
56   PetscCall(VecGetArrayRead(x, &X));
57   PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
58 
59   for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
60 
61   PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
62   PetscCall(VecRestoreArrayRead(x, &X));
63   PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
64 
65   PetscCall(PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE));
66   PetscFunctionReturn(0);
67 }
68 
69 static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc) {
70   PC_PATCH   *patch = (PC_PATCH *)pc->data;
71   const char *prefix;
72   PetscInt    i, pStart, dof, maxDof = -1;
73 
74   PetscFunctionBegin;
75   if (!pc->setupcalled) {
76     PetscCall(PetscMalloc1(patch->npatch, &patch->solver));
77     PetscCall(PCGetOptionsPrefix(pc, &prefix));
78     PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
79     for (i = 0; i < patch->npatch; ++i) {
80       SNES snes;
81 
82       PetscCall(SNESCreate(PETSC_COMM_SELF, &snes));
83       PetscCall(SNESSetOptionsPrefix(snes, prefix));
84       PetscCall(SNESAppendOptionsPrefix(snes, "sub_"));
85       PetscCall(PetscObjectIncrementTabLevel((PetscObject)snes, (PetscObject)pc, 2));
86       PetscCall(PetscLogObjectParent((PetscObject)pc, (PetscObject)snes));
87       patch->solver[i] = (PetscObject)snes;
88 
89       PetscCall(PetscSectionGetDof(patch->gtolCountsWithAll, i + pStart, &dof));
90       maxDof = PetscMax(maxDof, dof);
91     }
92     PetscCall(VecDuplicate(patch->localUpdate, &patch->localState));
93     PetscCall(VecDuplicate(patch->patchRHS, &patch->patchResidual));
94     PetscCall(VecDuplicate(patch->patchUpdate, &patch->patchState));
95 
96     PetscCall(VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll));
97     PetscCall(VecSetUp(patch->patchStateWithAll));
98   }
99   for (i = 0; i < patch->npatch; ++i) {
100     SNES snes = (SNES)patch->solver[i];
101 
102     PetscCall(SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc));
103     PetscCall(SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc));
104   }
105   if (!pc->setupcalled && patch->optionsSet)
106     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESSetFromOptions((SNES)patch->solver[i]));
107   PetscFunctionReturn(0);
108 }
109 
110 static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate) {
111   PC_PATCH *patch = (PC_PATCH *)pc->data;
112   PetscInt  pStart, n;
113 
114   PetscFunctionBegin;
115   patch->currentPatch = i;
116   PetscCall(PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0));
117 
118   /* Scatter the overlapped global state to our patch state vector */
119   PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
120   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR));
121   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL));
122 
123   PetscCall(MatGetLocalSize(patch->mat[i], NULL, &n));
124   patch->patchState->map->n = n;
125   patch->patchState->map->N = n;
126   patchUpdate->map->n       = n;
127   patchUpdate->map->N       = n;
128   patchRHS->map->n          = n;
129   patchRHS->map->N          = n;
130   /* Set initial guess to be current state*/
131   PetscCall(VecCopy(patch->patchState, patchUpdate));
132   /* Solve for new state */
133   PetscCall(SNESSolve((SNES)patch->solver[i], patchRHS, patchUpdate));
134   /* To compute update, subtract off previous state */
135   PetscCall(VecAXPY(patchUpdate, -1.0, patch->patchState));
136 
137   PetscCall(PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0));
138   PetscFunctionReturn(0);
139 }
140 
141 static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc) {
142   PC_PATCH *patch = (PC_PATCH *)pc->data;
143   PetscInt  i;
144 
145   PetscFunctionBegin;
146   if (patch->solver) {
147     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESReset((SNES)patch->solver[i]));
148   }
149 
150   PetscCall(VecDestroy(&patch->patchResidual));
151   PetscCall(VecDestroy(&patch->patchState));
152   PetscCall(VecDestroy(&patch->patchStateWithAll));
153 
154   PetscCall(VecDestroy(&patch->localState));
155   PetscFunctionReturn(0);
156 }
157 
158 static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc) {
159   PC_PATCH *patch = (PC_PATCH *)pc->data;
160   PetscInt  i;
161 
162   PetscFunctionBegin;
163   if (patch->solver) {
164     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESDestroy((SNES *)&patch->solver[i]));
165     PetscCall(PetscFree(patch->solver));
166   }
167   PetscFunctionReturn(0);
168 }
169 
170 static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart) {
171   PC_PATCH *patch = (PC_PATCH *)pc->data;
172 
173   PetscFunctionBegin;
174   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR));
175   PetscFunctionReturn(0);
176 }
177 
178 static PetscErrorCode SNESSetUp_Patch(SNES snes) {
179   SNES_Patch *patch = (SNES_Patch *)snes->data;
180   DM          dm;
181   Mat         dummy;
182   Vec         F;
183   PetscInt    n, N;
184 
185   PetscFunctionBegin;
186   PetscCall(SNESGetDM(snes, &dm));
187   PetscCall(PCSetDM(patch->pc, dm));
188   PetscCall(SNESGetFunction(snes, &F, NULL, NULL));
189   PetscCall(VecGetLocalSize(F, &n));
190   PetscCall(VecGetSize(F, &N));
191   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)snes), n, n, N, N, (void *)snes, &dummy));
192   PetscCall(PCSetOperators(patch->pc, dummy, dummy));
193   PetscCall(MatDestroy(&dummy));
194   PetscCall(PCSetUp(patch->pc));
195   /* allocate workspace */
196   PetscFunctionReturn(0);
197 }
198 
199 static PetscErrorCode SNESReset_Patch(SNES snes) {
200   SNES_Patch *patch = (SNES_Patch *)snes->data;
201 
202   PetscFunctionBegin;
203   PetscCall(PCReset(patch->pc));
204   PetscFunctionReturn(0);
205 }
206 
207 static PetscErrorCode SNESDestroy_Patch(SNES snes) {
208   SNES_Patch *patch = (SNES_Patch *)snes->data;
209 
210   PetscFunctionBegin;
211   PetscCall(SNESReset_Patch(snes));
212   PetscCall(PCDestroy(&patch->pc));
213   PetscCall(PetscFree(snes->data));
214   PetscFunctionReturn(0);
215 }
216 
217 static PetscErrorCode SNESSetFromOptions_Patch(SNES snes, PetscOptionItems *PetscOptionsObject) {
218   SNES_Patch *patch = (SNES_Patch *)snes->data;
219   const char *prefix;
220 
221   PetscFunctionBegin;
222   PetscCall(PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix));
223   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix));
224   PetscCall(PCSetFromOptions(patch->pc));
225   PetscFunctionReturn(0);
226 }
227 
228 static PetscErrorCode SNESView_Patch(SNES snes, PetscViewer viewer) {
229   SNES_Patch *patch = (SNES_Patch *)snes->data;
230   PetscBool   iascii;
231 
232   PetscFunctionBegin;
233   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
234   if (iascii) PetscCall(PetscViewerASCIIPrintf(viewer, "SNESPATCH\n"));
235   PetscCall(PetscViewerASCIIPushTab(viewer));
236   PetscCall(PCView(patch->pc, viewer));
237   PetscCall(PetscViewerASCIIPopTab(viewer));
238   PetscFunctionReturn(0);
239 }
240 
241 static PetscErrorCode SNESSolve_Patch(SNES snes) {
242   SNES_Patch        *patch   = (SNES_Patch *)snes->data;
243   PC_PATCH          *pcpatch = (PC_PATCH *)patch->pc->data;
244   SNESLineSearch     ls;
245   Vec                rhs, update, state, residual;
246   const PetscScalar *globalState = NULL;
247   PetscScalar       *localState  = NULL;
248   PetscInt           its         = 0;
249   PetscReal          xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
250 
251   PetscFunctionBegin;
252   PetscCall(SNESGetSolution(snes, &state));
253   PetscCall(SNESGetSolutionUpdate(snes, &update));
254   PetscCall(SNESGetRhs(snes, &rhs));
255 
256   PetscCall(SNESGetFunction(snes, &residual, NULL, NULL));
257   PetscCall(SNESGetLineSearch(snes, &ls));
258 
259   PetscCall(SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING));
260   PetscCall(VecSet(update, 0.0));
261   PetscCall(SNESComputeFunction(snes, state, residual));
262 
263   PetscCall(VecNorm(state, NORM_2, &xnorm));
264   PetscCall(VecNorm(residual, NORM_2, &fnorm));
265   snes->ttol = fnorm * snes->rtol;
266 
267   if (snes->ops->converged) {
268     PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
269   } else {
270     PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
271   }
272   PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* should we count lits from the patches? */
273   PetscCall(SNESMonitor(snes, its, fnorm));
274 
275   /* The main solver loop */
276   for (its = 0; its < snes->max_its; its++) {
277     PetscCall(SNESSetIterationNumber(snes, its));
278 
279     /* Scatter state vector to overlapped vector on all patches.
280        The vector pcpatch->localState is scattered to each patch
281        in PCApply_PATCH_Nonlinear. */
282     PetscCall(VecGetArrayRead(state, &globalState));
283     PetscCall(VecGetArray(pcpatch->localState, &localState));
284     PetscCall(PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
285     PetscCall(PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
286     PetscCall(VecRestoreArray(pcpatch->localState, &localState));
287     PetscCall(VecRestoreArrayRead(state, &globalState));
288 
289     /* The looping over patches happens here */
290     PetscCall(PCApply(patch->pc, rhs, update));
291 
292     /* Apply a line search. This will often be basic with
293        damping = 1/(max number of patches a dof can be in),
294        but not always */
295     PetscCall(VecScale(update, -1.0));
296     PetscCall(SNESLineSearchApply(ls, state, residual, &fnorm, update));
297 
298     PetscCall(VecNorm(state, NORM_2, &xnorm));
299     PetscCall(VecNorm(update, NORM_2, &ynorm));
300 
301     if (snes->ops->converged) {
302       PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
303     } else {
304       PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
305     }
306     PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* FIXME: should we count lits? */
307     PetscCall(SNESMonitor(snes, its, fnorm));
308   }
309 
310   if (its == snes->max_its) PetscCall(SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT));
311   PetscFunctionReturn(0);
312 }
313 
314 /*MC
315   SNESPATCH - Solve a nonlinear problem or apply a nonlinear smoother by composing together many nonlinear solvers on (often overlapping) patches
316 
317   Level: intermediate
318 
319    References:
320 .  * - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015
321 
322 .seealso: `SNESFAS`, `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, `PCPATCH`
323 M*/
324 PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes) {
325   SNES_Patch    *patch;
326   PC_PATCH      *patchpc;
327   SNESLineSearch linesearch;
328 
329   PetscFunctionBegin;
330   PetscCall(PetscNewLog(snes, &patch));
331 
332   snes->ops->solve          = SNESSolve_Patch;
333   snes->ops->setup          = SNESSetUp_Patch;
334   snes->ops->reset          = SNESReset_Patch;
335   snes->ops->destroy        = SNESDestroy_Patch;
336   snes->ops->setfromoptions = SNESSetFromOptions_Patch;
337   snes->ops->view           = SNESView_Patch;
338 
339   PetscCall(SNESGetLineSearch(snes, &linesearch));
340   if (!((PetscObject)linesearch)->type_name) PetscCall(SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC));
341   snes->usesksp = PETSC_FALSE;
342 
343   snes->alwayscomputesfinalresidual = PETSC_FALSE;
344 
345   snes->data = (void *)patch;
346   PetscCall(PCCreate(PetscObjectComm((PetscObject)snes), &patch->pc));
347   PetscCall(PCSetType(patch->pc, PCPATCH));
348 
349   patchpc              = (PC_PATCH *)patch->pc->data;
350   patchpc->classname   = "snes";
351   patchpc->isNonlinear = PETSC_TRUE;
352 
353   patchpc->setupsolver          = PCSetUp_PATCH_Nonlinear;
354   patchpc->applysolver          = PCApply_PATCH_Nonlinear;
355   patchpc->resetsolver          = PCReset_PATCH_Nonlinear;
356   patchpc->destroysolver        = PCDestroy_PATCH_Nonlinear;
357   patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;
358 
359   PetscFunctionReturn(0);
360 }
361 
362 PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes) {
363   SNES_Patch *patch = (SNES_Patch *)snes->data;
364   DM          dm;
365 
366   PetscFunctionBegin;
367   PetscCall(SNESGetDM(snes, &dm));
368   PetscCheck(dm, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES");
369   PetscCall(PCSetDM(patch->pc, dm));
370   PetscCall(PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes));
371   PetscFunctionReturn(0);
372 }
373 
374 PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx) {
375   SNES_Patch *patch = (SNES_Patch *)snes->data;
376 
377   PetscFunctionBegin;
378   PetscCall(PCPatchSetComputeOperator(patch->pc, func, ctx));
379   PetscFunctionReturn(0);
380 }
381 
382 PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx) {
383   SNES_Patch *patch = (SNES_Patch *)snes->data;
384 
385   PetscFunctionBegin;
386   PetscCall(PCPatchSetComputeFunction(patch->pc, func, ctx));
387   PetscFunctionReturn(0);
388 }
389 
390 PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx) {
391   SNES_Patch *patch = (SNES_Patch *)snes->data;
392 
393   PetscFunctionBegin;
394   PetscCall(PCPatchSetConstructType(patch->pc, ctype, func, ctx));
395   PetscFunctionReturn(0);
396 }
397 
398 PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering) {
399   SNES_Patch *patch = (SNES_Patch *)snes->data;
400 
401   PetscFunctionBegin;
402   PetscCall(PCPatchSetCellNumbering(patch->pc, cellNumbering));
403   PetscFunctionReturn(0);
404 }
405