xref: /petsc/src/snes/impls/patch/snespatch.c (revision 636c62a66e3732393be38223299603b2c6478548)
1 /*
2       Defines a SNES that can consist of a collection of SNESes on patches of the domain
3 */
4 #include <petsc/private/snesimpl.h> /*I "petscsnes.h" I*/
5 #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */
6 #include <petscsf.h>
7 
8 typedef struct {
9   PC pc; /* The linear patch preconditioner */
10 } SNES_Patch;
11 
12 static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
13 {
14   PC             pc      = (PC) ctx;
15   PC_PATCH      *pcpatch = (PC_PATCH *) pc->data;
16   PetscErrorCode ierr;
17 
18   PetscFunctionBegin;
19   ierr = PCPatchComputeFunction_Internal(pc, x, F, pcpatch->currentPatch);CHKERRQ(ierr);
20   PetscFunctionReturn(0);
21 }
22 
23 static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
24 {
25   PC             pc      = (PC) ctx;
26   PC_PATCH      *pcpatch = (PC_PATCH *) pc->data;
27   PetscErrorCode ierr;
28 
29   PetscFunctionBegin;
30   ierr = PCPatchComputeOperator_Internal(pc, x, M, pcpatch->currentPatch, PETSC_FALSE);CHKERRQ(ierr);
31   PetscFunctionReturn(0);
32 }
33 
34 static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
35 {
36   PC_PATCH      *patch = (PC_PATCH *) pc->data;
37   const char    *prefix;
38   PetscInt       i;
39   PetscErrorCode ierr;
40 
41   PetscFunctionBegin;
42   if (!pc->setupcalled) {
43     ierr = PetscMalloc1(patch->npatch, &patch->solver);CHKERRQ(ierr);
44     ierr = PCGetOptionsPrefix(pc, &prefix);CHKERRQ(ierr);
45     for (i = 0; i < patch->npatch; ++i) {
46       SNES snes;
47       KSP  subksp;
48 
49       ierr = SNESCreate(PETSC_COMM_SELF, &snes);CHKERRQ(ierr);
50       ierr = SNESSetOptionsPrefix(snes, prefix);CHKERRQ(ierr);
51       ierr = SNESAppendOptionsPrefix(snes, "sub_");CHKERRQ(ierr);
52       ierr = PetscObjectIncrementTabLevel((PetscObject) snes, (PetscObject) pc, 2);CHKERRQ(ierr);
53       ierr = SNESGetKSP(snes, &subksp);CHKERRQ(ierr);
54       ierr = PetscObjectIncrementTabLevel((PetscObject) subksp, (PetscObject) pc, 2);CHKERRQ(ierr);
55       ierr = PetscLogObjectParent((PetscObject) pc, (PetscObject) snes);CHKERRQ(ierr);
56       patch->solver[i] = (PetscObject) snes;
57     }
58 
59     ierr = PetscMalloc1(patch->npatch, &patch->patchResidual);CHKERRQ(ierr);
60     ierr = PetscMalloc1(patch->npatch, &patch->patchState);CHKERRQ(ierr);
61     for (i = 0; i < patch->npatch; ++i) {
62       ierr = VecDuplicate(patch->patchRHS[i], &patch->patchResidual[i]);CHKERRQ(ierr);
63       ierr = VecDuplicate(patch->patchUpdate[i], &patch->patchState[i]);CHKERRQ(ierr);
64     }
65     ierr = VecDuplicate(patch->localUpdate, &patch->localState);CHKERRQ(ierr);
66   }
67   for (i = 0; i < patch->npatch; ++i) {
68     SNES snes = (SNES) patch->solver[i];
69 
70     ierr = SNESSetFunction(snes, patch->patchResidual[i], SNESPatchComputeResidual_Private, pc);CHKERRQ(ierr);
71     ierr = SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc);CHKERRQ(ierr);
72   }
73   if (!pc->setupcalled && patch->optionsSet) for (i = 0; i < patch->npatch; ++i) {ierr = SNESSetFromOptions((SNES) patch->solver[i]);CHKERRQ(ierr);}
74   PetscFunctionReturn(0);
75 }
76 
77 static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
78 {
79   PC_PATCH      *patch = (PC_PATCH *) pc->data;
80   PetscInt       pStart;
81   PetscErrorCode ierr;
82 
83   PetscFunctionBegin;
84   patch->currentPatch = i;
85   ierr = PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0);CHKERRQ(ierr);
86 
87   /* Scatter the overlapped global state to our patch state vector */
88   ierr = PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);CHKERRQ(ierr);
89   ierr = PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchState[i], INSERT_VALUES, SCATTER_FORWARD, PETSC_FALSE);CHKERRQ(ierr);
90 
91   /* Set initial guess to be current state*/
92   ierr = VecCopy(patch->patchState[i], patchUpdate);CHKERRQ(ierr);
93   /* Solve for new state */
94   ierr = SNESSolve((SNES) patch->solver[i], patchRHS, patchUpdate);CHKERRQ(ierr);
95   /* To compute update, subtract off previous state */
96   ierr = VecAXPY(patchUpdate, -1.0, patch->patchState[i]);CHKERRQ(ierr);
97 
98   ierr = PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0);CHKERRQ(ierr);
99   PetscFunctionReturn(0);
100 }
101 
102 static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
103 {
104   PC_PATCH      *patch = (PC_PATCH *) pc->data;
105   PetscInt       i;
106   PetscErrorCode ierr;
107 
108   PetscFunctionBegin;
109 
110   if (patch->solver) {
111     for (i = 0; i < patch->npatch; ++i) {ierr = SNESReset((SNES) patch->solver[i]);CHKERRQ(ierr);}
112   }
113 
114   if (patch->patchResidual) {
115     for (i = 0; i < patch->npatch; ++i) {ierr = VecDestroy(&patch->patchResidual[i]);CHKERRQ(ierr);}
116     ierr = PetscFree(patch->patchResidual);CHKERRQ(ierr);
117   }
118 
119   if (patch->patchState) {
120     for (i = 0; i < patch->npatch; ++i) {ierr = VecDestroy(&patch->patchState[i]);CHKERRQ(ierr);}
121     ierr = PetscFree(patch->patchState);CHKERRQ(ierr);
122   }
123 
124   ierr = VecDestroy(&patch->localState);CHKERRQ(ierr);
125 
126   PetscFunctionReturn(0);
127 }
128 
129 static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
130 {
131   PC_PATCH      *patch = (PC_PATCH *) pc->data;
132   PetscInt       i;
133   PetscErrorCode ierr;
134 
135   PetscFunctionBegin;
136   if (patch->solver) {
137     for (i = 0; i < patch->npatch; ++i) {ierr = SNESDestroy((SNES *) &patch->solver[i]);CHKERRQ(ierr);}
138     ierr = PetscFree(patch->solver);CHKERRQ(ierr);
139   }
140   PetscFunctionReturn(0);
141 }
142 
143 static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
144 {
145   PC_PATCH      *patch = (PC_PATCH *) pc->data;
146   PetscErrorCode ierr;
147 
148   ierr = PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate[i], patch->localState, ADD_VALUES, SCATTER_REVERSE, PETSC_FALSE);CHKERRQ(ierr);
149 }
150 
151 static PetscErrorCode SNESSetUp_Patch(SNES snes)
152 {
153   SNES_Patch    *patch = (SNES_Patch *) snes->data;
154   DM             dm;
155   Mat            dummy;
156   Vec            F;
157   PetscInt       n, N;
158   PetscErrorCode ierr;
159 
160   PetscFunctionBegin;
161   ierr = SNESGetDM(snes, &dm);CHKERRQ(ierr);
162   ierr = PCSetDM(patch->pc, dm);CHKERRQ(ierr);
163   ierr = SNESGetFunction(snes, &F, NULL, NULL);CHKERRQ(ierr);
164   ierr = VecGetLocalSize(F, &n);CHKERRQ(ierr);
165   ierr = VecGetSize(F, &N);CHKERRQ(ierr);
166   ierr = MatCreateShell(PetscObjectComm((PetscObject) snes), n, n, N, N, (void *) snes, &dummy);CHKERRQ(ierr);
167   ierr = PCSetOperators(patch->pc, dummy, dummy);CHKERRQ(ierr);
168   ierr = MatDestroy(&dummy);CHKERRQ(ierr);
169   ierr = PCSetUp(patch->pc);CHKERRQ(ierr);
170   /* allocate workspace */
171   PetscFunctionReturn(0);
172 }
173 
174 static PetscErrorCode SNESReset_Patch(SNES snes)
175 {
176   SNES_Patch    *patch = (SNES_Patch *) snes->data;
177   PetscErrorCode ierr;
178 
179   PetscFunctionBegin;
180   ierr = PCReset(patch->pc);CHKERRQ(ierr);
181   PetscFunctionReturn(0);
182 }
183 
184 static PetscErrorCode SNESDestroy_Patch(SNES snes)
185 {
186   SNES_Patch    *patch = (SNES_Patch *) snes->data;
187   PetscErrorCode ierr;
188 
189   PetscFunctionBegin;
190   ierr = SNESReset_Patch(snes);CHKERRQ(ierr);
191   ierr = PCDestroy(&patch->pc);CHKERRQ(ierr);
192   ierr = PetscFree(snes->data);CHKERRQ(ierr);
193   PetscFunctionReturn(0);
194 }
195 
196 static PetscErrorCode SNESSetFromOptions_Patch(PetscOptionItems *PetscOptionsObject, SNES snes)
197 {
198   SNES_Patch    *patch = (SNES_Patch *) snes->data;
199   PetscBool      flg;
200   const char    *prefix;
201   PetscErrorCode ierr;
202 
203   PetscFunctionBegin;
204   ierr = PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix);CHKERRQ(ierr);
205   ierr = PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix);CHKERRQ(ierr);
206   ierr = PCSetFromOptions(patch->pc);CHKERRQ(ierr);
207   PetscFunctionReturn(0);
208 }
209 
210 static PetscErrorCode SNESView_Patch(SNES snes,PetscViewer viewer)
211 {
212   SNES_Patch    *patch = (SNES_Patch *) snes->data;
213   PetscBool      iascii;
214   PetscErrorCode ierr;
215 
216   PetscFunctionBegin;
217   ierr = PetscObjectTypeCompare((PetscObject) viewer, PETSCVIEWERASCII, &iascii);CHKERRQ(ierr);
218   if (iascii) {
219     ierr = PetscViewerASCIIPrintf(viewer,"SNESPATCH\n");CHKERRQ(ierr);
220   }
221   ierr = PetscViewerASCIIPushTab(viewer);CHKERRQ(ierr);
222   ierr = PCView(patch->pc, viewer);CHKERRQ(ierr);
223   ierr = PetscViewerASCIIPopTab(viewer);CHKERRQ(ierr);
224   PetscFunctionReturn(0);
225 }
226 
227 static PetscErrorCode SNESSolve_Patch(SNES snes)
228 {
229   SNES_Patch *patch = (SNES_Patch *) snes->data;
230   PC_PATCH   *pcpatch = (PC_PATCH *) patch->pc->data;
231   SNESLineSearch ls;
232   Vec rhs, update, state, residual;
233   const PetscScalar *globalState  = NULL;
234   PetscScalar       *localState   = NULL;
235   PetscInt its = 0;
236   PetscReal xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
237   PetscErrorCode ierr;
238 
239   PetscFunctionBegin;
240 
241   ierr = SNESGetSolution(snes, &state);CHKERRQ(ierr);
242   ierr = SNESGetSolutionUpdate(snes, &update);CHKERRQ(ierr);
243   ierr = SNESGetRhs(snes, &rhs);CHKERRQ(ierr);
244 
245   ierr = SNESGetFunction(snes, &residual, NULL, NULL);CHKERRQ(ierr);
246   ierr = SNESGetLineSearch(snes, &ls);CHKERRQ(ierr);
247 
248   ierr = SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING);CHKERRQ(ierr);
249   ierr = VecSet(update, 0.0);CHKERRQ(ierr);
250   ierr = SNESComputeFunction(snes, state, residual);CHKERRQ(ierr);
251 
252   ierr = VecNorm(state, NORM_2, &xnorm);CHKERRQ(ierr);
253   ierr = VecNorm(residual, NORM_2, &fnorm);CHKERRQ(ierr);
254   snes->ttol = fnorm*snes->rtol;
255 
256   if (snes->ops->converged) {
257     ierr = (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);CHKERRQ(ierr);
258   } else {
259     ierr = SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,0);CHKERRQ(ierr);
260   }
261   ierr = SNESLogConvergenceHistory(snes, fnorm, 0);CHKERRQ(ierr); /* should we count lits from the patches? */
262   ierr = SNESMonitor(snes, its, fnorm);CHKERRQ(ierr);
263 
264   /* The main solver loop */
265   for (its = 0; its < snes->max_its; its++) {
266 
267     ierr = SNESSetIterationNumber(snes, its);CHKERRQ(ierr);
268 
269     /* Scatter state vector to overlapped vector on all patches.
270        The vector pcpatch->localState is scattered to each patch
271        in PCApply_PATCH_Nonlinear. */
272     ierr = VecGetArrayRead(state, &globalState);CHKERRQ(ierr);
273     ierr = VecGetArray(pcpatch->localState, &localState);CHKERRQ(ierr);
274     ierr = PetscSFBcastBegin(pcpatch->defaultSF, MPIU_SCALAR, globalState, localState);CHKERRQ(ierr);
275     ierr = PetscSFBcastEnd(pcpatch->defaultSF, MPIU_SCALAR, globalState, localState);CHKERRQ(ierr);
276     ierr = VecRestoreArray(pcpatch->localState, &localState);CHKERRQ(ierr);
277     ierr = VecRestoreArrayRead(state, &globalState);CHKERRQ(ierr);
278 
279     /* The looping over patches happens here */
280     ierr = PCApply(patch->pc, rhs, update);
281 
282     /* Apply a line search. This will often be basic with
283        damping = 1/(max number of patches a dof can be in),
284        but not always */
285     ierr = VecScale(update, -1.0);CHKERRQ(ierr);
286     ierr = SNESLineSearchApply(ls, state, residual, &fnorm, update);CHKERRQ(ierr);
287 
288     ierr = VecNorm(state, NORM_2, &xnorm);CHKERRQ(ierr);
289     ierr = VecNorm(update, NORM_2, &ynorm);CHKERRQ(ierr);
290 
291     if (snes->ops->converged) {
292       ierr = (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);CHKERRQ(ierr);
293     } else {
294       ierr = SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,0);CHKERRQ(ierr);
295     }
296     ierr = SNESLogConvergenceHistory(snes, fnorm, 0);CHKERRQ(ierr); /* FIXME: should we count lits? */
297     ierr = SNESMonitor(snes, its, fnorm);CHKERRQ(ierr);
298   }
299 
300   if (its == snes->max_its) { ierr = SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT);CHKERRQ(ierr); }
301   PetscFunctionReturn(0);
302 }
303 
304 /*MC
305   SNESPATCH - Solve a nonlinear problem by composing together many nonlinear solvers on patches
306 
307   Level: intermediate
308 
309   Concepts: composing solvers
310 
311 .seealso:  SNESCreate(), SNESSetType(), SNESType (for list of available types), SNES,
312            PCPATCH
313 
314    References:
315 .  1. - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015
316 
317 M*/
318 PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
319 {
320   PetscErrorCode ierr;
321   SNES_Patch    *patch;
322   PC_PATCH      *patchpc;
323 
324   PetscFunctionBegin;
325   ierr = PetscNewLog(snes, &patch);CHKERRQ(ierr);
326 
327   snes->ops->solve          = SNESSolve_Patch;
328   snes->ops->setup          = SNESSetUp_Patch;
329   snes->ops->reset          = SNESReset_Patch;
330   snes->ops->destroy        = SNESDestroy_Patch;
331   snes->ops->setfromoptions = SNESSetFromOptions_Patch;
332   snes->ops->view           = SNESView_Patch;
333 
334   snes->alwayscomputesfinalresidual = PETSC_FALSE;
335 
336   snes->data = (void *) patch;
337   ierr = PCCreate(PetscObjectComm((PetscObject) snes), &patch->pc);CHKERRQ(ierr);
338   ierr = PCSetType(patch->pc, PCPATCH);CHKERRQ(ierr);
339 
340   patchpc = (PC_PATCH*) patch->pc->data;
341   patchpc->classname = "snes";
342 
343   patchpc->setupsolver   = PCSetUp_PATCH_Nonlinear;
344   patchpc->applysolver   = PCApply_PATCH_Nonlinear;
345   patchpc->resetsolver   = PCReset_PATCH_Nonlinear;
346   patchpc->destroysolver = PCDestroy_PATCH_Nonlinear;
347   patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;
348 
349   PetscFunctionReturn(0);
350 }
351 
352 PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap,
353                                             const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
354 {
355   SNES_Patch    *patch = (SNES_Patch *) snes->data;
356   PetscErrorCode ierr;
357   DM dm;
358 
359   PetscFunctionBegin;
360   ierr = SNESGetDM(snes, &dm);CHKERRQ(ierr);
361   if (!dm) SETERRQ(PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES\n");
362   ierr = PCSetDM(patch->pc, dm);CHKERRQ(ierr);
363   ierr = PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes);CHKERRQ(ierr);
364   PetscFunctionReturn(0);
365 }
366 
367 PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, void *), void *ctx)
368 {
369   SNES_Patch    *patch = (SNES_Patch *) snes->data;
370   PetscErrorCode ierr;
371 
372   PetscFunctionBegin;
373   ierr = PCPatchSetComputeOperator(patch->pc, func, ctx);CHKERRQ(ierr);
374   PetscFunctionReturn(0);
375 }
376 
377 PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, void *), void *ctx)
378 {
379   SNES_Patch    *patch = (SNES_Patch *) snes->data;
380   PetscErrorCode ierr;
381 
382   PetscFunctionBegin;
383   ierr = PCPatchSetComputeFunction(patch->pc, func, ctx);CHKERRQ(ierr);
384   PetscFunctionReturn(0);
385 }
386 
387 PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
388 {
389   SNES_Patch    *patch = (SNES_Patch *) snes->data;
390   PetscErrorCode ierr;
391 
392   PetscFunctionBegin;
393   ierr = PCPatchSetConstructType(patch->pc, ctype, func, ctx);CHKERRQ(ierr);
394   PetscFunctionReturn(0);
395 }
396 
397 PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
398 {
399   SNES_Patch    *patch = (SNES_Patch *) snes->data;
400   PetscErrorCode ierr;
401 
402   PetscFunctionBegin;
403   ierr = PCPatchSetCellNumbering(patch->pc, cellNumbering);CHKERRQ(ierr);
404   PetscFunctionReturn(0);
405 }
406