1 /* 2 Defines a SNES that can consist of a collection of SNESes on patches of the domain 3 */ 4 #include <petsc/private/vecimpl.h> /* For vec->map */ 5 #include <petsc/private/snesimpl.h> /*I "petscsnes.h" I*/ 6 #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */ 7 #include <petscsf.h> 8 #include <petscsection.h> 9 10 typedef struct { 11 PC pc; /* The linear patch preconditioner */ 12 } SNES_Patch; 13 14 static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx) 15 { 16 PC pc = (PC)ctx; 17 PC_PATCH *pcpatch = (PC_PATCH *)pc->data; 18 PetscInt pt, size, i; 19 const PetscInt *indices; 20 const PetscScalar *X; 21 PetscScalar *XWithAll; 22 23 PetscFunctionBegin; 24 25 /* scatter from x to patch->patchStateWithAll[pt] */ 26 pt = pcpatch->currentPatch; 27 PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size)); 28 29 PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 30 PetscCall(VecGetArrayRead(x, &X)); 31 PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll)); 32 33 for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i]; 34 35 PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll)); 36 PetscCall(VecRestoreArrayRead(x, &X)); 37 PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 38 39 PetscCall(PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt)); 40 PetscFunctionReturn(0); 41 } 42 43 static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx) 44 { 45 PC pc = (PC)ctx; 46 PC_PATCH *pcpatch = (PC_PATCH *)pc->data; 47 PetscInt pt, size, i; 48 const PetscInt *indices; 49 const PetscScalar *X; 50 PetscScalar *XWithAll; 51 52 PetscFunctionBegin; 53 /* scatter from x to patch->patchStateWithAll[pt] */ 54 pt = pcpatch->currentPatch; 55 PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size)); 56 57 PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 58 PetscCall(VecGetArrayRead(x, &X)); 59 PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll)); 60 61 for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i]; 62 63 PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll)); 64 PetscCall(VecRestoreArrayRead(x, &X)); 65 PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 66 67 PetscCall(PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE)); 68 PetscFunctionReturn(0); 69 } 70 71 static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc) 72 { 73 PC_PATCH *patch = (PC_PATCH *)pc->data; 74 const char *prefix; 75 PetscInt i, pStart, dof, maxDof = -1; 76 77 PetscFunctionBegin; 78 if (!pc->setupcalled) { 79 PetscCall(PetscMalloc1(patch->npatch, &patch->solver)); 80 PetscCall(PCGetOptionsPrefix(pc, &prefix)); 81 PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL)); 82 for (i = 0; i < patch->npatch; ++i) { 83 SNES snes; 84 85 PetscCall(SNESCreate(PETSC_COMM_SELF, &snes)); 86 PetscCall(SNESSetOptionsPrefix(snes, prefix)); 87 PetscCall(SNESAppendOptionsPrefix(snes, "sub_")); 88 PetscCall(PetscObjectIncrementTabLevel((PetscObject)snes, (PetscObject)pc, 2)); 89 patch->solver[i] = (PetscObject)snes; 90 91 PetscCall(PetscSectionGetDof(patch->gtolCountsWithAll, i + pStart, &dof)); 92 maxDof = PetscMax(maxDof, dof); 93 } 94 PetscCall(VecDuplicate(patch->localUpdate, &patch->localState)); 95 PetscCall(VecDuplicate(patch->patchRHS, &patch->patchResidual)); 96 PetscCall(VecDuplicate(patch->patchUpdate, &patch->patchState)); 97 98 PetscCall(VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll)); 99 PetscCall(VecSetUp(patch->patchStateWithAll)); 100 } 101 for (i = 0; i < patch->npatch; ++i) { 102 SNES snes = (SNES)patch->solver[i]; 103 104 PetscCall(SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc)); 105 PetscCall(SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc)); 106 } 107 if (!pc->setupcalled && patch->optionsSet) 108 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESSetFromOptions((SNES)patch->solver[i])); 109 PetscFunctionReturn(0); 110 } 111 112 static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate) 113 { 114 PC_PATCH *patch = (PC_PATCH *)pc->data; 115 PetscInt pStart, n; 116 117 PetscFunctionBegin; 118 patch->currentPatch = i; 119 PetscCall(PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0)); 120 121 /* Scatter the overlapped global state to our patch state vector */ 122 PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL)); 123 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR)); 124 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL)); 125 126 PetscCall(MatGetLocalSize(patch->mat[i], NULL, &n)); 127 patch->patchState->map->n = n; 128 patch->patchState->map->N = n; 129 patchUpdate->map->n = n; 130 patchUpdate->map->N = n; 131 patchRHS->map->n = n; 132 patchRHS->map->N = n; 133 /* Set initial guess to be current state*/ 134 PetscCall(VecCopy(patch->patchState, patchUpdate)); 135 /* Solve for new state */ 136 PetscCall(SNESSolve((SNES)patch->solver[i], patchRHS, patchUpdate)); 137 /* To compute update, subtract off previous state */ 138 PetscCall(VecAXPY(patchUpdate, -1.0, patch->patchState)); 139 140 PetscCall(PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0)); 141 PetscFunctionReturn(0); 142 } 143 144 static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc) 145 { 146 PC_PATCH *patch = (PC_PATCH *)pc->data; 147 PetscInt i; 148 149 PetscFunctionBegin; 150 if (patch->solver) { 151 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESReset((SNES)patch->solver[i])); 152 } 153 154 PetscCall(VecDestroy(&patch->patchResidual)); 155 PetscCall(VecDestroy(&patch->patchState)); 156 PetscCall(VecDestroy(&patch->patchStateWithAll)); 157 158 PetscCall(VecDestroy(&patch->localState)); 159 PetscFunctionReturn(0); 160 } 161 162 static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc) 163 { 164 PC_PATCH *patch = (PC_PATCH *)pc->data; 165 PetscInt i; 166 167 PetscFunctionBegin; 168 if (patch->solver) { 169 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESDestroy((SNES *)&patch->solver[i])); 170 PetscCall(PetscFree(patch->solver)); 171 } 172 PetscFunctionReturn(0); 173 } 174 175 static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart) 176 { 177 PC_PATCH *patch = (PC_PATCH *)pc->data; 178 179 PetscFunctionBegin; 180 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR)); 181 PetscFunctionReturn(0); 182 } 183 184 static PetscErrorCode SNESSetUp_Patch(SNES snes) 185 { 186 SNES_Patch *patch = (SNES_Patch *)snes->data; 187 DM dm; 188 Mat dummy; 189 Vec F; 190 PetscInt n, N; 191 192 PetscFunctionBegin; 193 PetscCall(SNESGetDM(snes, &dm)); 194 PetscCall(PCSetDM(patch->pc, dm)); 195 PetscCall(SNESGetFunction(snes, &F, NULL, NULL)); 196 PetscCall(VecGetLocalSize(F, &n)); 197 PetscCall(VecGetSize(F, &N)); 198 PetscCall(MatCreateShell(PetscObjectComm((PetscObject)snes), n, n, N, N, (void *)snes, &dummy)); 199 PetscCall(PCSetOperators(patch->pc, dummy, dummy)); 200 PetscCall(MatDestroy(&dummy)); 201 PetscCall(PCSetUp(patch->pc)); 202 /* allocate workspace */ 203 PetscFunctionReturn(0); 204 } 205 206 static PetscErrorCode SNESReset_Patch(SNES snes) 207 { 208 SNES_Patch *patch = (SNES_Patch *)snes->data; 209 210 PetscFunctionBegin; 211 PetscCall(PCReset(patch->pc)); 212 PetscFunctionReturn(0); 213 } 214 215 static PetscErrorCode SNESDestroy_Patch(SNES snes) 216 { 217 SNES_Patch *patch = (SNES_Patch *)snes->data; 218 219 PetscFunctionBegin; 220 PetscCall(SNESReset_Patch(snes)); 221 PetscCall(PCDestroy(&patch->pc)); 222 PetscCall(PetscFree(snes->data)); 223 PetscFunctionReturn(0); 224 } 225 226 static PetscErrorCode SNESSetFromOptions_Patch(SNES snes, PetscOptionItems *PetscOptionsObject) 227 { 228 SNES_Patch *patch = (SNES_Patch *)snes->data; 229 const char *prefix; 230 231 PetscFunctionBegin; 232 PetscCall(PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix)); 233 PetscCall(PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix)); 234 PetscCall(PCSetFromOptions(patch->pc)); 235 PetscFunctionReturn(0); 236 } 237 238 static PetscErrorCode SNESView_Patch(SNES snes, PetscViewer viewer) 239 { 240 SNES_Patch *patch = (SNES_Patch *)snes->data; 241 PetscBool iascii; 242 243 PetscFunctionBegin; 244 PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii)); 245 if (iascii) PetscCall(PetscViewerASCIIPrintf(viewer, "SNESPATCH\n")); 246 PetscCall(PetscViewerASCIIPushTab(viewer)); 247 PetscCall(PCView(patch->pc, viewer)); 248 PetscCall(PetscViewerASCIIPopTab(viewer)); 249 PetscFunctionReturn(0); 250 } 251 252 static PetscErrorCode SNESSolve_Patch(SNES snes) 253 { 254 SNES_Patch *patch = (SNES_Patch *)snes->data; 255 PC_PATCH *pcpatch = (PC_PATCH *)patch->pc->data; 256 SNESLineSearch ls; 257 Vec rhs, update, state, residual; 258 const PetscScalar *globalState = NULL; 259 PetscScalar *localState = NULL; 260 PetscInt its = 0; 261 PetscReal xnorm = 0.0, ynorm = 0.0, fnorm = 0.0; 262 263 PetscFunctionBegin; 264 PetscCall(SNESGetSolution(snes, &state)); 265 PetscCall(SNESGetSolutionUpdate(snes, &update)); 266 PetscCall(SNESGetRhs(snes, &rhs)); 267 268 PetscCall(SNESGetFunction(snes, &residual, NULL, NULL)); 269 PetscCall(SNESGetLineSearch(snes, &ls)); 270 271 PetscCall(SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING)); 272 PetscCall(VecSet(update, 0.0)); 273 PetscCall(SNESComputeFunction(snes, state, residual)); 274 275 PetscCall(VecNorm(state, NORM_2, &xnorm)); 276 PetscCall(VecNorm(residual, NORM_2, &fnorm)); 277 snes->ttol = fnorm * snes->rtol; 278 279 if (snes->ops->converged) { 280 PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP); 281 } else { 282 PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL)); 283 } 284 PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* should we count lits from the patches? */ 285 PetscCall(SNESMonitor(snes, its, fnorm)); 286 287 /* The main solver loop */ 288 for (its = 0; its < snes->max_its; its++) { 289 PetscCall(SNESSetIterationNumber(snes, its)); 290 291 /* Scatter state vector to overlapped vector on all patches. 292 The vector pcpatch->localState is scattered to each patch 293 in PCApply_PATCH_Nonlinear. */ 294 PetscCall(VecGetArrayRead(state, &globalState)); 295 PetscCall(VecGetArray(pcpatch->localState, &localState)); 296 PetscCall(PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE)); 297 PetscCall(PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE)); 298 PetscCall(VecRestoreArray(pcpatch->localState, &localState)); 299 PetscCall(VecRestoreArrayRead(state, &globalState)); 300 301 /* The looping over patches happens here */ 302 PetscCall(PCApply(patch->pc, rhs, update)); 303 304 /* Apply a line search. This will often be basic with 305 damping = 1/(max number of patches a dof can be in), 306 but not always */ 307 PetscCall(VecScale(update, -1.0)); 308 PetscCall(SNESLineSearchApply(ls, state, residual, &fnorm, update)); 309 310 PetscCall(VecNorm(state, NORM_2, &xnorm)); 311 PetscCall(VecNorm(update, NORM_2, &ynorm)); 312 313 if (snes->ops->converged) { 314 PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP); 315 } else { 316 PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL)); 317 } 318 PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* FIXME: should we count lits? */ 319 PetscCall(SNESMonitor(snes, its, fnorm)); 320 } 321 322 if (its == snes->max_its) PetscCall(SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT)); 323 PetscFunctionReturn(0); 324 } 325 326 /*MC 327 SNESPATCH - Solve a nonlinear problem or apply a nonlinear smoother by composing together many nonlinear solvers on (often overlapping) patches 328 329 Level: intermediate 330 331 References: 332 . * - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015 333 334 .seealso: `SNESFAS`, `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, `PCPATCH` 335 M*/ 336 PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes) 337 { 338 SNES_Patch *patch; 339 PC_PATCH *patchpc; 340 SNESLineSearch linesearch; 341 342 PetscFunctionBegin; 343 PetscCall(PetscNew(&patch)); 344 345 snes->ops->solve = SNESSolve_Patch; 346 snes->ops->setup = SNESSetUp_Patch; 347 snes->ops->reset = SNESReset_Patch; 348 snes->ops->destroy = SNESDestroy_Patch; 349 snes->ops->setfromoptions = SNESSetFromOptions_Patch; 350 snes->ops->view = SNESView_Patch; 351 352 PetscCall(SNESGetLineSearch(snes, &linesearch)); 353 if (!((PetscObject)linesearch)->type_name) PetscCall(SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC)); 354 snes->usesksp = PETSC_FALSE; 355 356 snes->alwayscomputesfinalresidual = PETSC_FALSE; 357 358 snes->data = (void *)patch; 359 PetscCall(PCCreate(PetscObjectComm((PetscObject)snes), &patch->pc)); 360 PetscCall(PCSetType(patch->pc, PCPATCH)); 361 362 patchpc = (PC_PATCH *)patch->pc->data; 363 patchpc->classname = "snes"; 364 patchpc->isNonlinear = PETSC_TRUE; 365 366 patchpc->setupsolver = PCSetUp_PATCH_Nonlinear; 367 patchpc->applysolver = PCApply_PATCH_Nonlinear; 368 patchpc->resetsolver = PCReset_PATCH_Nonlinear; 369 patchpc->destroysolver = PCDestroy_PATCH_Nonlinear; 370 patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear; 371 372 PetscFunctionReturn(0); 373 } 374 375 PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes) 376 { 377 SNES_Patch *patch = (SNES_Patch *)snes->data; 378 DM dm; 379 380 PetscFunctionBegin; 381 PetscCall(SNESGetDM(snes, &dm)); 382 PetscCheck(dm, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES"); 383 PetscCall(PCSetDM(patch->pc, dm)); 384 PetscCall(PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes)); 385 PetscFunctionReturn(0); 386 } 387 388 PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx) 389 { 390 SNES_Patch *patch = (SNES_Patch *)snes->data; 391 392 PetscFunctionBegin; 393 PetscCall(PCPatchSetComputeOperator(patch->pc, func, ctx)); 394 PetscFunctionReturn(0); 395 } 396 397 PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx) 398 { 399 SNES_Patch *patch = (SNES_Patch *)snes->data; 400 401 PetscFunctionBegin; 402 PetscCall(PCPatchSetComputeFunction(patch->pc, func, ctx)); 403 PetscFunctionReturn(0); 404 } 405 406 PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx) 407 { 408 SNES_Patch *patch = (SNES_Patch *)snes->data; 409 410 PetscFunctionBegin; 411 PetscCall(PCPatchSetConstructType(patch->pc, ctype, func, ctx)); 412 PetscFunctionReturn(0); 413 } 414 415 PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering) 416 { 417 SNES_Patch *patch = (SNES_Patch *)snes->data; 418 419 PetscFunctionBegin; 420 PetscCall(PCPatchSetCellNumbering(patch->pc, cellNumbering)); 421 PetscFunctionReturn(0); 422 } 423