1 #include <../src/snes/impls/ngmres/snesngmres.h> /*I "petscsnes.h" I*/ 2 #include <petscblaslapack.h> 3 4 PetscErrorCode SNESNGMRESGetAdditiveLineSearch_Private(SNES snes, SNESLineSearch *linesearch) 5 { 6 SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data; 7 8 PetscFunctionBegin; 9 if (!ngmres->additive_linesearch) { 10 const char *optionsprefix; 11 PetscCall(SNESGetOptionsPrefix(snes, &optionsprefix)); 12 PetscCall(SNESLineSearchCreate(PetscObjectComm((PetscObject)snes), &ngmres->additive_linesearch)); 13 PetscCall(SNESLineSearchSetSNES(ngmres->additive_linesearch, snes)); 14 PetscCall(SNESLineSearchSetType(ngmres->additive_linesearch, SNESLINESEARCHSECANT)); 15 PetscCall(SNESLineSearchAppendOptionsPrefix(ngmres->additive_linesearch, "snes_ngmres_additive_")); 16 PetscCall(SNESLineSearchAppendOptionsPrefix(ngmres->additive_linesearch, optionsprefix)); 17 PetscCall(PetscObjectIncrementTabLevel((PetscObject)ngmres->additive_linesearch, (PetscObject)snes, 1)); 18 } 19 *linesearch = ngmres->additive_linesearch; 20 PetscFunctionReturn(PETSC_SUCCESS); 21 } 22 23 PetscErrorCode SNESNGMRESUpdateSubspace_Private(SNES snes, PetscInt ivec, PetscInt l, Vec F, PetscReal fnorm, Vec X) 24 { 25 SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data; 26 Vec *Fdot = ngmres->Fdot; 27 Vec *Xdot = ngmres->Xdot; 28 29 PetscFunctionBegin; 30 PetscCheck(ivec <= l, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "Cannot update vector %" PetscInt_FMT " with space size %" PetscInt_FMT "!", ivec, l); 31 PetscCall(VecCopy(F, Fdot[ivec])); 32 PetscCall(VecCopy(X, Xdot[ivec])); 33 34 ngmres->fnorms[ivec] = fnorm; 35 PetscFunctionReturn(PETSC_SUCCESS); 36 } 37 38 PetscErrorCode SNESNGMRESFormCombinedSolution_Private(SNES snes, PetscInt ivec, PetscInt l, Vec XM, Vec FM, PetscReal fMnorm, Vec X, Vec XA, Vec FA) 39 { 40 SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data; 41 PetscInt i, j; 42 Vec *Fdot = ngmres->Fdot; 43 Vec *Xdot = ngmres->Xdot; 44 PetscScalar *beta = ngmres->beta; 45 PetscScalar *xi = ngmres->xi; 46 PetscScalar alph_total = 0.; 47 PetscReal nu; 48 Vec Y = snes->vec_sol_update; 49 PetscBool changed_y, changed_w; 50 51 PetscFunctionBegin; 52 nu = fMnorm * fMnorm; 53 54 /* construct the right-hand side and xi factors */ 55 if (l > 0) { 56 PetscCall(VecMDotBegin(FM, l, Fdot, xi)); 57 PetscCall(VecMDotBegin(Fdot[ivec], l, Fdot, beta)); 58 PetscCall(VecMDotEnd(FM, l, Fdot, xi)); 59 PetscCall(VecMDotEnd(Fdot[ivec], l, Fdot, beta)); 60 for (i = 0; i < l; i++) { 61 Q(i, ivec) = beta[i]; 62 Q(ivec, i) = beta[i]; 63 } 64 } else { 65 Q(0, 0) = ngmres->fnorms[ivec] * ngmres->fnorms[ivec]; 66 } 67 68 for (i = 0; i < l; i++) beta[i] = nu - xi[i]; 69 70 /* construct h */ 71 for (j = 0; j < l; j++) { 72 for (i = 0; i < l; i++) H(i, j) = Q(i, j) - xi[i] - xi[j] + nu; 73 } 74 if (l == 1) { 75 /* simply set alpha[0] = beta[0] / H[0, 0] */ 76 if (H(0, 0) != 0.) beta[0] = beta[0] / H(0, 0); 77 else beta[0] = 0.; 78 } else { 79 PetscCall(PetscBLASIntCast(l, &ngmres->m)); 80 PetscCall(PetscBLASIntCast(l, &ngmres->n)); 81 ngmres->info = 0; 82 ngmres->rcond = -1.; 83 PetscCall(PetscFPTrapPush(PETSC_FP_TRAP_OFF)); 84 #if defined(PETSC_USE_COMPLEX) 85 PetscCallBLAS("LAPACKgelss", LAPACKgelss_(&ngmres->m, &ngmres->n, &ngmres->nrhs, ngmres->h, &ngmres->lda, ngmres->beta, &ngmres->ldb, ngmres->s, &ngmres->rcond, &ngmres->rank, ngmres->work, &ngmres->lwork, ngmres->rwork, &ngmres->info)); 86 #else 87 PetscCallBLAS("LAPACKgelss", LAPACKgelss_(&ngmres->m, &ngmres->n, &ngmres->nrhs, ngmres->h, &ngmres->lda, ngmres->beta, &ngmres->ldb, ngmres->s, &ngmres->rcond, &ngmres->rank, ngmres->work, &ngmres->lwork, &ngmres->info)); 88 #endif 89 PetscCall(PetscFPTrapPop()); 90 PetscCheck(ngmres->info >= 0, PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "Bad argument to GELSS"); 91 PetscCheck(ngmres->info <= 0, PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "SVD failed to converge"); 92 } 93 for (i = 0; i < l; i++) PetscCheck(!PetscIsInfOrNanScalar(beta[i]), PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "SVD generated inconsistent output"); 94 alph_total = 0.; 95 for (i = 0; i < l; i++) alph_total += beta[i]; 96 97 PetscCall(VecAXPBY(XA, 1.0 - alph_total, 0.0, XM)); 98 PetscCall(VecMAXPY(XA, l, beta, Xdot)); 99 /* check the validity of the step */ 100 PetscCall(VecWAXPY(Y, -1.0, X, XA)); 101 PetscCall(SNESLineSearchPostCheck(snes->linesearch, X, Y, XA, &changed_y, &changed_w)); 102 if (!ngmres->approxfunc) { 103 if (snes->npc && snes->npcside == PC_LEFT) { 104 PetscCall(SNESApplyNPC(snes, XA, NULL, FA)); 105 } else { 106 PetscCall(SNESComputeFunction(snes, XA, FA)); 107 } 108 } else { 109 PetscCall(VecAXPBY(FA, 1.0 - alph_total, 0.0, FM)); 110 PetscCall(VecMAXPY(FA, l, beta, Fdot)); 111 } 112 PetscFunctionReturn(PETSC_SUCCESS); 113 } 114 115 PetscErrorCode SNESNGMRESNorms_Private(SNES snes, PetscInt l, Vec X, Vec F, Vec XM, Vec FM, Vec XA, Vec FA, Vec D, PetscReal *dnorm, PetscReal *dminnorm, PetscReal *xMnorm, PetscReal *fMnorm, PetscReal *yMnorm, PetscReal *xAnorm, PetscReal *fAnorm, PetscReal *yAnorm) 116 { 117 SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data; 118 PetscReal dcurnorm, dmin = -1.0; 119 Vec *Xdot = ngmres->Xdot; 120 PetscInt i; 121 122 PetscFunctionBegin; 123 if (xMnorm) PetscCall(VecNormBegin(XM, NORM_2, xMnorm)); 124 if (fMnorm) PetscCall(VecNormBegin(FM, NORM_2, fMnorm)); 125 if (yMnorm) { 126 PetscCall(VecWAXPY(D, -1.0, XM, X)); 127 PetscCall(VecNormBegin(D, NORM_2, yMnorm)); 128 } 129 if (xAnorm) PetscCall(VecNormBegin(XA, NORM_2, xAnorm)); 130 if (fAnorm) PetscCall(VecNormBegin(FA, NORM_2, fAnorm)); 131 if (yAnorm) { 132 PetscCall(VecWAXPY(D, -1.0, XA, X)); 133 PetscCall(VecNormBegin(D, NORM_2, yAnorm)); 134 } 135 if (dnorm) { 136 PetscCall(VecWAXPY(D, -1.0, XM, XA)); 137 PetscCall(VecNormBegin(D, NORM_2, dnorm)); 138 } 139 if (dminnorm) { 140 for (i = 0; i < l; i++) { 141 PetscCall(VecWAXPY(D, -1.0, XA, Xdot[i])); 142 PetscCall(VecNormBegin(D, NORM_2, &ngmres->xnorms[i])); 143 } 144 } 145 if (xMnorm) PetscCall(VecNormEnd(XM, NORM_2, xMnorm)); 146 if (fMnorm) PetscCall(VecNormEnd(FM, NORM_2, fMnorm)); 147 if (yMnorm) PetscCall(VecNormEnd(D, NORM_2, yMnorm)); 148 if (xAnorm) PetscCall(VecNormEnd(XA, NORM_2, xAnorm)); 149 if (fAnorm) PetscCall(VecNormEnd(FA, NORM_2, fAnorm)); 150 if (yAnorm) PetscCall(VecNormEnd(D, NORM_2, yAnorm)); 151 if (dnorm) PetscCall(VecNormEnd(D, NORM_2, dnorm)); 152 if (dminnorm) { 153 for (i = 0; i < l; i++) { 154 PetscCall(VecNormEnd(D, NORM_2, &ngmres->xnorms[i])); 155 dcurnorm = ngmres->xnorms[i]; 156 if ((dcurnorm < dmin) || (dmin < 0.0)) dmin = dcurnorm; 157 } 158 *dminnorm = dmin; 159 } 160 PetscFunctionReturn(PETSC_SUCCESS); 161 } 162 163 PetscErrorCode SNESNGMRESSelect_Private(SNES snes, PetscInt k_restart, Vec XM, Vec FM, PetscReal xMnorm, PetscReal fMnorm, PetscReal yMnorm, PetscReal objM, Vec XA, Vec FA, PetscReal xAnorm, PetscReal fAnorm, PetscReal yAnorm, PetscReal objA, PetscReal dnorm, PetscReal objmin, PetscReal dminnorm, Vec X, Vec F, Vec Y, PetscReal *xnorm, PetscReal *fnorm, PetscReal *ynorm) 164 { 165 SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data; 166 SNESLineSearchReason lssucceed; 167 PetscBool selectA; 168 169 PetscFunctionBegin; 170 if (ngmres->select_type == SNES_NGMRES_SELECT_LINESEARCH) { 171 /* X = X + \lambda(XA - X) */ 172 if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm)); 173 /* Test if is XA - XM is a descent direction: we want < F(XM), XA - XM > not positive 174 If positive, GMRES will be restarted see https://epubs.siam.org/doi/pdf/10.1137/110835530 */ 175 PetscCall(VecCopy(FM, F)); 176 PetscCall(VecCopy(XM, X)); 177 PetscCall(VecWAXPY(Y, -1.0, XA, X)); /* minus sign since linesearch expects to find Xnew = X - lambda * Y */ 178 PetscCall(VecDotRealPart(FM, Y, &ngmres->descent_ls_test)); /* this is actually < F(XM), XM - XA > */ 179 *fnorm = fMnorm; 180 if (ngmres->descent_ls_test < 0) { /* XA - XM is not a descent direction, select XM */ 181 *xnorm = xMnorm; 182 *fnorm = fMnorm; 183 *ynorm = yMnorm; 184 PetscCall(VecWAXPY(Y, -1.0, X, XM)); 185 PetscCall(VecCopy(FM, F)); 186 PetscCall(VecCopy(XM, X)); 187 } else { 188 PetscCall(SNESNGMRESGetAdditiveLineSearch_Private(snes, &ngmres->additive_linesearch)); 189 PetscCall(SNESLineSearchApply(ngmres->additive_linesearch, X, F, fnorm, Y)); 190 PetscCall(SNESLineSearchGetReason(ngmres->additive_linesearch, &lssucceed)); 191 PetscCall(SNESLineSearchGetNorms(ngmres->additive_linesearch, xnorm, fnorm, ynorm)); 192 if (lssucceed) { 193 if (++snes->numFailures >= snes->maxFailures) { 194 snes->reason = SNES_DIVERGED_LINE_SEARCH; 195 PetscFunctionReturn(PETSC_SUCCESS); 196 } 197 } 198 } 199 if (ngmres->monitor) { 200 PetscReal objT = *fnorm; 201 SNESObjectiveFn *objective; 202 203 PetscCall(SNESGetObjective(snes, &objective, NULL)); 204 if (objective) PetscCall(SNESComputeObjective(snes, X, &objT)); 205 PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "Additive solution: objective = %e\n", (double)objT)); 206 } 207 } else if (ngmres->select_type == SNES_NGMRES_SELECT_DIFFERENCE) { 208 /* Conditions for choosing the accelerated answer: 209 Criterion A -- the objective function isn't increased above the minimum by too much 210 Criterion B -- the choice of x^A isn't too close to some other choice 211 */ 212 selectA = (PetscBool)(/* A */ (objA < ngmres->gammaA * objmin) && /* B */ (ngmres->epsilonB * dnorm < dminnorm || objA < ngmres->deltaB * objmin)); 213 214 if (selectA) { 215 if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "picked X_A, obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm)); 216 /* copy it over */ 217 *xnorm = xAnorm; 218 *fnorm = fAnorm; 219 *ynorm = yAnorm; 220 PetscCall(VecCopy(FA, F)); 221 PetscCall(VecCopy(XA, X)); 222 } else { 223 if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "picked X_M, obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm)); 224 *xnorm = xMnorm; 225 *fnorm = fMnorm; 226 *ynorm = yMnorm; 227 PetscCall(VecWAXPY(Y, -1.0, X, XM)); 228 PetscCall(VecCopy(FM, F)); 229 PetscCall(VecCopy(XM, X)); 230 } 231 } else { /* none */ 232 *xnorm = xAnorm; 233 *fnorm = fAnorm; 234 *ynorm = yAnorm; 235 PetscCall(VecCopy(FA, F)); 236 PetscCall(VecCopy(XA, X)); 237 } 238 PetscFunctionReturn(PETSC_SUCCESS); 239 } 240 241 PetscErrorCode SNESNGMRESSelectRestart_Private(SNES snes, PetscInt l, PetscReal obj, PetscReal objM, PetscReal objA, PetscReal dnorm, PetscReal objmin, PetscReal dminnorm, PetscBool *selectRestart) 242 { 243 SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data; 244 245 PetscFunctionBegin; 246 *selectRestart = PETSC_FALSE; 247 if (ngmres->select_type == SNES_NGMRES_SELECT_LINESEARCH) { 248 if (ngmres->descent_ls_test < 0) { /* XA - XM is not a descent direction */ 249 if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "ascent restart: %e > 0\n", (double)-ngmres->descent_ls_test)); 250 *selectRestart = PETSC_TRUE; 251 } 252 } else if (ngmres->select_type == SNES_NGMRES_SELECT_DIFFERENCE) { 253 /* difference stagnation restart */ 254 if (ngmres->epsilonB * dnorm > dminnorm && objA > ngmres->deltaB * objmin && l > 0) { 255 if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "difference restart: %e > %e\n", (double)(ngmres->epsilonB * dnorm), (double)dminnorm)); 256 *selectRestart = PETSC_TRUE; 257 } 258 /* residual stagnation restart */ 259 if (objA > ngmres->gammaC * objmin) { 260 if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "residual restart: %e > %e\n", (double)objA, (double)(ngmres->gammaC * objmin))); 261 *selectRestart = PETSC_TRUE; 262 } 263 264 /* F_M stagnation restart */ 265 if (ngmres->restart_fm_rise && objM > obj) { 266 if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "F_M rise restart: %e > %e\n", (double)objM, (double)obj)); 267 *selectRestart = PETSC_TRUE; 268 } 269 } 270 PetscFunctionReturn(PETSC_SUCCESS); 271 } 272