xref: /petsc/src/snes/impls/ngmres/ngmresfunc.c (revision bd89dbf26d8a5efecb980364933175da61864cd7)
1 #include <../src/snes/impls/ngmres/snesngmres.h> /*I "petscsnes.h" I*/
2 #include <petscblaslapack.h>
3 
SNESNGMRESGetAdditiveLineSearch_Private(SNES snes,SNESLineSearch * linesearch)4 PetscErrorCode SNESNGMRESGetAdditiveLineSearch_Private(SNES snes, SNESLineSearch *linesearch)
5 {
6   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
7 
8   PetscFunctionBegin;
9   if (!ngmres->additive_linesearch) {
10     const char *optionsprefix;
11     PetscCall(SNESGetOptionsPrefix(snes, &optionsprefix));
12     PetscCall(SNESLineSearchCreate(PetscObjectComm((PetscObject)snes), &ngmres->additive_linesearch));
13     PetscCall(SNESLineSearchSetSNES(ngmres->additive_linesearch, snes));
14     PetscCall(SNESLineSearchSetType(ngmres->additive_linesearch, SNESLINESEARCHSECANT));
15     PetscCall(SNESLineSearchAppendOptionsPrefix(ngmres->additive_linesearch, "snes_ngmres_additive_"));
16     PetscCall(SNESLineSearchAppendOptionsPrefix(ngmres->additive_linesearch, optionsprefix));
17     PetscCall(PetscObjectIncrementTabLevel((PetscObject)ngmres->additive_linesearch, (PetscObject)snes, 1));
18   }
19   *linesearch = ngmres->additive_linesearch;
20   PetscFunctionReturn(PETSC_SUCCESS);
21 }
22 
SNESNGMRESUpdateSubspace_Private(SNES snes,PetscInt ivec,PetscInt l,Vec F,PetscReal fnorm,Vec X)23 PetscErrorCode SNESNGMRESUpdateSubspace_Private(SNES snes, PetscInt ivec, PetscInt l, Vec F, PetscReal fnorm, Vec X)
24 {
25   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
26   Vec         *Fdot   = ngmres->Fdot;
27   Vec         *Xdot   = ngmres->Xdot;
28 
29   PetscFunctionBegin;
30   PetscCheck(ivec <= l, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "Cannot update vector %" PetscInt_FMT " with space size %" PetscInt_FMT "!", ivec, l);
31   PetscCall(VecCopy(F, Fdot[ivec]));
32   PetscCall(VecCopy(X, Xdot[ivec]));
33 
34   ngmres->fnorms[ivec] = fnorm;
35   PetscFunctionReturn(PETSC_SUCCESS);
36 }
37 
SNESNGMRESFormCombinedSolution_Private(SNES snes,PetscInt ivec,PetscInt l,Vec XM,Vec FM,PetscReal fMnorm,Vec X,Vec XA,Vec FA)38 PetscErrorCode SNESNGMRESFormCombinedSolution_Private(SNES snes, PetscInt ivec, PetscInt l, Vec XM, Vec FM, PetscReal fMnorm, Vec X, Vec XA, Vec FA)
39 {
40   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
41   PetscInt     i, j;
42   Vec         *Fdot       = ngmres->Fdot;
43   Vec         *Xdot       = ngmres->Xdot;
44   PetscScalar *beta       = ngmres->beta;
45   PetscScalar *xi         = ngmres->xi;
46   PetscScalar  alph_total = 0.;
47   PetscReal    nu;
48   Vec          Y = snes->vec_sol_update;
49   PetscBool    changed_y, changed_w;
50 
51   PetscFunctionBegin;
52   nu = fMnorm * fMnorm;
53 
54   /* construct the right-hand side and xi factors */
55   if (l > 0) {
56     PetscCall(VecMDotBegin(FM, l, Fdot, xi));
57     PetscCall(VecMDotBegin(Fdot[ivec], l, Fdot, beta));
58     PetscCall(VecMDotEnd(FM, l, Fdot, xi));
59     PetscCall(VecMDotEnd(Fdot[ivec], l, Fdot, beta));
60     for (i = 0; i < l; i++) {
61       Q(i, ivec) = beta[i];
62       Q(ivec, i) = beta[i];
63     }
64   } else {
65     Q(0, 0) = ngmres->fnorms[ivec] * ngmres->fnorms[ivec];
66   }
67 
68   for (i = 0; i < l; i++) beta[i] = nu - xi[i];
69 
70   /* construct h */
71   for (j = 0; j < l; j++) {
72     for (i = 0; i < l; i++) H(i, j) = Q(i, j) - xi[i] - xi[j] + nu;
73   }
74   if (l == 1) {
75     /* simply set alpha[0] = beta[0] / H[0, 0] */
76     if (H(0, 0) != 0.) beta[0] = beta[0] / H(0, 0);
77     else beta[0] = 0.;
78   } else {
79     PetscCall(PetscBLASIntCast(l, &ngmres->m));
80     PetscCall(PetscBLASIntCast(l, &ngmres->n));
81     ngmres->info  = 0;
82     ngmres->rcond = -1.;
83     PetscCall(PetscFPTrapPush(PETSC_FP_TRAP_OFF));
84 #if defined(PETSC_USE_COMPLEX)
85     PetscCallBLAS("LAPACKgelss", LAPACKgelss_(&ngmres->m, &ngmres->n, &ngmres->nrhs, ngmres->h, &ngmres->lda, ngmres->beta, &ngmres->ldb, ngmres->s, &ngmres->rcond, &ngmres->rank, ngmres->work, &ngmres->lwork, ngmres->rwork, &ngmres->info));
86 #else
87     PetscCallBLAS("LAPACKgelss", LAPACKgelss_(&ngmres->m, &ngmres->n, &ngmres->nrhs, ngmres->h, &ngmres->lda, ngmres->beta, &ngmres->ldb, ngmres->s, &ngmres->rcond, &ngmres->rank, ngmres->work, &ngmres->lwork, &ngmres->info));
88 #endif
89     PetscCall(PetscFPTrapPop());
90     PetscCheck(ngmres->info >= 0, PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "Bad argument to GELSS");
91     PetscCheck(ngmres->info <= 0, PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "SVD failed to converge");
92   }
93   for (i = 0; i < l; i++) PetscCheck(!PetscIsInfOrNanScalar(beta[i]), PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "SVD generated inconsistent output");
94   alph_total = 0.;
95   for (i = 0; i < l; i++) alph_total += beta[i];
96 
97   PetscCall(VecAXPBY(XA, 1.0 - alph_total, 0.0, XM));
98   PetscCall(VecMAXPY(XA, l, beta, Xdot));
99   /* check the validity of the step */
100   PetscCall(VecWAXPY(Y, -1.0, X, XA));
101   PetscCall(SNESLineSearchPostCheck(snes->linesearch, X, Y, XA, &changed_y, &changed_w));
102   if (!ngmres->approxfunc) {
103     if (snes->npc && snes->npcside == PC_LEFT) {
104       PetscCall(SNESApplyNPC(snes, XA, NULL, FA));
105     } else {
106       PetscCall(SNESComputeFunction(snes, XA, FA));
107     }
108   } else {
109     PetscCall(VecAXPBY(FA, 1.0 - alph_total, 0.0, FM));
110     PetscCall(VecMAXPY(FA, l, beta, Fdot));
111   }
112   PetscFunctionReturn(PETSC_SUCCESS);
113 }
114 
SNESNGMRESNorms_Private(SNES snes,PetscInt l,Vec X,Vec F,Vec XM,Vec FM,Vec XA,Vec FA,Vec D,PetscReal * dnorm,PetscReal * dminnorm,PetscReal * xMnorm,PetscReal * fMnorm,PetscReal * yMnorm,PetscReal * xAnorm,PetscReal * fAnorm,PetscReal * yAnorm)115 PetscErrorCode SNESNGMRESNorms_Private(SNES snes, PetscInt l, Vec X, Vec F, Vec XM, Vec FM, Vec XA, Vec FA, Vec D, PetscReal *dnorm, PetscReal *dminnorm, PetscReal *xMnorm, PetscReal *fMnorm, PetscReal *yMnorm, PetscReal *xAnorm, PetscReal *fAnorm, PetscReal *yAnorm)
116 {
117   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
118   PetscReal    dcurnorm, dmin = -1.0;
119   Vec         *Xdot = ngmres->Xdot;
120   PetscInt     i;
121 
122   PetscFunctionBegin;
123   if (xMnorm) PetscCall(VecNormBegin(XM, NORM_2, xMnorm));
124   if (fMnorm) PetscCall(VecNormBegin(FM, NORM_2, fMnorm));
125   if (yMnorm) {
126     PetscCall(VecWAXPY(D, -1.0, XM, X));
127     PetscCall(VecNormBegin(D, NORM_2, yMnorm));
128   }
129   if (xAnorm) PetscCall(VecNormBegin(XA, NORM_2, xAnorm));
130   if (fAnorm) PetscCall(VecNormBegin(FA, NORM_2, fAnorm));
131   if (yAnorm) {
132     PetscCall(VecWAXPY(D, -1.0, XA, X));
133     PetscCall(VecNormBegin(D, NORM_2, yAnorm));
134   }
135   if (dnorm) {
136     PetscCall(VecWAXPY(D, -1.0, XM, XA));
137     PetscCall(VecNormBegin(D, NORM_2, dnorm));
138   }
139   if (dminnorm) {
140     for (i = 0; i < l; i++) {
141       PetscCall(VecWAXPY(D, -1.0, XA, Xdot[i]));
142       PetscCall(VecNormBegin(D, NORM_2, &ngmres->xnorms[i]));
143     }
144   }
145   if (xMnorm) PetscCall(VecNormEnd(XM, NORM_2, xMnorm));
146   if (fMnorm) PetscCall(VecNormEnd(FM, NORM_2, fMnorm));
147   if (yMnorm) PetscCall(VecNormEnd(D, NORM_2, yMnorm));
148   if (xAnorm) PetscCall(VecNormEnd(XA, NORM_2, xAnorm));
149   if (fAnorm) PetscCall(VecNormEnd(FA, NORM_2, fAnorm));
150   if (yAnorm) PetscCall(VecNormEnd(D, NORM_2, yAnorm));
151   if (dnorm) PetscCall(VecNormEnd(D, NORM_2, dnorm));
152   if (dminnorm) {
153     for (i = 0; i < l; i++) {
154       PetscCall(VecNormEnd(D, NORM_2, &ngmres->xnorms[i]));
155       dcurnorm = ngmres->xnorms[i];
156       if ((dcurnorm < dmin) || (dmin < 0.0)) dmin = dcurnorm;
157     }
158     *dminnorm = dmin;
159   }
160   PetscFunctionReturn(PETSC_SUCCESS);
161 }
162 
SNESNGMRESSelect_Private(SNES snes,PetscInt k_restart,Vec XM,Vec FM,PetscReal xMnorm,PetscReal fMnorm,PetscReal yMnorm,PetscReal objM,Vec XA,Vec FA,PetscReal xAnorm,PetscReal fAnorm,PetscReal yAnorm,PetscReal objA,PetscReal dnorm,PetscReal objmin,PetscReal dminnorm,Vec X,Vec F,Vec Y,PetscReal * xnorm,PetscReal * fnorm,PetscReal * ynorm)163 PetscErrorCode SNESNGMRESSelect_Private(SNES snes, PetscInt k_restart, Vec XM, Vec FM, PetscReal xMnorm, PetscReal fMnorm, PetscReal yMnorm, PetscReal objM, Vec XA, Vec FA, PetscReal xAnorm, PetscReal fAnorm, PetscReal yAnorm, PetscReal objA, PetscReal dnorm, PetscReal objmin, PetscReal dminnorm, Vec X, Vec F, Vec Y, PetscReal *xnorm, PetscReal *fnorm, PetscReal *ynorm)
164 {
165   SNES_NGMRES         *ngmres = (SNES_NGMRES *)snes->data;
166   SNESLineSearchReason lsreason;
167   PetscBool            selectA;
168 
169   PetscFunctionBegin;
170   if (ngmres->select_type == SNES_NGMRES_SELECT_LINESEARCH) {
171     /* X = X + \lambda(XA - X) */
172     if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
173     /* Test if is XA - XM is a descent direction: we want < F(XM), XA - XM > not positive
174        If positive, GMRES will be restarted see https://epubs.siam.org/doi/pdf/10.1137/110835530 */
175     PetscCall(VecCopy(FM, F));
176     PetscCall(VecCopy(XM, X));
177     PetscCall(VecWAXPY(Y, -1.0, XA, X));                        /* minus sign since linesearch expects to find Xnew = X - lambda * Y */
178     PetscCall(VecDotRealPart(FM, Y, &ngmres->descent_ls_test)); /* this is actually < F(XM), XM - XA > */
179     *fnorm = fMnorm;
180     if (ngmres->descent_ls_test < 0) { /* XA - XM is not a descent direction, select XM */
181       *xnorm = xMnorm;
182       *fnorm = fMnorm;
183       *ynorm = yMnorm;
184       PetscCall(VecWAXPY(Y, -1.0, X, XM));
185       PetscCall(VecCopy(FM, F));
186       PetscCall(VecCopy(XM, X));
187     } else {
188       PetscCall(SNESNGMRESGetAdditiveLineSearch_Private(snes, &ngmres->additive_linesearch));
189       PetscCall(SNESLineSearchApply(ngmres->additive_linesearch, X, F, fnorm, Y));
190       PetscCall(SNESLineSearchGetReason(ngmres->additive_linesearch, &lsreason));
191       if (lsreason == SNES_LINESEARCH_FAILED_FUNCTION_DOMAIN) {
192         PetscCheck(!snes->errorifnotconverged, PetscObjectComm((PetscObject)snes), PETSC_ERR_NOT_CONVERGED, "SNES solver has not converged");
193         snes->reason = SNES_DIVERGED_FUNCTION_DOMAIN;
194         PetscFunctionReturn(PETSC_SUCCESS);
195       }
196       if (lsreason == SNES_LINESEARCH_FAILED_NANORINF) {
197         PetscCheck(!snes->errorifnotconverged, PetscObjectComm((PetscObject)snes), PETSC_ERR_NOT_CONVERGED, "SNES solver has not converged");
198         snes->reason = SNES_DIVERGED_FUNCTION_NANORINF;
199         PetscFunctionReturn(PETSC_SUCCESS);
200       }
201       if (lsreason && ++snes->numFailures >= snes->maxFailures) {
202         PetscCheck(!snes->errorifnotconverged, PetscObjectComm((PetscObject)snes), PETSC_ERR_NOT_CONVERGED, "SNES solver has not converged");
203         snes->reason = SNES_DIVERGED_LINE_SEARCH;
204         PetscFunctionReturn(PETSC_SUCCESS);
205       }
206       PetscCall(SNESLineSearchGetNorms(ngmres->additive_linesearch, xnorm, fnorm, ynorm));
207     }
208     if (ngmres->monitor) {
209       PetscReal        objT = *fnorm;
210       SNESObjectiveFn *objective;
211 
212       PetscCall(SNESGetObjective(snes, &objective, NULL));
213       if (objective) PetscCall(SNESComputeObjective(snes, X, &objT));
214       PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "Additive solution: objective = %e\n", (double)objT));
215     }
216   } else if (ngmres->select_type == SNES_NGMRES_SELECT_DIFFERENCE) {
217     /* Conditions for choosing the accelerated answer:
218           Criterion A -- the objective function isn't increased above the minimum by too much
219           Criterion B -- the choice of x^A isn't too close to some other choice
220     */
221     selectA = (PetscBool)(/* A */ (objA < ngmres->gammaA * objmin) && /* B */ (ngmres->epsilonB * dnorm < dminnorm || objA < ngmres->deltaB * objmin));
222 
223     if (selectA) {
224       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "picked X_A, obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
225       /* copy it over */
226       *xnorm = xAnorm;
227       *fnorm = fAnorm;
228       *ynorm = yAnorm;
229       PetscCall(VecCopy(FA, F));
230       PetscCall(VecCopy(XA, X));
231     } else {
232       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "picked X_M, obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
233       *xnorm = xMnorm;
234       *fnorm = fMnorm;
235       *ynorm = yMnorm;
236       PetscCall(VecWAXPY(Y, -1.0, X, XM));
237       PetscCall(VecCopy(FM, F));
238       PetscCall(VecCopy(XM, X));
239     }
240   } else { /* none */
241     *xnorm = xAnorm;
242     *fnorm = fAnorm;
243     *ynorm = yAnorm;
244     PetscCall(VecCopy(FA, F));
245     PetscCall(VecCopy(XA, X));
246   }
247   PetscFunctionReturn(PETSC_SUCCESS);
248 }
249 
SNESNGMRESSelectRestart_Private(SNES snes,PetscInt l,PetscReal obj,PetscReal objM,PetscReal objA,PetscReal dnorm,PetscReal objmin,PetscReal dminnorm,PetscBool * selectRestart)250 PetscErrorCode SNESNGMRESSelectRestart_Private(SNES snes, PetscInt l, PetscReal obj, PetscReal objM, PetscReal objA, PetscReal dnorm, PetscReal objmin, PetscReal dminnorm, PetscBool *selectRestart)
251 {
252   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
253 
254   PetscFunctionBegin;
255   *selectRestart = PETSC_FALSE;
256   if (ngmres->select_type == SNES_NGMRES_SELECT_LINESEARCH) {
257     if (ngmres->descent_ls_test < 0) { /* XA - XM is not a descent direction */
258       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "ascent restart: %e > 0\n", (double)-ngmres->descent_ls_test));
259       *selectRestart = PETSC_TRUE;
260     }
261   } else if (ngmres->select_type == SNES_NGMRES_SELECT_DIFFERENCE) {
262     /* difference stagnation restart */
263     if (ngmres->epsilonB * dnorm > dminnorm && objA > ngmres->deltaB * objmin && l > 0) {
264       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "difference restart: %e > %e\n", (double)(ngmres->epsilonB * dnorm), (double)dminnorm));
265       *selectRestart = PETSC_TRUE;
266     }
267     /* residual stagnation restart */
268     if (objA > ngmres->gammaC * objmin) {
269       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "residual restart: %e > %e\n", (double)objA, (double)(ngmres->gammaC * objmin)));
270       *selectRestart = PETSC_TRUE;
271     }
272 
273     /* F_M stagnation restart */
274     if (ngmres->restart_fm_rise && objM > obj) {
275       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "F_M rise restart: %e > %e\n", (double)objM, (double)obj));
276       *selectRestart = PETSC_TRUE;
277     }
278   }
279   PetscFunctionReturn(PETSC_SUCCESS);
280 }
281