1 /* 2 Defines a SNES that can consist of a collection of SNESes on patches of the domain 3 */ 4 #include <petsc/private/vecimpl.h> /* For vec->map */ 5 #include <petsc/private/snesimpl.h> /*I "petscsnes.h" I*/ 6 #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */ 7 #include <petscsf.h> 8 #include <petscsection.h> 9 10 typedef struct { 11 PC pc; /* The linear patch preconditioner */ 12 } SNES_Patch; 13 14 static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx) 15 { 16 PC pc = (PC)ctx; 17 PC_PATCH *pcpatch = (PC_PATCH *)pc->data; 18 PetscInt pt, size, i; 19 const PetscInt *indices; 20 const PetscScalar *X; 21 PetscScalar *XWithAll; 22 23 PetscFunctionBegin; 24 /* scatter from x to patch->patchStateWithAll[pt] */ 25 pt = pcpatch->currentPatch; 26 PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size)); 27 28 PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 29 PetscCall(VecGetArrayRead(x, &X)); 30 PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll)); 31 32 for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i]; 33 34 PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll)); 35 PetscCall(VecRestoreArrayRead(x, &X)); 36 PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 37 38 PetscCall(PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt)); 39 PetscFunctionReturn(PETSC_SUCCESS); 40 } 41 42 static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx) 43 { 44 PC pc = (PC)ctx; 45 PC_PATCH *pcpatch = (PC_PATCH *)pc->data; 46 PetscInt pt, size, i; 47 const PetscInt *indices; 48 const PetscScalar *X; 49 PetscScalar *XWithAll; 50 51 PetscFunctionBegin; 52 /* scatter from x to patch->patchStateWithAll[pt] */ 53 pt = pcpatch->currentPatch; 54 PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size)); 55 56 PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 57 PetscCall(VecGetArrayRead(x, &X)); 58 PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll)); 59 60 for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i]; 61 62 PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll)); 63 PetscCall(VecRestoreArrayRead(x, &X)); 64 PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 65 66 PetscCall(PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE)); 67 PetscFunctionReturn(PETSC_SUCCESS); 68 } 69 70 static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc) 71 { 72 PC_PATCH *patch = (PC_PATCH *)pc->data; 73 const char *prefix; 74 PetscInt i, pStart, dof, maxDof = -1; 75 76 PetscFunctionBegin; 77 if (!pc->setupcalled) { 78 PetscCall(PetscMalloc1(patch->npatch, &patch->solver)); 79 PetscCall(PCGetOptionsPrefix(pc, &prefix)); 80 PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL)); 81 for (i = 0; i < patch->npatch; ++i) { 82 SNES snes; 83 84 PetscCall(SNESCreate(PETSC_COMM_SELF, &snes)); 85 PetscCall(SNESSetOptionsPrefix(snes, prefix)); 86 PetscCall(SNESAppendOptionsPrefix(snes, "sub_")); 87 PetscCall(PetscObjectIncrementTabLevel((PetscObject)snes, (PetscObject)pc, 2)); 88 patch->solver[i] = (PetscObject)snes; 89 90 PetscCall(PetscSectionGetDof(patch->gtolCountsWithAll, i + pStart, &dof)); 91 maxDof = PetscMax(maxDof, dof); 92 } 93 PetscCall(VecDuplicate(patch->localUpdate, &patch->localState)); 94 PetscCall(VecDuplicate(patch->patchRHS, &patch->patchResidual)); 95 PetscCall(VecDuplicate(patch->patchUpdate, &patch->patchState)); 96 97 PetscCall(VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll)); 98 PetscCall(VecSetUp(patch->patchStateWithAll)); 99 } 100 for (i = 0; i < patch->npatch; ++i) { 101 SNES snes = (SNES)patch->solver[i]; 102 103 PetscCall(SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc)); 104 PetscCall(SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc)); 105 } 106 if (!pc->setupcalled && patch->optionsSet) 107 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESSetFromOptions((SNES)patch->solver[i])); 108 PetscFunctionReturn(PETSC_SUCCESS); 109 } 110 111 static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate) 112 { 113 PC_PATCH *patch = (PC_PATCH *)pc->data; 114 PetscInt pStart, n; 115 116 PetscFunctionBegin; 117 patch->currentPatch = i; 118 PetscCall(PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0)); 119 120 /* Scatter the overlapped global state to our patch state vector */ 121 PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL)); 122 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR)); 123 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL)); 124 125 PetscCall(MatGetLocalSize(patch->mat[i], NULL, &n)); 126 patch->patchState->map->n = n; 127 patch->patchState->map->N = n; 128 patchUpdate->map->n = n; 129 patchUpdate->map->N = n; 130 patchRHS->map->n = n; 131 patchRHS->map->N = n; 132 /* Set initial guess to be current state*/ 133 PetscCall(VecCopy(patch->patchState, patchUpdate)); 134 /* Solve for new state */ 135 PetscCall(SNESSolve((SNES)patch->solver[i], patchRHS, patchUpdate)); 136 /* To compute update, subtract off previous state */ 137 PetscCall(VecAXPY(patchUpdate, -1.0, patch->patchState)); 138 139 PetscCall(PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0)); 140 PetscFunctionReturn(PETSC_SUCCESS); 141 } 142 143 static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc) 144 { 145 PC_PATCH *patch = (PC_PATCH *)pc->data; 146 PetscInt i; 147 148 PetscFunctionBegin; 149 if (patch->solver) { 150 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESReset((SNES)patch->solver[i])); 151 } 152 153 PetscCall(VecDestroy(&patch->patchResidual)); 154 PetscCall(VecDestroy(&patch->patchState)); 155 PetscCall(VecDestroy(&patch->patchStateWithAll)); 156 157 PetscCall(VecDestroy(&patch->localState)); 158 PetscFunctionReturn(PETSC_SUCCESS); 159 } 160 161 static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc) 162 { 163 PC_PATCH *patch = (PC_PATCH *)pc->data; 164 PetscInt i; 165 166 PetscFunctionBegin; 167 if (patch->solver) { 168 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESDestroy((SNES *)&patch->solver[i])); 169 PetscCall(PetscFree(patch->solver)); 170 } 171 PetscFunctionReturn(PETSC_SUCCESS); 172 } 173 174 static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart) 175 { 176 PC_PATCH *patch = (PC_PATCH *)pc->data; 177 178 PetscFunctionBegin; 179 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR)); 180 PetscFunctionReturn(PETSC_SUCCESS); 181 } 182 183 static PetscErrorCode SNESSetUp_Patch(SNES snes) 184 { 185 SNES_Patch *patch = (SNES_Patch *)snes->data; 186 DM dm; 187 Mat dummy; 188 Vec F; 189 PetscInt n, N; 190 191 PetscFunctionBegin; 192 PetscCall(SNESGetDM(snes, &dm)); 193 PetscCall(PCSetDM(patch->pc, dm)); 194 PetscCall(SNESGetFunction(snes, &F, NULL, NULL)); 195 PetscCall(VecGetLocalSize(F, &n)); 196 PetscCall(VecGetSize(F, &N)); 197 PetscCall(MatCreateShell(PetscObjectComm((PetscObject)snes), n, n, N, N, (void *)snes, &dummy)); 198 PetscCall(PCSetOperators(patch->pc, dummy, dummy)); 199 PetscCall(MatDestroy(&dummy)); 200 PetscCall(PCSetUp(patch->pc)); 201 /* allocate workspace */ 202 PetscFunctionReturn(PETSC_SUCCESS); 203 } 204 205 static PetscErrorCode SNESReset_Patch(SNES snes) 206 { 207 SNES_Patch *patch = (SNES_Patch *)snes->data; 208 209 PetscFunctionBegin; 210 PetscCall(PCReset(patch->pc)); 211 PetscFunctionReturn(PETSC_SUCCESS); 212 } 213 214 static PetscErrorCode SNESDestroy_Patch(SNES snes) 215 { 216 SNES_Patch *patch = (SNES_Patch *)snes->data; 217 218 PetscFunctionBegin; 219 PetscCall(SNESReset_Patch(snes)); 220 PetscCall(PCDestroy(&patch->pc)); 221 PetscCall(PetscFree(snes->data)); 222 PetscFunctionReturn(PETSC_SUCCESS); 223 } 224 225 static PetscErrorCode SNESSetFromOptions_Patch(SNES snes, PetscOptionItems PetscOptionsObject) 226 { 227 SNES_Patch *patch = (SNES_Patch *)snes->data; 228 const char *prefix; 229 230 PetscFunctionBegin; 231 PetscCall(PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix)); 232 PetscCall(PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix)); 233 PetscCall(PCSetFromOptions(patch->pc)); 234 PetscFunctionReturn(PETSC_SUCCESS); 235 } 236 237 static PetscErrorCode SNESView_Patch(SNES snes, PetscViewer viewer) 238 { 239 SNES_Patch *patch = (SNES_Patch *)snes->data; 240 PetscBool isascii; 241 242 PetscFunctionBegin; 243 PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii)); 244 if (isascii) PetscCall(PetscViewerASCIIPrintf(viewer, "SNESPATCH\n")); 245 PetscCall(PetscViewerASCIIPushTab(viewer)); 246 PetscCall(PCView(patch->pc, viewer)); 247 PetscCall(PetscViewerASCIIPopTab(viewer)); 248 PetscFunctionReturn(PETSC_SUCCESS); 249 } 250 251 static PetscErrorCode SNESSolve_Patch(SNES snes) 252 { 253 SNES_Patch *patch = (SNES_Patch *)snes->data; 254 PC_PATCH *pcpatch = (PC_PATCH *)patch->pc->data; 255 SNESLineSearch ls; 256 Vec rhs, update, state, residual; 257 const PetscScalar *globalState = NULL; 258 PetscScalar *localState = NULL; 259 PetscInt its = 0; 260 PetscReal xnorm = 0.0, ynorm = 0.0, fnorm = 0.0; 261 262 PetscFunctionBegin; 263 PetscCall(SNESGetSolution(snes, &state)); 264 PetscCall(SNESGetSolutionUpdate(snes, &update)); 265 PetscCall(SNESGetRhs(snes, &rhs)); 266 267 PetscCall(SNESGetFunction(snes, &residual, NULL, NULL)); 268 PetscCall(SNESGetLineSearch(snes, &ls)); 269 270 PetscCall(SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING)); 271 PetscCall(VecSet(update, 0.0)); 272 PetscCall(SNESComputeFunction(snes, state, residual)); 273 274 PetscCall(VecNorm(state, NORM_2, &xnorm)); 275 PetscCall(VecNorm(residual, NORM_2, &fnorm)); 276 snes->ttol = fnorm * snes->rtol; 277 278 if (snes->ops->converged) { 279 PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP); 280 } else { 281 PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL)); 282 } 283 PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* should we count lits from the patches? */ 284 PetscCall(SNESMonitor(snes, its, fnorm)); 285 286 /* The main solver loop */ 287 for (its = 0; its < snes->max_its; its++) { 288 PetscCall(SNESSetIterationNumber(snes, its)); 289 290 /* Scatter state vector to overlapped vector on all patches. 291 The vector pcpatch->localState is scattered to each patch 292 in PCApply_PATCH_Nonlinear. */ 293 PetscCall(VecGetArrayRead(state, &globalState)); 294 PetscCall(VecGetArray(pcpatch->localState, &localState)); 295 PetscCall(PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE)); 296 PetscCall(PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE)); 297 PetscCall(VecRestoreArray(pcpatch->localState, &localState)); 298 PetscCall(VecRestoreArrayRead(state, &globalState)); 299 300 /* The looping over patches happens here */ 301 PetscCall(PCApply(patch->pc, rhs, update)); 302 303 /* Apply a line search. This will often be basic with 304 damping = 1/(max number of patches a dof can be in), 305 but not always */ 306 PetscCall(VecScale(update, -1.0)); 307 PetscCall(SNESLineSearchApply(ls, state, residual, &fnorm, update)); 308 309 PetscCall(VecNorm(state, NORM_2, &xnorm)); 310 PetscCall(VecNorm(update, NORM_2, &ynorm)); 311 312 if (snes->ops->converged) { 313 PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP); 314 } else { 315 PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL)); 316 } 317 PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* FIXME: should we count lits? */ 318 PetscCall(SNESMonitor(snes, its, fnorm)); 319 } 320 321 if (its == snes->max_its) PetscCall(SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT)); 322 PetscFunctionReturn(PETSC_SUCCESS); 323 } 324 325 /*MC 326 SNESPATCH - Solve a nonlinear problem or apply a nonlinear smoother by composing together many nonlinear solvers on (often overlapping) patches {cite}`bruneknepleysmithtu15` 327 328 Level: intermediate 329 330 .seealso: [](ch_snes), `SNESFAS`, `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, `PCPATCH` 331 M*/ 332 PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes) 333 { 334 SNES_Patch *patch; 335 PC_PATCH *patchpc; 336 SNESLineSearch linesearch; 337 338 PetscFunctionBegin; 339 PetscCall(PetscNew(&patch)); 340 341 snes->ops->solve = SNESSolve_Patch; 342 snes->ops->setup = SNESSetUp_Patch; 343 snes->ops->reset = SNESReset_Patch; 344 snes->ops->destroy = SNESDestroy_Patch; 345 snes->ops->setfromoptions = SNESSetFromOptions_Patch; 346 snes->ops->view = SNESView_Patch; 347 348 PetscCall(SNESGetLineSearch(snes, &linesearch)); 349 if (!((PetscObject)linesearch)->type_name) PetscCall(SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC)); 350 snes->usesksp = PETSC_FALSE; 351 352 snes->alwayscomputesfinalresidual = PETSC_FALSE; 353 354 PetscCall(SNESParametersInitialize(snes)); 355 356 snes->data = (void *)patch; 357 PetscCall(PCCreate(PetscObjectComm((PetscObject)snes), &patch->pc)); 358 PetscCall(PCSetType(patch->pc, PCPATCH)); 359 360 patchpc = (PC_PATCH *)patch->pc->data; 361 patchpc->classname = "snes"; 362 patchpc->isNonlinear = PETSC_TRUE; 363 364 patchpc->setupsolver = PCSetUp_PATCH_Nonlinear; 365 patchpc->applysolver = PCApply_PATCH_Nonlinear; 366 patchpc->resetsolver = PCReset_PATCH_Nonlinear; 367 patchpc->destroysolver = PCDestroy_PATCH_Nonlinear; 368 patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear; 369 PetscFunctionReturn(PETSC_SUCCESS); 370 } 371 372 PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes) 373 { 374 SNES_Patch *patch = (SNES_Patch *)snes->data; 375 DM dm; 376 377 PetscFunctionBegin; 378 PetscCall(SNESGetDM(snes, &dm)); 379 PetscCheck(dm, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES"); 380 PetscCall(PCSetDM(patch->pc, dm)); 381 PetscCall(PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes)); 382 PetscFunctionReturn(PETSC_SUCCESS); 383 } 384 385 PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx) 386 { 387 SNES_Patch *patch = (SNES_Patch *)snes->data; 388 389 PetscFunctionBegin; 390 PetscCall(PCPatchSetComputeOperator(patch->pc, func, ctx)); 391 PetscFunctionReturn(PETSC_SUCCESS); 392 } 393 394 PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx) 395 { 396 SNES_Patch *patch = (SNES_Patch *)snes->data; 397 398 PetscFunctionBegin; 399 PetscCall(PCPatchSetComputeFunction(patch->pc, func, ctx)); 400 PetscFunctionReturn(PETSC_SUCCESS); 401 } 402 403 PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx) 404 { 405 SNES_Patch *patch = (SNES_Patch *)snes->data; 406 407 PetscFunctionBegin; 408 PetscCall(PCPatchSetConstructType(patch->pc, ctype, func, ctx)); 409 PetscFunctionReturn(PETSC_SUCCESS); 410 } 411 412 PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering) 413 { 414 SNES_Patch *patch = (SNES_Patch *)snes->data; 415 416 PetscFunctionBegin; 417 PetscCall(PCPatchSetCellNumbering(patch->pc, cellNumbering)); 418 PetscFunctionReturn(PETSC_SUCCESS); 419 } 420