1 /*
2 Defines a SNES that can consist of a collection of SNESes on patches of the domain
3 */
4 #include <petsc/private/vecimpl.h> /* For vec->map */
5 #include <petsc/private/snesimpl.h> /*I "petscsnes.h" I*/
6 #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */
7 #include <petscsf.h>
8 #include <petscsection.h>
9
10 typedef struct {
11 PC pc; /* The linear patch preconditioner */
12 } SNES_Patch;
13
SNESPatchComputeResidual_Private(SNES snes,Vec x,Vec F,PetscCtx ctx)14 static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, PetscCtx ctx)
15 {
16 PC pc = (PC)ctx;
17 PC_PATCH *pcpatch = (PC_PATCH *)pc->data;
18 PetscInt pt, size, i;
19 const PetscInt *indices;
20 const PetscScalar *X;
21 PetscScalar *XWithAll;
22
23 PetscFunctionBegin;
24 /* scatter from x to patch->patchStateWithAll[pt] */
25 pt = pcpatch->currentPatch;
26 PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
27
28 PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
29 PetscCall(VecGetArrayRead(x, &X));
30 PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
31
32 for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
33
34 PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
35 PetscCall(VecRestoreArrayRead(x, &X));
36 PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
37
38 PetscCall(PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt));
39 PetscFunctionReturn(PETSC_SUCCESS);
40 }
41
SNESPatchComputeJacobian_Private(SNES snes,Vec x,Mat J,Mat M,PetscCtx ctx)42 static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, PetscCtx ctx)
43 {
44 PC pc = (PC)ctx;
45 PC_PATCH *pcpatch = (PC_PATCH *)pc->data;
46 PetscInt pt, size, i;
47 const PetscInt *indices;
48 const PetscScalar *X;
49 PetscScalar *XWithAll;
50
51 PetscFunctionBegin;
52 /* scatter from x to patch->patchStateWithAll[pt] */
53 pt = pcpatch->currentPatch;
54 PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
55
56 PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
57 PetscCall(VecGetArrayRead(x, &X));
58 PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
59
60 for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
61
62 PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
63 PetscCall(VecRestoreArrayRead(x, &X));
64 PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
65
66 PetscCall(PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE));
67 PetscFunctionReturn(PETSC_SUCCESS);
68 }
69
PCSetUp_PATCH_Nonlinear(PC pc)70 static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
71 {
72 PC_PATCH *patch = (PC_PATCH *)pc->data;
73 const char *prefix;
74 PetscInt i, pStart, dof, maxDof = -1;
75
76 PetscFunctionBegin;
77 if (!pc->setupcalled) {
78 PetscCall(PetscMalloc1(patch->npatch, &patch->solver));
79 PetscCall(PCGetOptionsPrefix(pc, &prefix));
80 PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
81 for (i = 0; i < patch->npatch; ++i) {
82 SNES snes;
83
84 PetscCall(SNESCreate(PETSC_COMM_SELF, &snes));
85 PetscCall(SNESSetOptionsPrefix(snes, prefix));
86 PetscCall(SNESAppendOptionsPrefix(snes, "sub_"));
87 PetscCall(PetscObjectIncrementTabLevel((PetscObject)snes, (PetscObject)pc, 2));
88 patch->solver[i] = (PetscObject)snes;
89
90 PetscCall(PetscSectionGetDof(patch->gtolCountsWithAll, i + pStart, &dof));
91 maxDof = PetscMax(maxDof, dof);
92 }
93 PetscCall(VecDuplicate(patch->localUpdate, &patch->localState));
94 PetscCall(VecDuplicate(patch->patchRHS, &patch->patchResidual));
95 PetscCall(VecDuplicate(patch->patchUpdate, &patch->patchState));
96
97 PetscCall(VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll));
98 PetscCall(VecSetUp(patch->patchStateWithAll));
99 }
100 for (i = 0; i < patch->npatch; ++i) {
101 SNES snes = (SNES)patch->solver[i];
102
103 PetscCall(SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc));
104 PetscCall(SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc));
105 }
106 if (!pc->setupcalled && patch->optionsSet)
107 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESSetFromOptions((SNES)patch->solver[i]));
108 PetscFunctionReturn(PETSC_SUCCESS);
109 }
110
PCApply_PATCH_Nonlinear(PC pc,PetscInt i,Vec patchRHS,Vec patchUpdate)111 static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
112 {
113 PC_PATCH *patch = (PC_PATCH *)pc->data;
114 PetscInt pStart, n;
115
116 PetscFunctionBegin;
117 patch->currentPatch = i;
118 PetscCall(PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0));
119
120 /* Scatter the overlapped global state to our patch state vector */
121 PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
122 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR));
123 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL));
124
125 PetscCall(MatGetLocalSize(patch->mat[i], NULL, &n));
126 patch->patchState->map->n = n;
127 patch->patchState->map->N = n;
128 patchUpdate->map->n = n;
129 patchUpdate->map->N = n;
130 patchRHS->map->n = n;
131 patchRHS->map->N = n;
132 /* Set initial guess to be current state*/
133 PetscCall(VecCopy(patch->patchState, patchUpdate));
134 /* Solve for new state */
135 PetscCall(SNESSolve((SNES)patch->solver[i], patchRHS, patchUpdate));
136 /* To compute update, subtract off previous state */
137 PetscCall(VecAXPY(patchUpdate, -1.0, patch->patchState));
138
139 PetscCall(PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0));
140 PetscFunctionReturn(PETSC_SUCCESS);
141 }
142
PCReset_PATCH_Nonlinear(PC pc)143 static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
144 {
145 PC_PATCH *patch = (PC_PATCH *)pc->data;
146 PetscInt i;
147
148 PetscFunctionBegin;
149 if (patch->solver) {
150 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESReset((SNES)patch->solver[i]));
151 }
152
153 PetscCall(VecDestroy(&patch->patchResidual));
154 PetscCall(VecDestroy(&patch->patchState));
155 PetscCall(VecDestroy(&patch->patchStateWithAll));
156
157 PetscCall(VecDestroy(&patch->localState));
158 PetscFunctionReturn(PETSC_SUCCESS);
159 }
160
PCDestroy_PATCH_Nonlinear(PC pc)161 static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
162 {
163 PC_PATCH *patch = (PC_PATCH *)pc->data;
164 PetscInt i;
165
166 PetscFunctionBegin;
167 if (patch->solver) {
168 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESDestroy((SNES *)&patch->solver[i]));
169 PetscCall(PetscFree(patch->solver));
170 }
171 PetscFunctionReturn(PETSC_SUCCESS);
172 }
173
PCUpdateMultiplicative_PATCH_Nonlinear(PC pc,PetscInt i,PetscInt pStart)174 static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
175 {
176 PC_PATCH *patch = (PC_PATCH *)pc->data;
177
178 PetscFunctionBegin;
179 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR));
180 PetscFunctionReturn(PETSC_SUCCESS);
181 }
182
SNESSetUp_Patch(SNES snes)183 static PetscErrorCode SNESSetUp_Patch(SNES snes)
184 {
185 SNES_Patch *patch = (SNES_Patch *)snes->data;
186 DM dm;
187 Mat dummy;
188 Vec F;
189 PetscInt n, N;
190
191 PetscFunctionBegin;
192 PetscCall(SNESGetDM(snes, &dm));
193 PetscCall(PCSetDM(patch->pc, dm));
194 PetscCall(SNESGetFunction(snes, &F, NULL, NULL));
195 PetscCall(VecGetLocalSize(F, &n));
196 PetscCall(VecGetSize(F, &N));
197 PetscCall(MatCreateShell(PetscObjectComm((PetscObject)snes), n, n, N, N, (void *)snes, &dummy));
198 PetscCall(PCSetOperators(patch->pc, dummy, dummy));
199 PetscCall(MatDestroy(&dummy));
200 PetscCall(PCSetUp(patch->pc));
201 /* allocate workspace */
202 PetscFunctionReturn(PETSC_SUCCESS);
203 }
204
SNESReset_Patch(SNES snes)205 static PetscErrorCode SNESReset_Patch(SNES snes)
206 {
207 SNES_Patch *patch = (SNES_Patch *)snes->data;
208
209 PetscFunctionBegin;
210 PetscCall(PCReset(patch->pc));
211 PetscFunctionReturn(PETSC_SUCCESS);
212 }
213
SNESDestroy_Patch(SNES snes)214 static PetscErrorCode SNESDestroy_Patch(SNES snes)
215 {
216 SNES_Patch *patch = (SNES_Patch *)snes->data;
217
218 PetscFunctionBegin;
219 PetscCall(SNESReset_Patch(snes));
220 PetscCall(PCDestroy(&patch->pc));
221 PetscCall(PetscFree(snes->data));
222 PetscFunctionReturn(PETSC_SUCCESS);
223 }
224
SNESSetFromOptions_Patch(SNES snes,PetscOptionItems PetscOptionsObject)225 static PetscErrorCode SNESSetFromOptions_Patch(SNES snes, PetscOptionItems PetscOptionsObject)
226 {
227 SNES_Patch *patch = (SNES_Patch *)snes->data;
228 const char *prefix;
229
230 PetscFunctionBegin;
231 PetscCall(PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix));
232 PetscCall(PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix));
233 PetscCall(PCSetFromOptions(patch->pc));
234 PetscFunctionReturn(PETSC_SUCCESS);
235 }
236
SNESView_Patch(SNES snes,PetscViewer viewer)237 static PetscErrorCode SNESView_Patch(SNES snes, PetscViewer viewer)
238 {
239 SNES_Patch *patch = (SNES_Patch *)snes->data;
240 PetscBool isascii;
241
242 PetscFunctionBegin;
243 PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
244 if (isascii) PetscCall(PetscViewerASCIIPrintf(viewer, "SNESPATCH\n"));
245 PetscCall(PetscViewerASCIIPushTab(viewer));
246 PetscCall(PCView(patch->pc, viewer));
247 PetscCall(PetscViewerASCIIPopTab(viewer));
248 PetscFunctionReturn(PETSC_SUCCESS);
249 }
250
SNESSolve_Patch(SNES snes)251 static PetscErrorCode SNESSolve_Patch(SNES snes)
252 {
253 SNES_Patch *patch = (SNES_Patch *)snes->data;
254 PC_PATCH *pcpatch = (PC_PATCH *)patch->pc->data;
255 SNESLineSearch ls;
256 Vec rhs, update, state, residual;
257 const PetscScalar *globalState = NULL;
258 PetscScalar *localState = NULL;
259 PetscInt its = 0;
260 PetscReal xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
261
262 PetscFunctionBegin;
263 PetscCall(SNESGetSolution(snes, &state));
264 PetscCall(SNESGetSolutionUpdate(snes, &update));
265 PetscCall(SNESGetRhs(snes, &rhs));
266
267 PetscCall(SNESGetFunction(snes, &residual, NULL, NULL));
268 PetscCall(SNESGetLineSearch(snes, &ls));
269
270 PetscCall(SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING));
271 PetscCall(VecSet(update, 0.0));
272 PetscCall(SNESComputeFunction(snes, state, residual));
273
274 PetscCall(VecNorm(state, NORM_2, &xnorm));
275 PetscCall(VecNorm(residual, NORM_2, &fnorm));
276 SNESCheckFunctionDomainError(snes, fnorm);
277 snes->ttol = fnorm * snes->rtol;
278
279 if (snes->ops->converged) {
280 PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
281 } else {
282 PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
283 }
284 PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* should we count lits from the patches? */
285 PetscCall(SNESMonitor(snes, its, fnorm));
286
287 /* The main solver loop */
288 for (its = 0; its < snes->max_its; its++) {
289 PetscCall(SNESSetIterationNumber(snes, its));
290
291 /* Scatter state vector to overlapped vector on all patches.
292 The vector pcpatch->localState is scattered to each patch
293 in PCApply_PATCH_Nonlinear. */
294 PetscCall(VecGetArrayRead(state, &globalState));
295 PetscCall(VecGetArray(pcpatch->localState, &localState));
296 PetscCall(PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
297 PetscCall(PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
298 PetscCall(VecRestoreArray(pcpatch->localState, &localState));
299 PetscCall(VecRestoreArrayRead(state, &globalState));
300
301 /* The looping over patches happens here */
302 PetscCall(PCApply(patch->pc, rhs, update));
303
304 /* Apply a line search. This will often be basic with
305 damping = 1/(max number of patches a dof can be in),
306 but not always */
307 PetscCall(VecScale(update, -1.0));
308 PetscCall(SNESLineSearchApply(ls, state, residual, &fnorm, update));
309
310 PetscCall(VecNorm(state, NORM_2, &xnorm));
311 PetscCall(VecNorm(update, NORM_2, &ynorm));
312
313 if (snes->ops->converged) {
314 PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
315 } else {
316 PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
317 }
318 PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* FIXME: should we count lits? */
319 PetscCall(SNESMonitor(snes, its, fnorm));
320 }
321
322 if (its == snes->max_its) PetscCall(SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT));
323 PetscFunctionReturn(PETSC_SUCCESS);
324 }
325
326 /*MC
327 SNESPATCH - Solve a nonlinear problem or apply a nonlinear smoother by composing together many nonlinear solvers on (often overlapping) patches {cite}`bruneknepleysmithtu15`
328
329 Level: intermediate
330
331 .seealso: [](ch_snes), `SNESFAS`, `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, `PCPATCH`
332 M*/
SNESCreate_Patch(SNES snes)333 PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
334 {
335 SNES_Patch *patch;
336 PC_PATCH *patchpc;
337 SNESLineSearch linesearch;
338
339 PetscFunctionBegin;
340 PetscCall(PetscNew(&patch));
341
342 snes->ops->solve = SNESSolve_Patch;
343 snes->ops->setup = SNESSetUp_Patch;
344 snes->ops->reset = SNESReset_Patch;
345 snes->ops->destroy = SNESDestroy_Patch;
346 snes->ops->setfromoptions = SNESSetFromOptions_Patch;
347 snes->ops->view = SNESView_Patch;
348
349 PetscCall(SNESGetLineSearch(snes, &linesearch));
350 if (!((PetscObject)linesearch)->type_name) PetscCall(SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC));
351 snes->usesksp = PETSC_FALSE;
352
353 snes->alwayscomputesfinalresidual = PETSC_FALSE;
354
355 PetscCall(SNESParametersInitialize(snes));
356
357 snes->data = (void *)patch;
358 PetscCall(PCCreate(PetscObjectComm((PetscObject)snes), &patch->pc));
359 PetscCall(PCSetType(patch->pc, PCPATCH));
360
361 patchpc = (PC_PATCH *)patch->pc->data;
362 patchpc->classname = "snes";
363 patchpc->isNonlinear = PETSC_TRUE;
364
365 patchpc->setupsolver = PCSetUp_PATCH_Nonlinear;
366 patchpc->applysolver = PCApply_PATCH_Nonlinear;
367 patchpc->resetsolver = PCReset_PATCH_Nonlinear;
368 patchpc->destroysolver = PCDestroy_PATCH_Nonlinear;
369 patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;
370 PetscFunctionReturn(PETSC_SUCCESS);
371 }
372
SNESPatchSetDiscretisationInfo(SNES snes,PetscInt nsubspaces,DM * dms,PetscInt * bs,PetscInt * nodesPerCell,const PetscInt ** cellNodeMap,const PetscInt * subspaceOffsets,PetscInt numGhostBcs,const PetscInt * ghostBcNodes,PetscInt numGlobalBcs,const PetscInt * globalBcNodes)373 PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
374 {
375 SNES_Patch *patch = (SNES_Patch *)snes->data;
376 DM dm;
377
378 PetscFunctionBegin;
379 PetscCall(SNESGetDM(snes, &dm));
380 PetscCheck(dm, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES");
381 PetscCall(PCSetDM(patch->pc, dm));
382 PetscCall(PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes));
383 PetscFunctionReturn(PETSC_SUCCESS);
384 }
385
SNESPatchSetComputeOperator(SNES snes,PetscErrorCode (* func)(PC,PetscInt,Vec,Mat,IS,PetscInt,const PetscInt *,const PetscInt *,void *),PetscCtx ctx)386 PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), PetscCtx ctx)
387 {
388 SNES_Patch *patch = (SNES_Patch *)snes->data;
389
390 PetscFunctionBegin;
391 PetscCall(PCPatchSetComputeOperator(patch->pc, func, ctx));
392 PetscFunctionReturn(PETSC_SUCCESS);
393 }
394
SNESPatchSetComputeFunction(SNES snes,PetscErrorCode (* func)(PC,PetscInt,Vec,Vec,IS,PetscInt,const PetscInt *,const PetscInt *,void *),PetscCtx ctx)395 PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), PetscCtx ctx)
396 {
397 SNES_Patch *patch = (SNES_Patch *)snes->data;
398
399 PetscFunctionBegin;
400 PetscCall(PCPatchSetComputeFunction(patch->pc, func, ctx));
401 PetscFunctionReturn(PETSC_SUCCESS);
402 }
403
SNESPatchSetConstructType(SNES snes,PCPatchConstructType ctype,PetscErrorCode (* func)(PC,PetscInt *,IS **,IS *,void *),PetscCtx ctx)404 PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), PetscCtx ctx)
405 {
406 SNES_Patch *patch = (SNES_Patch *)snes->data;
407
408 PetscFunctionBegin;
409 PetscCall(PCPatchSetConstructType(patch->pc, ctype, func, ctx));
410 PetscFunctionReturn(PETSC_SUCCESS);
411 }
412
SNESPatchSetCellNumbering(SNES snes,PetscSection cellNumbering)413 PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
414 {
415 SNES_Patch *patch = (SNES_Patch *)snes->data;
416
417 PetscFunctionBegin;
418 PetscCall(PCPatchSetCellNumbering(patch->pc, cellNumbering));
419 PetscFunctionReturn(PETSC_SUCCESS);
420 }
421