xref: /petsc/src/snes/impls/ngmres/ngmresfunc.c (revision 2ff79c18c26c94ed8cb599682f680f231dca6444)
1 #include <../src/snes/impls/ngmres/snesngmres.h> /*I "petscsnes.h" I*/
2 #include <petscblaslapack.h>
3 
4 PetscErrorCode SNESNGMRESGetAdditiveLineSearch_Private(SNES snes, SNESLineSearch *linesearch)
5 {
6   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
7 
8   PetscFunctionBegin;
9   if (!ngmres->additive_linesearch) {
10     const char *optionsprefix;
11     PetscCall(SNESGetOptionsPrefix(snes, &optionsprefix));
12     PetscCall(SNESLineSearchCreate(PetscObjectComm((PetscObject)snes), &ngmres->additive_linesearch));
13     PetscCall(SNESLineSearchSetSNES(ngmres->additive_linesearch, snes));
14     PetscCall(SNESLineSearchSetType(ngmres->additive_linesearch, SNESLINESEARCHSECANT));
15     PetscCall(SNESLineSearchAppendOptionsPrefix(ngmres->additive_linesearch, "snes_ngmres_additive_"));
16     PetscCall(SNESLineSearchAppendOptionsPrefix(ngmres->additive_linesearch, optionsprefix));
17     PetscCall(PetscObjectIncrementTabLevel((PetscObject)ngmres->additive_linesearch, (PetscObject)snes, 1));
18   }
19   *linesearch = ngmres->additive_linesearch;
20   PetscFunctionReturn(PETSC_SUCCESS);
21 }
22 
23 PetscErrorCode SNESNGMRESUpdateSubspace_Private(SNES snes, PetscInt ivec, PetscInt l, Vec F, PetscReal fnorm, Vec X)
24 {
25   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
26   Vec         *Fdot   = ngmres->Fdot;
27   Vec         *Xdot   = ngmres->Xdot;
28 
29   PetscFunctionBegin;
30   PetscCheck(ivec <= l, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "Cannot update vector %" PetscInt_FMT " with space size %" PetscInt_FMT "!", ivec, l);
31   PetscCall(VecCopy(F, Fdot[ivec]));
32   PetscCall(VecCopy(X, Xdot[ivec]));
33 
34   ngmres->fnorms[ivec] = fnorm;
35   PetscFunctionReturn(PETSC_SUCCESS);
36 }
37 
38 PetscErrorCode SNESNGMRESFormCombinedSolution_Private(SNES snes, PetscInt ivec, PetscInt l, Vec XM, Vec FM, PetscReal fMnorm, Vec X, Vec XA, Vec FA)
39 {
40   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
41   PetscInt     i, j;
42   Vec         *Fdot       = ngmres->Fdot;
43   Vec         *Xdot       = ngmres->Xdot;
44   PetscScalar *beta       = ngmres->beta;
45   PetscScalar *xi         = ngmres->xi;
46   PetscScalar  alph_total = 0.;
47   PetscReal    nu;
48   Vec          Y = snes->vec_sol_update;
49   PetscBool    changed_y, changed_w;
50 
51   PetscFunctionBegin;
52   nu = fMnorm * fMnorm;
53 
54   /* construct the right-hand side and xi factors */
55   if (l > 0) {
56     PetscCall(VecMDotBegin(FM, l, Fdot, xi));
57     PetscCall(VecMDotBegin(Fdot[ivec], l, Fdot, beta));
58     PetscCall(VecMDotEnd(FM, l, Fdot, xi));
59     PetscCall(VecMDotEnd(Fdot[ivec], l, Fdot, beta));
60     for (i = 0; i < l; i++) {
61       Q(i, ivec) = beta[i];
62       Q(ivec, i) = beta[i];
63     }
64   } else {
65     Q(0, 0) = ngmres->fnorms[ivec] * ngmres->fnorms[ivec];
66   }
67 
68   for (i = 0; i < l; i++) beta[i] = nu - xi[i];
69 
70   /* construct h */
71   for (j = 0; j < l; j++) {
72     for (i = 0; i < l; i++) H(i, j) = Q(i, j) - xi[i] - xi[j] + nu;
73   }
74   if (l == 1) {
75     /* simply set alpha[0] = beta[0] / H[0, 0] */
76     if (H(0, 0) != 0.) beta[0] = beta[0] / H(0, 0);
77     else beta[0] = 0.;
78   } else {
79     PetscCall(PetscBLASIntCast(l, &ngmres->m));
80     PetscCall(PetscBLASIntCast(l, &ngmres->n));
81     ngmres->info  = 0;
82     ngmres->rcond = -1.;
83     PetscCall(PetscFPTrapPush(PETSC_FP_TRAP_OFF));
84 #if defined(PETSC_USE_COMPLEX)
85     PetscCallBLAS("LAPACKgelss", LAPACKgelss_(&ngmres->m, &ngmres->n, &ngmres->nrhs, ngmres->h, &ngmres->lda, ngmres->beta, &ngmres->ldb, ngmres->s, &ngmres->rcond, &ngmres->rank, ngmres->work, &ngmres->lwork, ngmres->rwork, &ngmres->info));
86 #else
87     PetscCallBLAS("LAPACKgelss", LAPACKgelss_(&ngmres->m, &ngmres->n, &ngmres->nrhs, ngmres->h, &ngmres->lda, ngmres->beta, &ngmres->ldb, ngmres->s, &ngmres->rcond, &ngmres->rank, ngmres->work, &ngmres->lwork, &ngmres->info));
88 #endif
89     PetscCall(PetscFPTrapPop());
90     PetscCheck(ngmres->info >= 0, PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "Bad argument to GELSS");
91     PetscCheck(ngmres->info <= 0, PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "SVD failed to converge");
92   }
93   for (i = 0; i < l; i++) PetscCheck(!PetscIsInfOrNanScalar(beta[i]), PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "SVD generated inconsistent output");
94   alph_total = 0.;
95   for (i = 0; i < l; i++) alph_total += beta[i];
96 
97   PetscCall(VecAXPBY(XA, 1.0 - alph_total, 0.0, XM));
98   PetscCall(VecMAXPY(XA, l, beta, Xdot));
99   /* check the validity of the step */
100   PetscCall(VecWAXPY(Y, -1.0, X, XA));
101   PetscCall(SNESLineSearchPostCheck(snes->linesearch, X, Y, XA, &changed_y, &changed_w));
102   if (!ngmres->approxfunc) {
103     if (snes->npc && snes->npcside == PC_LEFT) {
104       PetscCall(SNESApplyNPC(snes, XA, NULL, FA));
105     } else {
106       PetscCall(SNESComputeFunction(snes, XA, FA));
107     }
108   } else {
109     PetscCall(VecAXPBY(FA, 1.0 - alph_total, 0.0, FM));
110     PetscCall(VecMAXPY(FA, l, beta, Fdot));
111   }
112   PetscFunctionReturn(PETSC_SUCCESS);
113 }
114 
115 PetscErrorCode SNESNGMRESNorms_Private(SNES snes, PetscInt l, Vec X, Vec F, Vec XM, Vec FM, Vec XA, Vec FA, Vec D, PetscReal *dnorm, PetscReal *dminnorm, PetscReal *xMnorm, PetscReal *fMnorm, PetscReal *yMnorm, PetscReal *xAnorm, PetscReal *fAnorm, PetscReal *yAnorm)
116 {
117   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
118   PetscReal    dcurnorm, dmin = -1.0;
119   Vec         *Xdot = ngmres->Xdot;
120   PetscInt     i;
121 
122   PetscFunctionBegin;
123   if (xMnorm) PetscCall(VecNormBegin(XM, NORM_2, xMnorm));
124   if (fMnorm) PetscCall(VecNormBegin(FM, NORM_2, fMnorm));
125   if (yMnorm) {
126     PetscCall(VecWAXPY(D, -1.0, XM, X));
127     PetscCall(VecNormBegin(D, NORM_2, yMnorm));
128   }
129   if (xAnorm) PetscCall(VecNormBegin(XA, NORM_2, xAnorm));
130   if (fAnorm) PetscCall(VecNormBegin(FA, NORM_2, fAnorm));
131   if (yAnorm) {
132     PetscCall(VecWAXPY(D, -1.0, XA, X));
133     PetscCall(VecNormBegin(D, NORM_2, yAnorm));
134   }
135   if (dnorm) {
136     PetscCall(VecWAXPY(D, -1.0, XM, XA));
137     PetscCall(VecNormBegin(D, NORM_2, dnorm));
138   }
139   if (dminnorm) {
140     for (i = 0; i < l; i++) {
141       PetscCall(VecWAXPY(D, -1.0, XA, Xdot[i]));
142       PetscCall(VecNormBegin(D, NORM_2, &ngmres->xnorms[i]));
143     }
144   }
145   if (xMnorm) PetscCall(VecNormEnd(XM, NORM_2, xMnorm));
146   if (fMnorm) PetscCall(VecNormEnd(FM, NORM_2, fMnorm));
147   if (yMnorm) PetscCall(VecNormEnd(D, NORM_2, yMnorm));
148   if (xAnorm) PetscCall(VecNormEnd(XA, NORM_2, xAnorm));
149   if (fAnorm) PetscCall(VecNormEnd(FA, NORM_2, fAnorm));
150   if (yAnorm) PetscCall(VecNormEnd(D, NORM_2, yAnorm));
151   if (dnorm) PetscCall(VecNormEnd(D, NORM_2, dnorm));
152   if (dminnorm) {
153     for (i = 0; i < l; i++) {
154       PetscCall(VecNormEnd(D, NORM_2, &ngmres->xnorms[i]));
155       dcurnorm = ngmres->xnorms[i];
156       if ((dcurnorm < dmin) || (dmin < 0.0)) dmin = dcurnorm;
157     }
158     *dminnorm = dmin;
159   }
160   PetscFunctionReturn(PETSC_SUCCESS);
161 }
162 
163 PetscErrorCode SNESNGMRESSelect_Private(SNES snes, PetscInt k_restart, Vec XM, Vec FM, PetscReal xMnorm, PetscReal fMnorm, PetscReal yMnorm, PetscReal objM, Vec XA, Vec FA, PetscReal xAnorm, PetscReal fAnorm, PetscReal yAnorm, PetscReal objA, PetscReal dnorm, PetscReal objmin, PetscReal dminnorm, Vec X, Vec F, Vec Y, PetscReal *xnorm, PetscReal *fnorm, PetscReal *ynorm)
164 {
165   SNES_NGMRES         *ngmres = (SNES_NGMRES *)snes->data;
166   SNESLineSearchReason lssucceed;
167   PetscBool            selectA;
168 
169   PetscFunctionBegin;
170   if (ngmres->select_type == SNES_NGMRES_SELECT_LINESEARCH) {
171     /* X = X + \lambda(XA - X) */
172     if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
173     /* Test if is XA - XM is a descent direction: we want < F(XM), XA - XM > not positive
174        If positive, GMRES will be restarted see https://epubs.siam.org/doi/pdf/10.1137/110835530 */
175     PetscCall(VecCopy(FM, F));
176     PetscCall(VecCopy(XM, X));
177     PetscCall(VecWAXPY(Y, -1.0, XA, X));                        /* minus sign since linesearch expects to find Xnew = X - lambda * Y */
178     PetscCall(VecDotRealPart(FM, Y, &ngmres->descent_ls_test)); /* this is actually < F(XM), XM - XA > */
179     *fnorm = fMnorm;
180     if (ngmres->descent_ls_test < 0) { /* XA - XM is not a descent direction, select XM */
181       *xnorm = xMnorm;
182       *fnorm = fMnorm;
183       *ynorm = yMnorm;
184       PetscCall(VecWAXPY(Y, -1.0, X, XM));
185       PetscCall(VecCopy(FM, F));
186       PetscCall(VecCopy(XM, X));
187     } else {
188       PetscCall(SNESNGMRESGetAdditiveLineSearch_Private(snes, &ngmres->additive_linesearch));
189       PetscCall(SNESLineSearchApply(ngmres->additive_linesearch, X, F, fnorm, Y));
190       PetscCall(SNESLineSearchGetReason(ngmres->additive_linesearch, &lssucceed));
191       PetscCall(SNESLineSearchGetNorms(ngmres->additive_linesearch, xnorm, fnorm, ynorm));
192       if (lssucceed) {
193         if (++snes->numFailures >= snes->maxFailures) {
194           snes->reason = SNES_DIVERGED_LINE_SEARCH;
195           PetscFunctionReturn(PETSC_SUCCESS);
196         }
197       }
198     }
199     if (ngmres->monitor) {
200       PetscReal        objT = *fnorm;
201       SNESObjectiveFn *objective;
202 
203       PetscCall(SNESGetObjective(snes, &objective, NULL));
204       if (objective) PetscCall(SNESComputeObjective(snes, X, &objT));
205       PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "Additive solution: objective = %e\n", (double)objT));
206     }
207   } else if (ngmres->select_type == SNES_NGMRES_SELECT_DIFFERENCE) {
208     /* Conditions for choosing the accelerated answer:
209           Criterion A -- the objective function isn't increased above the minimum by too much
210           Criterion B -- the choice of x^A isn't too close to some other choice
211     */
212     selectA = (PetscBool)(/* A */ (objA < ngmres->gammaA * objmin) && /* B */ (ngmres->epsilonB * dnorm < dminnorm || objA < ngmres->deltaB * objmin));
213 
214     if (selectA) {
215       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "picked X_A, obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
216       /* copy it over */
217       *xnorm = xAnorm;
218       *fnorm = fAnorm;
219       *ynorm = yAnorm;
220       PetscCall(VecCopy(FA, F));
221       PetscCall(VecCopy(XA, X));
222     } else {
223       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "picked X_M, obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
224       *xnorm = xMnorm;
225       *fnorm = fMnorm;
226       *ynorm = yMnorm;
227       PetscCall(VecWAXPY(Y, -1.0, X, XM));
228       PetscCall(VecCopy(FM, F));
229       PetscCall(VecCopy(XM, X));
230     }
231   } else { /* none */
232     *xnorm = xAnorm;
233     *fnorm = fAnorm;
234     *ynorm = yAnorm;
235     PetscCall(VecCopy(FA, F));
236     PetscCall(VecCopy(XA, X));
237   }
238   PetscFunctionReturn(PETSC_SUCCESS);
239 }
240 
241 PetscErrorCode SNESNGMRESSelectRestart_Private(SNES snes, PetscInt l, PetscReal obj, PetscReal objM, PetscReal objA, PetscReal dnorm, PetscReal objmin, PetscReal dminnorm, PetscBool *selectRestart)
242 {
243   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
244 
245   PetscFunctionBegin;
246   *selectRestart = PETSC_FALSE;
247   if (ngmres->select_type == SNES_NGMRES_SELECT_LINESEARCH) {
248     if (ngmres->descent_ls_test < 0) { /* XA - XM is not a descent direction */
249       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "ascent restart: %e > 0\n", (double)-ngmres->descent_ls_test));
250       *selectRestart = PETSC_TRUE;
251     }
252   } else if (ngmres->select_type == SNES_NGMRES_SELECT_DIFFERENCE) {
253     /* difference stagnation restart */
254     if (ngmres->epsilonB * dnorm > dminnorm && objA > ngmres->deltaB * objmin && l > 0) {
255       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "difference restart: %e > %e\n", (double)(ngmres->epsilonB * dnorm), (double)dminnorm));
256       *selectRestart = PETSC_TRUE;
257     }
258     /* residual stagnation restart */
259     if (objA > ngmres->gammaC * objmin) {
260       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "residual restart: %e > %e\n", (double)objA, (double)(ngmres->gammaC * objmin)));
261       *selectRestart = PETSC_TRUE;
262     }
263 
264     /* F_M stagnation restart */
265     if (ngmres->restart_fm_rise && objM > obj) {
266       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "F_M rise restart: %e > %e\n", (double)objM, (double)obj));
267       *selectRestart = PETSC_TRUE;
268     }
269   }
270   PetscFunctionReturn(PETSC_SUCCESS);
271 }
272