xref: /petsc/src/snes/impls/patch/snespatch.c (revision e84e3fd21fa5912dca3017339ab4b3699e3a9c51)
1 /*
2       Defines a SNES that can consist of a collection of SNESes on patches of the domain
3 */
4 #include <petsc/private/vecimpl.h>         /* For vec->map */
5 #include <petsc/private/snesimpl.h> /*I "petscsnes.h" I*/
6 #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */
7 #include <petscsf.h>
8 #include <petscsection.h>
9 
10 typedef struct {
11   PC pc; /* The linear patch preconditioner */
12 } SNES_Patch;
13 
14 static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
15 {
16   PC                pc      = (PC) ctx;
17   PC_PATCH          *pcpatch = (PC_PATCH *) pc->data;
18   PetscInt          pt, size, i;
19   const PetscInt    *indices;
20   const PetscScalar *X;
21   PetscScalar       *XWithAll;
22 
23   PetscFunctionBegin;
24 
25   /* scatter from x to patch->patchStateWithAll[pt] */
26   pt = pcpatch->currentPatch;
27   PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
28 
29   PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
30   PetscCall(VecGetArrayRead(x, &X));
31   PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
32 
33   for (i = 0; i < size; ++i) {
34     XWithAll[indices[i]] = X[i];
35   }
36 
37   PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
38   PetscCall(VecRestoreArrayRead(x, &X));
39   PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
40 
41   PetscCall(PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt));
42   PetscFunctionReturn(0);
43 }
44 
45 static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
46 {
47   PC                pc      = (PC) ctx;
48   PC_PATCH          *pcpatch = (PC_PATCH *) pc->data;
49   PetscInt          pt, size, i;
50   const PetscInt    *indices;
51   const PetscScalar *X;
52   PetscScalar       *XWithAll;
53 
54   PetscFunctionBegin;
55   /* scatter from x to patch->patchStateWithAll[pt] */
56   pt = pcpatch->currentPatch;
57   PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
58 
59   PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
60   PetscCall(VecGetArrayRead(x, &X));
61   PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
62 
63   for (i = 0; i < size; ++i) {
64     XWithAll[indices[i]] = X[i];
65   }
66 
67   PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
68   PetscCall(VecRestoreArrayRead(x, &X));
69   PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
70 
71   PetscCall(PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE));
72   PetscFunctionReturn(0);
73 }
74 
75 static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
76 {
77   PC_PATCH       *patch = (PC_PATCH *) pc->data;
78   const char     *prefix;
79   PetscInt       i, pStart, dof, maxDof = -1;
80 
81   PetscFunctionBegin;
82   if (!pc->setupcalled) {
83     PetscCall(PetscMalloc1(patch->npatch, &patch->solver));
84     PetscCall(PCGetOptionsPrefix(pc, &prefix));
85     PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
86     for (i = 0; i < patch->npatch; ++i) {
87       SNES snes;
88 
89       PetscCall(SNESCreate(PETSC_COMM_SELF, &snes));
90       PetscCall(SNESSetOptionsPrefix(snes, prefix));
91       PetscCall(SNESAppendOptionsPrefix(snes, "sub_"));
92       PetscCall(PetscObjectIncrementTabLevel((PetscObject) snes, (PetscObject) pc, 2));
93       PetscCall(PetscLogObjectParent((PetscObject) pc, (PetscObject) snes));
94       patch->solver[i] = (PetscObject) snes;
95 
96       PetscCall(PetscSectionGetDof(patch->gtolCountsWithAll, i+pStart, &dof));
97       maxDof = PetscMax(maxDof, dof);
98     }
99     PetscCall(VecDuplicate(patch->localUpdate, &patch->localState));
100     PetscCall(VecDuplicate(patch->patchRHS, &patch->patchResidual));
101     PetscCall(VecDuplicate(patch->patchUpdate, &patch->patchState));
102 
103     PetscCall(VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll));
104     PetscCall(VecSetUp(patch->patchStateWithAll));
105   }
106   for (i = 0; i < patch->npatch; ++i) {
107     SNES snes = (SNES) patch->solver[i];
108 
109     PetscCall(SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc));
110     PetscCall(SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc));
111   }
112   if (!pc->setupcalled && patch->optionsSet) for (i = 0; i < patch->npatch; ++i) PetscCall(SNESSetFromOptions((SNES) patch->solver[i]));
113   PetscFunctionReturn(0);
114 }
115 
116 static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
117 {
118   PC_PATCH      *patch = (PC_PATCH *) pc->data;
119   PetscInt       pStart, n;
120 
121   PetscFunctionBegin;
122   patch->currentPatch = i;
123   PetscCall(PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0));
124 
125   /* Scatter the overlapped global state to our patch state vector */
126   PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
127   PetscCall(PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR));
128   PetscCall(PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL));
129 
130   PetscCall(MatGetLocalSize(patch->mat[i], NULL, &n));
131   patch->patchState->map->n = n;
132   patch->patchState->map->N = n;
133   patchUpdate->map->n = n;
134   patchUpdate->map->N = n;
135   patchRHS->map->n = n;
136   patchRHS->map->N = n;
137   /* Set initial guess to be current state*/
138   PetscCall(VecCopy(patch->patchState, patchUpdate));
139   /* Solve for new state */
140   PetscCall(SNESSolve((SNES) patch->solver[i], patchRHS, patchUpdate));
141   /* To compute update, subtract off previous state */
142   PetscCall(VecAXPY(patchUpdate, -1.0, patch->patchState));
143 
144   PetscCall(PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0));
145   PetscFunctionReturn(0);
146 }
147 
148 static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
149 {
150   PC_PATCH      *patch = (PC_PATCH *) pc->data;
151   PetscInt       i;
152 
153   PetscFunctionBegin;
154   if (patch->solver) {
155     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESReset((SNES) patch->solver[i]));
156   }
157 
158   PetscCall(VecDestroy(&patch->patchResidual));
159   PetscCall(VecDestroy(&patch->patchState));
160   PetscCall(VecDestroy(&patch->patchStateWithAll));
161 
162   PetscCall(VecDestroy(&patch->localState));
163   PetscFunctionReturn(0);
164 }
165 
166 static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
167 {
168   PC_PATCH      *patch = (PC_PATCH *) pc->data;
169   PetscInt       i;
170 
171   PetscFunctionBegin;
172   if (patch->solver) {
173     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESDestroy((SNES *) &patch->solver[i]));
174     PetscCall(PetscFree(patch->solver));
175   }
176   PetscFunctionReturn(0);
177 }
178 
179 static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
180 {
181   PC_PATCH      *patch = (PC_PATCH *) pc->data;
182 
183   PetscFunctionBegin;
184   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR));
185   PetscFunctionReturn(0);
186 }
187 
188 static PetscErrorCode SNESSetUp_Patch(SNES snes)
189 {
190   SNES_Patch    *patch = (SNES_Patch *) snes->data;
191   DM             dm;
192   Mat            dummy;
193   Vec            F;
194   PetscInt       n, N;
195 
196   PetscFunctionBegin;
197   PetscCall(SNESGetDM(snes, &dm));
198   PetscCall(PCSetDM(patch->pc, dm));
199   PetscCall(SNESGetFunction(snes, &F, NULL, NULL));
200   PetscCall(VecGetLocalSize(F, &n));
201   PetscCall(VecGetSize(F, &N));
202   PetscCall(MatCreateShell(PetscObjectComm((PetscObject) snes), n, n, N, N, (void *) snes, &dummy));
203   PetscCall(PCSetOperators(patch->pc, dummy, dummy));
204   PetscCall(MatDestroy(&dummy));
205   PetscCall(PCSetUp(patch->pc));
206   /* allocate workspace */
207   PetscFunctionReturn(0);
208 }
209 
210 static PetscErrorCode SNESReset_Patch(SNES snes)
211 {
212   SNES_Patch    *patch = (SNES_Patch *) snes->data;
213 
214   PetscFunctionBegin;
215   PetscCall(PCReset(patch->pc));
216   PetscFunctionReturn(0);
217 }
218 
219 static PetscErrorCode SNESDestroy_Patch(SNES snes)
220 {
221   SNES_Patch    *patch = (SNES_Patch *) snes->data;
222 
223   PetscFunctionBegin;
224   PetscCall(SNESReset_Patch(snes));
225   PetscCall(PCDestroy(&patch->pc));
226   PetscCall(PetscFree(snes->data));
227   PetscFunctionReturn(0);
228 }
229 
230 static PetscErrorCode SNESSetFromOptions_Patch(PetscOptionItems *PetscOptionsObject, SNES snes)
231 {
232   SNES_Patch    *patch = (SNES_Patch *) snes->data;
233   const char    *prefix;
234 
235   PetscFunctionBegin;
236   PetscCall(PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix));
237   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix));
238   PetscCall(PCSetFromOptions(patch->pc));
239   PetscFunctionReturn(0);
240 }
241 
242 static PetscErrorCode SNESView_Patch(SNES snes,PetscViewer viewer)
243 {
244   SNES_Patch    *patch = (SNES_Patch *) snes->data;
245   PetscBool      iascii;
246 
247   PetscFunctionBegin;
248   PetscCall(PetscObjectTypeCompare((PetscObject) viewer, PETSCVIEWERASCII, &iascii));
249   if (iascii) {
250     PetscCall(PetscViewerASCIIPrintf(viewer,"SNESPATCH\n"));
251   }
252   PetscCall(PetscViewerASCIIPushTab(viewer));
253   PetscCall(PCView(patch->pc, viewer));
254   PetscCall(PetscViewerASCIIPopTab(viewer));
255   PetscFunctionReturn(0);
256 }
257 
258 static PetscErrorCode SNESSolve_Patch(SNES snes)
259 {
260   SNES_Patch        *patch = (SNES_Patch *) snes->data;
261   PC_PATCH          *pcpatch = (PC_PATCH *) patch->pc->data;
262   SNESLineSearch    ls;
263   Vec               rhs, update, state, residual;
264   const PetscScalar *globalState  = NULL;
265   PetscScalar       *localState   = NULL;
266   PetscInt          its = 0;
267   PetscReal         xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
268 
269   PetscFunctionBegin;
270   PetscCall(SNESGetSolution(snes, &state));
271   PetscCall(SNESGetSolutionUpdate(snes, &update));
272   PetscCall(SNESGetRhs(snes, &rhs));
273 
274   PetscCall(SNESGetFunction(snes, &residual, NULL, NULL));
275   PetscCall(SNESGetLineSearch(snes, &ls));
276 
277   PetscCall(SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING));
278   PetscCall(VecSet(update, 0.0));
279   PetscCall(SNESComputeFunction(snes, state, residual));
280 
281   PetscCall(VecNorm(state, NORM_2, &xnorm));
282   PetscCall(VecNorm(residual, NORM_2, &fnorm));
283   snes->ttol = fnorm*snes->rtol;
284 
285   if (snes->ops->converged) {
286     PetscCall((*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP));
287   } else {
288     PetscCall(SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,NULL));
289   }
290   PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* should we count lits from the patches? */
291   PetscCall(SNESMonitor(snes, its, fnorm));
292 
293   /* The main solver loop */
294   for (its = 0; its < snes->max_its; its++) {
295 
296     PetscCall(SNESSetIterationNumber(snes, its));
297 
298     /* Scatter state vector to overlapped vector on all patches.
299        The vector pcpatch->localState is scattered to each patch
300        in PCApply_PATCH_Nonlinear. */
301     PetscCall(VecGetArrayRead(state, &globalState));
302     PetscCall(VecGetArray(pcpatch->localState, &localState));
303     PetscCall(PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState,MPI_REPLACE));
304     PetscCall(PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState,MPI_REPLACE));
305     PetscCall(VecRestoreArray(pcpatch->localState, &localState));
306     PetscCall(VecRestoreArrayRead(state, &globalState));
307 
308     /* The looping over patches happens here */
309     PetscCall(PCApply(patch->pc, rhs, update));
310 
311     /* Apply a line search. This will often be basic with
312        damping = 1/(max number of patches a dof can be in),
313        but not always */
314     PetscCall(VecScale(update, -1.0));
315     PetscCall(SNESLineSearchApply(ls, state, residual, &fnorm, update));
316 
317     PetscCall(VecNorm(state, NORM_2, &xnorm));
318     PetscCall(VecNorm(update, NORM_2, &ynorm));
319 
320     if (snes->ops->converged) {
321       PetscCall((*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP));
322     } else {
323       PetscCall(SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,NULL));
324     }
325     PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* FIXME: should we count lits? */
326     PetscCall(SNESMonitor(snes, its, fnorm));
327   }
328 
329   if (its == snes->max_its) PetscCall(SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT));
330   PetscFunctionReturn(0);
331 }
332 
333 /*MC
334   SNESPATCH - Solve a nonlinear problem by composing together many nonlinear solvers on patches
335 
336   Level: intermediate
337 
338 .seealso:  SNESCreate(), SNESSetType(), SNESType (for list of available types), SNES,
339            PCPATCH
340 
341    References:
342 .  * - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015
343 
344 M*/
345 PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
346 {
347   SNES_Patch     *patch;
348   PC_PATCH       *patchpc;
349   SNESLineSearch linesearch;
350 
351   PetscFunctionBegin;
352   PetscCall(PetscNewLog(snes, &patch));
353 
354   snes->ops->solve          = SNESSolve_Patch;
355   snes->ops->setup          = SNESSetUp_Patch;
356   snes->ops->reset          = SNESReset_Patch;
357   snes->ops->destroy        = SNESDestroy_Patch;
358   snes->ops->setfromoptions = SNESSetFromOptions_Patch;
359   snes->ops->view           = SNESView_Patch;
360 
361   PetscCall(SNESGetLineSearch(snes,&linesearch));
362   if (!((PetscObject)linesearch)->type_name) {
363     PetscCall(SNESLineSearchSetType(linesearch,SNESLINESEARCHBASIC));
364   }
365   snes->usesksp        = PETSC_FALSE;
366 
367   snes->alwayscomputesfinalresidual = PETSC_FALSE;
368 
369   snes->data = (void *) patch;
370   PetscCall(PCCreate(PetscObjectComm((PetscObject) snes), &patch->pc));
371   PetscCall(PCSetType(patch->pc, PCPATCH));
372 
373   patchpc = (PC_PATCH*) patch->pc->data;
374   patchpc->classname = "snes";
375   patchpc->isNonlinear = PETSC_TRUE;
376 
377   patchpc->setupsolver   = PCSetUp_PATCH_Nonlinear;
378   patchpc->applysolver   = PCApply_PATCH_Nonlinear;
379   patchpc->resetsolver   = PCReset_PATCH_Nonlinear;
380   patchpc->destroysolver = PCDestroy_PATCH_Nonlinear;
381   patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;
382 
383   PetscFunctionReturn(0);
384 }
385 
386 PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap,
387                                             const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
388 {
389   SNES_Patch     *patch = (SNES_Patch *) snes->data;
390   DM             dm;
391 
392   PetscFunctionBegin;
393   PetscCall(SNESGetDM(snes, &dm));
394   PetscCheck(dm,PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES");
395   PetscCall(PCSetDM(patch->pc, dm));
396   PetscCall(PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes));
397   PetscFunctionReturn(0);
398 }
399 
400 PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
401 {
402   SNES_Patch    *patch = (SNES_Patch *) snes->data;
403 
404   PetscFunctionBegin;
405   PetscCall(PCPatchSetComputeOperator(patch->pc, func, ctx));
406   PetscFunctionReturn(0);
407 }
408 
409 PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
410 {
411   SNES_Patch    *patch = (SNES_Patch *) snes->data;
412 
413   PetscFunctionBegin;
414   PetscCall(PCPatchSetComputeFunction(patch->pc, func, ctx));
415   PetscFunctionReturn(0);
416 }
417 
418 PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
419 {
420   SNES_Patch    *patch = (SNES_Patch *) snes->data;
421 
422   PetscFunctionBegin;
423   PetscCall(PCPatchSetConstructType(patch->pc, ctype, func, ctx));
424   PetscFunctionReturn(0);
425 }
426 
427 PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
428 {
429   SNES_Patch    *patch = (SNES_Patch *) snes->data;
430 
431   PetscFunctionBegin;
432   PetscCall(PCPatchSetCellNumbering(patch->pc, cellNumbering));
433   PetscFunctionReturn(0);
434 }
435