1 /* 2 Defines a SNES that can consist of a collection of SNESes on patches of the domain 3 */ 4 #include <petsc/private/vecimpl.h> /* For vec->map */ 5 #include <petsc/private/snesimpl.h> /*I "petscsnes.h" I*/ 6 #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */ 7 #include <petscsf.h> 8 #include <petscsection.h> 9 10 typedef struct { 11 PC pc; /* The linear patch preconditioner */ 12 } SNES_Patch; 13 14 static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx) { 15 PC pc = (PC)ctx; 16 PC_PATCH *pcpatch = (PC_PATCH *)pc->data; 17 PetscInt pt, size, i; 18 const PetscInt *indices; 19 const PetscScalar *X; 20 PetscScalar *XWithAll; 21 22 PetscFunctionBegin; 23 24 /* scatter from x to patch->patchStateWithAll[pt] */ 25 pt = pcpatch->currentPatch; 26 PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size)); 27 28 PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 29 PetscCall(VecGetArrayRead(x, &X)); 30 PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll)); 31 32 for (i = 0; i < size; ++i) { XWithAll[indices[i]] = X[i]; } 33 34 PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll)); 35 PetscCall(VecRestoreArrayRead(x, &X)); 36 PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 37 38 PetscCall(PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt)); 39 PetscFunctionReturn(0); 40 } 41 42 static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx) { 43 PC pc = (PC)ctx; 44 PC_PATCH *pcpatch = (PC_PATCH *)pc->data; 45 PetscInt pt, size, i; 46 const PetscInt *indices; 47 const PetscScalar *X; 48 PetscScalar *XWithAll; 49 50 PetscFunctionBegin; 51 /* scatter from x to patch->patchStateWithAll[pt] */ 52 pt = pcpatch->currentPatch; 53 PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size)); 54 55 PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 56 PetscCall(VecGetArrayRead(x, &X)); 57 PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll)); 58 59 for (i = 0; i < size; ++i) { XWithAll[indices[i]] = X[i]; } 60 61 PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll)); 62 PetscCall(VecRestoreArrayRead(x, &X)); 63 PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 64 65 PetscCall(PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE)); 66 PetscFunctionReturn(0); 67 } 68 69 static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc) { 70 PC_PATCH *patch = (PC_PATCH *)pc->data; 71 const char *prefix; 72 PetscInt i, pStart, dof, maxDof = -1; 73 74 PetscFunctionBegin; 75 if (!pc->setupcalled) { 76 PetscCall(PetscMalloc1(patch->npatch, &patch->solver)); 77 PetscCall(PCGetOptionsPrefix(pc, &prefix)); 78 PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL)); 79 for (i = 0; i < patch->npatch; ++i) { 80 SNES snes; 81 82 PetscCall(SNESCreate(PETSC_COMM_SELF, &snes)); 83 PetscCall(SNESSetOptionsPrefix(snes, prefix)); 84 PetscCall(SNESAppendOptionsPrefix(snes, "sub_")); 85 PetscCall(PetscObjectIncrementTabLevel((PetscObject)snes, (PetscObject)pc, 2)); 86 PetscCall(PetscLogObjectParent((PetscObject)pc, (PetscObject)snes)); 87 patch->solver[i] = (PetscObject)snes; 88 89 PetscCall(PetscSectionGetDof(patch->gtolCountsWithAll, i + pStart, &dof)); 90 maxDof = PetscMax(maxDof, dof); 91 } 92 PetscCall(VecDuplicate(patch->localUpdate, &patch->localState)); 93 PetscCall(VecDuplicate(patch->patchRHS, &patch->patchResidual)); 94 PetscCall(VecDuplicate(patch->patchUpdate, &patch->patchState)); 95 96 PetscCall(VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll)); 97 PetscCall(VecSetUp(patch->patchStateWithAll)); 98 } 99 for (i = 0; i < patch->npatch; ++i) { 100 SNES snes = (SNES)patch->solver[i]; 101 102 PetscCall(SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc)); 103 PetscCall(SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc)); 104 } 105 if (!pc->setupcalled && patch->optionsSet) 106 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESSetFromOptions((SNES)patch->solver[i])); 107 PetscFunctionReturn(0); 108 } 109 110 static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate) { 111 PC_PATCH *patch = (PC_PATCH *)pc->data; 112 PetscInt pStart, n; 113 114 PetscFunctionBegin; 115 patch->currentPatch = i; 116 PetscCall(PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0)); 117 118 /* Scatter the overlapped global state to our patch state vector */ 119 PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL)); 120 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR)); 121 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL)); 122 123 PetscCall(MatGetLocalSize(patch->mat[i], NULL, &n)); 124 patch->patchState->map->n = n; 125 patch->patchState->map->N = n; 126 patchUpdate->map->n = n; 127 patchUpdate->map->N = n; 128 patchRHS->map->n = n; 129 patchRHS->map->N = n; 130 /* Set initial guess to be current state*/ 131 PetscCall(VecCopy(patch->patchState, patchUpdate)); 132 /* Solve for new state */ 133 PetscCall(SNESSolve((SNES)patch->solver[i], patchRHS, patchUpdate)); 134 /* To compute update, subtract off previous state */ 135 PetscCall(VecAXPY(patchUpdate, -1.0, patch->patchState)); 136 137 PetscCall(PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0)); 138 PetscFunctionReturn(0); 139 } 140 141 static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc) { 142 PC_PATCH *patch = (PC_PATCH *)pc->data; 143 PetscInt i; 144 145 PetscFunctionBegin; 146 if (patch->solver) { 147 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESReset((SNES)patch->solver[i])); 148 } 149 150 PetscCall(VecDestroy(&patch->patchResidual)); 151 PetscCall(VecDestroy(&patch->patchState)); 152 PetscCall(VecDestroy(&patch->patchStateWithAll)); 153 154 PetscCall(VecDestroy(&patch->localState)); 155 PetscFunctionReturn(0); 156 } 157 158 static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc) { 159 PC_PATCH *patch = (PC_PATCH *)pc->data; 160 PetscInt i; 161 162 PetscFunctionBegin; 163 if (patch->solver) { 164 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESDestroy((SNES *)&patch->solver[i])); 165 PetscCall(PetscFree(patch->solver)); 166 } 167 PetscFunctionReturn(0); 168 } 169 170 static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart) { 171 PC_PATCH *patch = (PC_PATCH *)pc->data; 172 173 PetscFunctionBegin; 174 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR)); 175 PetscFunctionReturn(0); 176 } 177 178 static PetscErrorCode SNESSetUp_Patch(SNES snes) { 179 SNES_Patch *patch = (SNES_Patch *)snes->data; 180 DM dm; 181 Mat dummy; 182 Vec F; 183 PetscInt n, N; 184 185 PetscFunctionBegin; 186 PetscCall(SNESGetDM(snes, &dm)); 187 PetscCall(PCSetDM(patch->pc, dm)); 188 PetscCall(SNESGetFunction(snes, &F, NULL, NULL)); 189 PetscCall(VecGetLocalSize(F, &n)); 190 PetscCall(VecGetSize(F, &N)); 191 PetscCall(MatCreateShell(PetscObjectComm((PetscObject)snes), n, n, N, N, (void *)snes, &dummy)); 192 PetscCall(PCSetOperators(patch->pc, dummy, dummy)); 193 PetscCall(MatDestroy(&dummy)); 194 PetscCall(PCSetUp(patch->pc)); 195 /* allocate workspace */ 196 PetscFunctionReturn(0); 197 } 198 199 static PetscErrorCode SNESReset_Patch(SNES snes) { 200 SNES_Patch *patch = (SNES_Patch *)snes->data; 201 202 PetscFunctionBegin; 203 PetscCall(PCReset(patch->pc)); 204 PetscFunctionReturn(0); 205 } 206 207 static PetscErrorCode SNESDestroy_Patch(SNES snes) { 208 SNES_Patch *patch = (SNES_Patch *)snes->data; 209 210 PetscFunctionBegin; 211 PetscCall(SNESReset_Patch(snes)); 212 PetscCall(PCDestroy(&patch->pc)); 213 PetscCall(PetscFree(snes->data)); 214 PetscFunctionReturn(0); 215 } 216 217 static PetscErrorCode SNESSetFromOptions_Patch(SNES snes, PetscOptionItems *PetscOptionsObject) { 218 SNES_Patch *patch = (SNES_Patch *)snes->data; 219 const char *prefix; 220 221 PetscFunctionBegin; 222 PetscCall(PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix)); 223 PetscCall(PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix)); 224 PetscCall(PCSetFromOptions(patch->pc)); 225 PetscFunctionReturn(0); 226 } 227 228 static PetscErrorCode SNESView_Patch(SNES snes, PetscViewer viewer) { 229 SNES_Patch *patch = (SNES_Patch *)snes->data; 230 PetscBool iascii; 231 232 PetscFunctionBegin; 233 PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii)); 234 if (iascii) PetscCall(PetscViewerASCIIPrintf(viewer, "SNESPATCH\n")); 235 PetscCall(PetscViewerASCIIPushTab(viewer)); 236 PetscCall(PCView(patch->pc, viewer)); 237 PetscCall(PetscViewerASCIIPopTab(viewer)); 238 PetscFunctionReturn(0); 239 } 240 241 static PetscErrorCode SNESSolve_Patch(SNES snes) { 242 SNES_Patch *patch = (SNES_Patch *)snes->data; 243 PC_PATCH *pcpatch = (PC_PATCH *)patch->pc->data; 244 SNESLineSearch ls; 245 Vec rhs, update, state, residual; 246 const PetscScalar *globalState = NULL; 247 PetscScalar *localState = NULL; 248 PetscInt its = 0; 249 PetscReal xnorm = 0.0, ynorm = 0.0, fnorm = 0.0; 250 251 PetscFunctionBegin; 252 PetscCall(SNESGetSolution(snes, &state)); 253 PetscCall(SNESGetSolutionUpdate(snes, &update)); 254 PetscCall(SNESGetRhs(snes, &rhs)); 255 256 PetscCall(SNESGetFunction(snes, &residual, NULL, NULL)); 257 PetscCall(SNESGetLineSearch(snes, &ls)); 258 259 PetscCall(SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING)); 260 PetscCall(VecSet(update, 0.0)); 261 PetscCall(SNESComputeFunction(snes, state, residual)); 262 263 PetscCall(VecNorm(state, NORM_2, &xnorm)); 264 PetscCall(VecNorm(residual, NORM_2, &fnorm)); 265 snes->ttol = fnorm * snes->rtol; 266 267 if (snes->ops->converged) { 268 PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP); 269 } else { 270 PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL)); 271 } 272 PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* should we count lits from the patches? */ 273 PetscCall(SNESMonitor(snes, its, fnorm)); 274 275 /* The main solver loop */ 276 for (its = 0; its < snes->max_its; its++) { 277 PetscCall(SNESSetIterationNumber(snes, its)); 278 279 /* Scatter state vector to overlapped vector on all patches. 280 The vector pcpatch->localState is scattered to each patch 281 in PCApply_PATCH_Nonlinear. */ 282 PetscCall(VecGetArrayRead(state, &globalState)); 283 PetscCall(VecGetArray(pcpatch->localState, &localState)); 284 PetscCall(PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE)); 285 PetscCall(PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE)); 286 PetscCall(VecRestoreArray(pcpatch->localState, &localState)); 287 PetscCall(VecRestoreArrayRead(state, &globalState)); 288 289 /* The looping over patches happens here */ 290 PetscCall(PCApply(patch->pc, rhs, update)); 291 292 /* Apply a line search. This will often be basic with 293 damping = 1/(max number of patches a dof can be in), 294 but not always */ 295 PetscCall(VecScale(update, -1.0)); 296 PetscCall(SNESLineSearchApply(ls, state, residual, &fnorm, update)); 297 298 PetscCall(VecNorm(state, NORM_2, &xnorm)); 299 PetscCall(VecNorm(update, NORM_2, &ynorm)); 300 301 if (snes->ops->converged) { 302 PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP); 303 } else { 304 PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL)); 305 } 306 PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* FIXME: should we count lits? */ 307 PetscCall(SNESMonitor(snes, its, fnorm)); 308 } 309 310 if (its == snes->max_its) PetscCall(SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT)); 311 PetscFunctionReturn(0); 312 } 313 314 /*MC 315 SNESPATCH - Solve a nonlinear problem by composing together many nonlinear solvers on patches 316 317 Level: intermediate 318 319 .seealso: `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, 320 `PCPATCH` 321 322 References: 323 . * - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015 324 325 M*/ 326 PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes) { 327 SNES_Patch *patch; 328 PC_PATCH *patchpc; 329 SNESLineSearch linesearch; 330 331 PetscFunctionBegin; 332 PetscCall(PetscNewLog(snes, &patch)); 333 334 snes->ops->solve = SNESSolve_Patch; 335 snes->ops->setup = SNESSetUp_Patch; 336 snes->ops->reset = SNESReset_Patch; 337 snes->ops->destroy = SNESDestroy_Patch; 338 snes->ops->setfromoptions = SNESSetFromOptions_Patch; 339 snes->ops->view = SNESView_Patch; 340 341 PetscCall(SNESGetLineSearch(snes, &linesearch)); 342 if (!((PetscObject)linesearch)->type_name) PetscCall(SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC)); 343 snes->usesksp = PETSC_FALSE; 344 345 snes->alwayscomputesfinalresidual = PETSC_FALSE; 346 347 snes->data = (void *)patch; 348 PetscCall(PCCreate(PetscObjectComm((PetscObject)snes), &patch->pc)); 349 PetscCall(PCSetType(patch->pc, PCPATCH)); 350 351 patchpc = (PC_PATCH *)patch->pc->data; 352 patchpc->classname = "snes"; 353 patchpc->isNonlinear = PETSC_TRUE; 354 355 patchpc->setupsolver = PCSetUp_PATCH_Nonlinear; 356 patchpc->applysolver = PCApply_PATCH_Nonlinear; 357 patchpc->resetsolver = PCReset_PATCH_Nonlinear; 358 patchpc->destroysolver = PCDestroy_PATCH_Nonlinear; 359 patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear; 360 361 PetscFunctionReturn(0); 362 } 363 364 PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes) { 365 SNES_Patch *patch = (SNES_Patch *)snes->data; 366 DM dm; 367 368 PetscFunctionBegin; 369 PetscCall(SNESGetDM(snes, &dm)); 370 PetscCheck(dm, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES"); 371 PetscCall(PCSetDM(patch->pc, dm)); 372 PetscCall(PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes)); 373 PetscFunctionReturn(0); 374 } 375 376 PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx) { 377 SNES_Patch *patch = (SNES_Patch *)snes->data; 378 379 PetscFunctionBegin; 380 PetscCall(PCPatchSetComputeOperator(patch->pc, func, ctx)); 381 PetscFunctionReturn(0); 382 } 383 384 PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx) { 385 SNES_Patch *patch = (SNES_Patch *)snes->data; 386 387 PetscFunctionBegin; 388 PetscCall(PCPatchSetComputeFunction(patch->pc, func, ctx)); 389 PetscFunctionReturn(0); 390 } 391 392 PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx) { 393 SNES_Patch *patch = (SNES_Patch *)snes->data; 394 395 PetscFunctionBegin; 396 PetscCall(PCPatchSetConstructType(patch->pc, ctype, func, ctx)); 397 PetscFunctionReturn(0); 398 } 399 400 PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering) { 401 SNES_Patch *patch = (SNES_Patch *)snes->data; 402 403 PetscFunctionBegin; 404 PetscCall(PCPatchSetCellNumbering(patch->pc, cellNumbering)); 405 PetscFunctionReturn(0); 406 } 407