1 /* 2 Defines a SNES that can consist of a collection of SNESes on patches of the domain 3 */ 4 #include <petsc/private/vecimpl.h> /* For vec->map */ 5 #include <petsc/private/snesimpl.h> /*I "petscsnes.h" I*/ 6 #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */ 7 #include <petscsf.h> 8 #include <petscsection.h> 9 10 typedef struct { 11 PC pc; /* The linear patch preconditioner */ 12 } SNES_Patch; 13 14 static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx) 15 { 16 PC pc = (PC) ctx; 17 PC_PATCH *pcpatch = (PC_PATCH *) pc->data; 18 PetscInt pt, size, i; 19 const PetscInt *indices; 20 const PetscScalar *X; 21 PetscScalar *XWithAll; 22 23 PetscFunctionBegin; 24 25 /* scatter from x to patch->patchStateWithAll[pt] */ 26 pt = pcpatch->currentPatch; 27 PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size)); 28 29 PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 30 PetscCall(VecGetArrayRead(x, &X)); 31 PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll)); 32 33 for (i = 0; i < size; ++i) { 34 XWithAll[indices[i]] = X[i]; 35 } 36 37 PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll)); 38 PetscCall(VecRestoreArrayRead(x, &X)); 39 PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 40 41 PetscCall(PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt)); 42 PetscFunctionReturn(0); 43 } 44 45 static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx) 46 { 47 PC pc = (PC) ctx; 48 PC_PATCH *pcpatch = (PC_PATCH *) pc->data; 49 PetscInt pt, size, i; 50 const PetscInt *indices; 51 const PetscScalar *X; 52 PetscScalar *XWithAll; 53 54 PetscFunctionBegin; 55 /* scatter from x to patch->patchStateWithAll[pt] */ 56 pt = pcpatch->currentPatch; 57 PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size)); 58 59 PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 60 PetscCall(VecGetArrayRead(x, &X)); 61 PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll)); 62 63 for (i = 0; i < size; ++i) { 64 XWithAll[indices[i]] = X[i]; 65 } 66 67 PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll)); 68 PetscCall(VecRestoreArrayRead(x, &X)); 69 PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices)); 70 71 PetscCall(PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE)); 72 PetscFunctionReturn(0); 73 } 74 75 static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc) 76 { 77 PC_PATCH *patch = (PC_PATCH *) pc->data; 78 const char *prefix; 79 PetscInt i, pStart, dof, maxDof = -1; 80 81 PetscFunctionBegin; 82 if (!pc->setupcalled) { 83 PetscCall(PetscMalloc1(patch->npatch, &patch->solver)); 84 PetscCall(PCGetOptionsPrefix(pc, &prefix)); 85 PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL)); 86 for (i = 0; i < patch->npatch; ++i) { 87 SNES snes; 88 89 PetscCall(SNESCreate(PETSC_COMM_SELF, &snes)); 90 PetscCall(SNESSetOptionsPrefix(snes, prefix)); 91 PetscCall(SNESAppendOptionsPrefix(snes, "sub_")); 92 PetscCall(PetscObjectIncrementTabLevel((PetscObject) snes, (PetscObject) pc, 2)); 93 PetscCall(PetscLogObjectParent((PetscObject) pc, (PetscObject) snes)); 94 patch->solver[i] = (PetscObject) snes; 95 96 PetscCall(PetscSectionGetDof(patch->gtolCountsWithAll, i+pStart, &dof)); 97 maxDof = PetscMax(maxDof, dof); 98 } 99 PetscCall(VecDuplicate(patch->localUpdate, &patch->localState)); 100 PetscCall(VecDuplicate(patch->patchRHS, &patch->patchResidual)); 101 PetscCall(VecDuplicate(patch->patchUpdate, &patch->patchState)); 102 103 PetscCall(VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll)); 104 PetscCall(VecSetUp(patch->patchStateWithAll)); 105 } 106 for (i = 0; i < patch->npatch; ++i) { 107 SNES snes = (SNES) patch->solver[i]; 108 109 PetscCall(SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc)); 110 PetscCall(SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc)); 111 } 112 if (!pc->setupcalled && patch->optionsSet) for (i = 0; i < patch->npatch; ++i) PetscCall(SNESSetFromOptions((SNES) patch->solver[i])); 113 PetscFunctionReturn(0); 114 } 115 116 static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate) 117 { 118 PC_PATCH *patch = (PC_PATCH *) pc->data; 119 PetscInt pStart, n; 120 121 PetscFunctionBegin; 122 patch->currentPatch = i; 123 PetscCall(PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0)); 124 125 /* Scatter the overlapped global state to our patch state vector */ 126 PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL)); 127 PetscCall(PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR)); 128 PetscCall(PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL)); 129 130 PetscCall(MatGetLocalSize(patch->mat[i], NULL, &n)); 131 patch->patchState->map->n = n; 132 patch->patchState->map->N = n; 133 patchUpdate->map->n = n; 134 patchUpdate->map->N = n; 135 patchRHS->map->n = n; 136 patchRHS->map->N = n; 137 /* Set initial guess to be current state*/ 138 PetscCall(VecCopy(patch->patchState, patchUpdate)); 139 /* Solve for new state */ 140 PetscCall(SNESSolve((SNES) patch->solver[i], patchRHS, patchUpdate)); 141 /* To compute update, subtract off previous state */ 142 PetscCall(VecAXPY(patchUpdate, -1.0, patch->patchState)); 143 144 PetscCall(PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0)); 145 PetscFunctionReturn(0); 146 } 147 148 static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc) 149 { 150 PC_PATCH *patch = (PC_PATCH *) pc->data; 151 PetscInt i; 152 153 PetscFunctionBegin; 154 if (patch->solver) { 155 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESReset((SNES) patch->solver[i])); 156 } 157 158 PetscCall(VecDestroy(&patch->patchResidual)); 159 PetscCall(VecDestroy(&patch->patchState)); 160 PetscCall(VecDestroy(&patch->patchStateWithAll)); 161 162 PetscCall(VecDestroy(&patch->localState)); 163 PetscFunctionReturn(0); 164 } 165 166 static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc) 167 { 168 PC_PATCH *patch = (PC_PATCH *) pc->data; 169 PetscInt i; 170 171 PetscFunctionBegin; 172 if (patch->solver) { 173 for (i = 0; i < patch->npatch; ++i) PetscCall(SNESDestroy((SNES *) &patch->solver[i])); 174 PetscCall(PetscFree(patch->solver)); 175 } 176 PetscFunctionReturn(0); 177 } 178 179 static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart) 180 { 181 PC_PATCH *patch = (PC_PATCH *) pc->data; 182 183 PetscFunctionBegin; 184 PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR)); 185 PetscFunctionReturn(0); 186 } 187 188 static PetscErrorCode SNESSetUp_Patch(SNES snes) 189 { 190 SNES_Patch *patch = (SNES_Patch *) snes->data; 191 DM dm; 192 Mat dummy; 193 Vec F; 194 PetscInt n, N; 195 196 PetscFunctionBegin; 197 PetscCall(SNESGetDM(snes, &dm)); 198 PetscCall(PCSetDM(patch->pc, dm)); 199 PetscCall(SNESGetFunction(snes, &F, NULL, NULL)); 200 PetscCall(VecGetLocalSize(F, &n)); 201 PetscCall(VecGetSize(F, &N)); 202 PetscCall(MatCreateShell(PetscObjectComm((PetscObject) snes), n, n, N, N, (void *) snes, &dummy)); 203 PetscCall(PCSetOperators(patch->pc, dummy, dummy)); 204 PetscCall(MatDestroy(&dummy)); 205 PetscCall(PCSetUp(patch->pc)); 206 /* allocate workspace */ 207 PetscFunctionReturn(0); 208 } 209 210 static PetscErrorCode SNESReset_Patch(SNES snes) 211 { 212 SNES_Patch *patch = (SNES_Patch *) snes->data; 213 214 PetscFunctionBegin; 215 PetscCall(PCReset(patch->pc)); 216 PetscFunctionReturn(0); 217 } 218 219 static PetscErrorCode SNESDestroy_Patch(SNES snes) 220 { 221 SNES_Patch *patch = (SNES_Patch *) snes->data; 222 223 PetscFunctionBegin; 224 PetscCall(SNESReset_Patch(snes)); 225 PetscCall(PCDestroy(&patch->pc)); 226 PetscCall(PetscFree(snes->data)); 227 PetscFunctionReturn(0); 228 } 229 230 static PetscErrorCode SNESSetFromOptions_Patch(PetscOptionItems *PetscOptionsObject, SNES snes) 231 { 232 SNES_Patch *patch = (SNES_Patch *) snes->data; 233 const char *prefix; 234 235 PetscFunctionBegin; 236 PetscCall(PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix)); 237 PetscCall(PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix)); 238 PetscCall(PCSetFromOptions(patch->pc)); 239 PetscFunctionReturn(0); 240 } 241 242 static PetscErrorCode SNESView_Patch(SNES snes,PetscViewer viewer) 243 { 244 SNES_Patch *patch = (SNES_Patch *) snes->data; 245 PetscBool iascii; 246 247 PetscFunctionBegin; 248 PetscCall(PetscObjectTypeCompare((PetscObject) viewer, PETSCVIEWERASCII, &iascii)); 249 if (iascii) { 250 PetscCall(PetscViewerASCIIPrintf(viewer,"SNESPATCH\n")); 251 } 252 PetscCall(PetscViewerASCIIPushTab(viewer)); 253 PetscCall(PCView(patch->pc, viewer)); 254 PetscCall(PetscViewerASCIIPopTab(viewer)); 255 PetscFunctionReturn(0); 256 } 257 258 static PetscErrorCode SNESSolve_Patch(SNES snes) 259 { 260 SNES_Patch *patch = (SNES_Patch *) snes->data; 261 PC_PATCH *pcpatch = (PC_PATCH *) patch->pc->data; 262 SNESLineSearch ls; 263 Vec rhs, update, state, residual; 264 const PetscScalar *globalState = NULL; 265 PetscScalar *localState = NULL; 266 PetscInt its = 0; 267 PetscReal xnorm = 0.0, ynorm = 0.0, fnorm = 0.0; 268 269 PetscFunctionBegin; 270 PetscCall(SNESGetSolution(snes, &state)); 271 PetscCall(SNESGetSolutionUpdate(snes, &update)); 272 PetscCall(SNESGetRhs(snes, &rhs)); 273 274 PetscCall(SNESGetFunction(snes, &residual, NULL, NULL)); 275 PetscCall(SNESGetLineSearch(snes, &ls)); 276 277 PetscCall(SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING)); 278 PetscCall(VecSet(update, 0.0)); 279 PetscCall(SNESComputeFunction(snes, state, residual)); 280 281 PetscCall(VecNorm(state, NORM_2, &xnorm)); 282 PetscCall(VecNorm(residual, NORM_2, &fnorm)); 283 snes->ttol = fnorm*snes->rtol; 284 285 if (snes->ops->converged) { 286 PetscCall((*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP)); 287 } else { 288 PetscCall(SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,NULL)); 289 } 290 PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* should we count lits from the patches? */ 291 PetscCall(SNESMonitor(snes, its, fnorm)); 292 293 /* The main solver loop */ 294 for (its = 0; its < snes->max_its; its++) { 295 296 PetscCall(SNESSetIterationNumber(snes, its)); 297 298 /* Scatter state vector to overlapped vector on all patches. 299 The vector pcpatch->localState is scattered to each patch 300 in PCApply_PATCH_Nonlinear. */ 301 PetscCall(VecGetArrayRead(state, &globalState)); 302 PetscCall(VecGetArray(pcpatch->localState, &localState)); 303 PetscCall(PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState,MPI_REPLACE)); 304 PetscCall(PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState,MPI_REPLACE)); 305 PetscCall(VecRestoreArray(pcpatch->localState, &localState)); 306 PetscCall(VecRestoreArrayRead(state, &globalState)); 307 308 /* The looping over patches happens here */ 309 PetscCall(PCApply(patch->pc, rhs, update)); 310 311 /* Apply a line search. This will often be basic with 312 damping = 1/(max number of patches a dof can be in), 313 but not always */ 314 PetscCall(VecScale(update, -1.0)); 315 PetscCall(SNESLineSearchApply(ls, state, residual, &fnorm, update)); 316 317 PetscCall(VecNorm(state, NORM_2, &xnorm)); 318 PetscCall(VecNorm(update, NORM_2, &ynorm)); 319 320 if (snes->ops->converged) { 321 PetscCall((*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP)); 322 } else { 323 PetscCall(SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,NULL)); 324 } 325 PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* FIXME: should we count lits? */ 326 PetscCall(SNESMonitor(snes, its, fnorm)); 327 } 328 329 if (its == snes->max_its) PetscCall(SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT)); 330 PetscFunctionReturn(0); 331 } 332 333 /*MC 334 SNESPATCH - Solve a nonlinear problem by composing together many nonlinear solvers on patches 335 336 Level: intermediate 337 338 .seealso: `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, 339 `PCPATCH` 340 341 References: 342 . * - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015 343 344 M*/ 345 PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes) 346 { 347 SNES_Patch *patch; 348 PC_PATCH *patchpc; 349 SNESLineSearch linesearch; 350 351 PetscFunctionBegin; 352 PetscCall(PetscNewLog(snes, &patch)); 353 354 snes->ops->solve = SNESSolve_Patch; 355 snes->ops->setup = SNESSetUp_Patch; 356 snes->ops->reset = SNESReset_Patch; 357 snes->ops->destroy = SNESDestroy_Patch; 358 snes->ops->setfromoptions = SNESSetFromOptions_Patch; 359 snes->ops->view = SNESView_Patch; 360 361 PetscCall(SNESGetLineSearch(snes,&linesearch)); 362 if (!((PetscObject)linesearch)->type_name) { 363 PetscCall(SNESLineSearchSetType(linesearch,SNESLINESEARCHBASIC)); 364 } 365 snes->usesksp = PETSC_FALSE; 366 367 snes->alwayscomputesfinalresidual = PETSC_FALSE; 368 369 snes->data = (void *) patch; 370 PetscCall(PCCreate(PetscObjectComm((PetscObject) snes), &patch->pc)); 371 PetscCall(PCSetType(patch->pc, PCPATCH)); 372 373 patchpc = (PC_PATCH*) patch->pc->data; 374 patchpc->classname = "snes"; 375 patchpc->isNonlinear = PETSC_TRUE; 376 377 patchpc->setupsolver = PCSetUp_PATCH_Nonlinear; 378 patchpc->applysolver = PCApply_PATCH_Nonlinear; 379 patchpc->resetsolver = PCReset_PATCH_Nonlinear; 380 patchpc->destroysolver = PCDestroy_PATCH_Nonlinear; 381 patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear; 382 383 PetscFunctionReturn(0); 384 } 385 386 PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, 387 const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes) 388 { 389 SNES_Patch *patch = (SNES_Patch *) snes->data; 390 DM dm; 391 392 PetscFunctionBegin; 393 PetscCall(SNESGetDM(snes, &dm)); 394 PetscCheck(dm,PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES"); 395 PetscCall(PCSetDM(patch->pc, dm)); 396 PetscCall(PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes)); 397 PetscFunctionReturn(0); 398 } 399 400 PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx) 401 { 402 SNES_Patch *patch = (SNES_Patch *) snes->data; 403 404 PetscFunctionBegin; 405 PetscCall(PCPatchSetComputeOperator(patch->pc, func, ctx)); 406 PetscFunctionReturn(0); 407 } 408 409 PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx) 410 { 411 SNES_Patch *patch = (SNES_Patch *) snes->data; 412 413 PetscFunctionBegin; 414 PetscCall(PCPatchSetComputeFunction(patch->pc, func, ctx)); 415 PetscFunctionReturn(0); 416 } 417 418 PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx) 419 { 420 SNES_Patch *patch = (SNES_Patch *) snes->data; 421 422 PetscFunctionBegin; 423 PetscCall(PCPatchSetConstructType(patch->pc, ctype, func, ctx)); 424 PetscFunctionReturn(0); 425 } 426 427 PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering) 428 { 429 SNES_Patch *patch = (SNES_Patch *) snes->data; 430 431 PetscFunctionBegin; 432 PetscCall(PCPatchSetCellNumbering(patch->pc, cellNumbering)); 433 PetscFunctionReturn(0); 434 } 435