xref: /petsc/src/snes/impls/ngmres/ngmresfunc.c (revision f13dfd9ea68e0ddeee984e65c377a1819eab8a8a)
1 #include <../src/snes/impls/ngmres/snesngmres.h> /*I "petscsnes.h" I*/
2 #include <petscblaslapack.h>
3 
4 PetscErrorCode SNESNGMRESGetAdditiveLineSearch_Private(SNES snes, SNESLineSearch *linesearch)
5 {
6   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
7 
8   PetscFunctionBegin;
9   if (!ngmres->additive_linesearch) {
10     const char *optionsprefix;
11     PetscCall(SNESGetOptionsPrefix(snes, &optionsprefix));
12     PetscCall(SNESLineSearchCreate(PetscObjectComm((PetscObject)snes), &ngmres->additive_linesearch));
13     PetscCall(SNESLineSearchSetSNES(ngmres->additive_linesearch, snes));
14     PetscCall(SNESLineSearchSetType(ngmres->additive_linesearch, SNESLINESEARCHL2));
15     PetscCall(SNESLineSearchAppendOptionsPrefix(ngmres->additive_linesearch, "snes_ngmres_additive_"));
16     PetscCall(SNESLineSearchAppendOptionsPrefix(ngmres->additive_linesearch, optionsprefix));
17     PetscCall(PetscObjectIncrementTabLevel((PetscObject)ngmres->additive_linesearch, (PetscObject)snes, 1));
18   }
19   *linesearch = ngmres->additive_linesearch;
20   PetscFunctionReturn(PETSC_SUCCESS);
21 }
22 
23 PetscErrorCode SNESNGMRESUpdateSubspace_Private(SNES snes, PetscInt ivec, PetscInt l, Vec F, PetscReal fnorm, Vec X)
24 {
25   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
26   Vec         *Fdot   = ngmres->Fdot;
27   Vec         *Xdot   = ngmres->Xdot;
28 
29   PetscFunctionBegin;
30   PetscCheck(ivec <= l, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "Cannot update vector %" PetscInt_FMT " with space size %" PetscInt_FMT "!", ivec, l);
31   PetscCall(VecCopy(F, Fdot[ivec]));
32   PetscCall(VecCopy(X, Xdot[ivec]));
33 
34   ngmres->fnorms[ivec] = fnorm;
35   PetscFunctionReturn(PETSC_SUCCESS);
36 }
37 
38 PetscErrorCode SNESNGMRESFormCombinedSolution_Private(SNES snes, PetscInt ivec, PetscInt l, Vec XM, Vec FM, PetscReal fMnorm, Vec X, Vec XA, Vec FA)
39 {
40   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
41   PetscInt     i, j;
42   Vec         *Fdot       = ngmres->Fdot;
43   Vec         *Xdot       = ngmres->Xdot;
44   PetscScalar *beta       = ngmres->beta;
45   PetscScalar *xi         = ngmres->xi;
46   PetscScalar  alph_total = 0.;
47   PetscReal    nu;
48   Vec          Y = snes->vec_sol_update;
49   PetscBool    changed_y, changed_w;
50 
51   PetscFunctionBegin;
52   nu = fMnorm * fMnorm;
53 
54   /* construct the right-hand side and xi factors */
55   if (l > 0) {
56     PetscCall(VecMDotBegin(FM, l, Fdot, xi));
57     PetscCall(VecMDotBegin(Fdot[ivec], l, Fdot, beta));
58     PetscCall(VecMDotEnd(FM, l, Fdot, xi));
59     PetscCall(VecMDotEnd(Fdot[ivec], l, Fdot, beta));
60     for (i = 0; i < l; i++) {
61       Q(i, ivec) = beta[i];
62       Q(ivec, i) = beta[i];
63     }
64   } else {
65     Q(0, 0) = ngmres->fnorms[ivec] * ngmres->fnorms[ivec];
66   }
67 
68   for (i = 0; i < l; i++) beta[i] = nu - xi[i];
69 
70   /* construct h */
71   for (j = 0; j < l; j++) {
72     for (i = 0; i < l; i++) H(i, j) = Q(i, j) - xi[i] - xi[j] + nu;
73   }
74   if (l == 1) {
75     /* simply set alpha[0] = beta[0] / H[0, 0] */
76     if (H(0, 0) != 0.) beta[0] = beta[0] / H(0, 0);
77     else beta[0] = 0.;
78   } else {
79     PetscCall(PetscBLASIntCast(l, &ngmres->m));
80     PetscCall(PetscBLASIntCast(l, &ngmres->n));
81     ngmres->info  = 0;
82     ngmres->rcond = -1.;
83     PetscCall(PetscFPTrapPush(PETSC_FP_TRAP_OFF));
84 #if defined(PETSC_USE_COMPLEX)
85     PetscCallBLAS("LAPACKgelss", LAPACKgelss_(&ngmres->m, &ngmres->n, &ngmres->nrhs, ngmres->h, &ngmres->lda, ngmres->beta, &ngmres->ldb, ngmres->s, &ngmres->rcond, &ngmres->rank, ngmres->work, &ngmres->lwork, ngmres->rwork, &ngmres->info));
86 #else
87     PetscCallBLAS("LAPACKgelss", LAPACKgelss_(&ngmres->m, &ngmres->n, &ngmres->nrhs, ngmres->h, &ngmres->lda, ngmres->beta, &ngmres->ldb, ngmres->s, &ngmres->rcond, &ngmres->rank, ngmres->work, &ngmres->lwork, &ngmres->info));
88 #endif
89     PetscCall(PetscFPTrapPop());
90     PetscCheck(ngmres->info >= 0, PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "Bad argument to GELSS");
91     PetscCheck(ngmres->info <= 0, PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "SVD failed to converge");
92   }
93   for (i = 0; i < l; i++) PetscCheck(!PetscIsInfOrNanScalar(beta[i]), PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "SVD generated inconsistent output");
94   alph_total = 0.;
95   for (i = 0; i < l; i++) alph_total += beta[i];
96 
97   PetscCall(VecCopy(XM, XA));
98   PetscCall(VecScale(XA, 1. - alph_total));
99   PetscCall(VecMAXPY(XA, l, beta, Xdot));
100   /* check the validity of the step */
101   PetscCall(VecCopy(XA, Y));
102   PetscCall(VecAXPY(Y, -1.0, X));
103   PetscCall(SNESLineSearchPostCheck(snes->linesearch, X, Y, XA, &changed_y, &changed_w));
104   if (!ngmres->approxfunc) {
105     if (snes->npc && snes->npcside == PC_LEFT) {
106       PetscCall(SNESApplyNPC(snes, XA, NULL, FA));
107     } else {
108       PetscCall(SNESComputeFunction(snes, XA, FA));
109     }
110   } else {
111     PetscCall(VecCopy(FM, FA));
112     PetscCall(VecScale(FA, 1. - alph_total));
113     PetscCall(VecMAXPY(FA, l, beta, Fdot));
114   }
115   PetscFunctionReturn(PETSC_SUCCESS);
116 }
117 
118 PetscErrorCode SNESNGMRESNorms_Private(SNES snes, PetscInt l, Vec X, Vec F, Vec XM, Vec FM, Vec XA, Vec FA, Vec D, PetscReal *dnorm, PetscReal *dminnorm, PetscReal *xMnorm, PetscReal *fMnorm, PetscReal *yMnorm, PetscReal *xAnorm, PetscReal *fAnorm, PetscReal *yAnorm)
119 {
120   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
121   PetscReal    dcurnorm, dmin = -1.0;
122   Vec         *Xdot = ngmres->Xdot;
123   PetscInt     i;
124 
125   PetscFunctionBegin;
126   if (xMnorm) PetscCall(VecNormBegin(XM, NORM_2, xMnorm));
127   if (fMnorm) PetscCall(VecNormBegin(FM, NORM_2, fMnorm));
128   if (yMnorm) {
129     PetscCall(VecCopy(X, D));
130     PetscCall(VecAXPY(D, -1.0, XM));
131     PetscCall(VecNormBegin(D, NORM_2, yMnorm));
132   }
133   if (xAnorm) PetscCall(VecNormBegin(XA, NORM_2, xAnorm));
134   if (fAnorm) PetscCall(VecNormBegin(FA, NORM_2, fAnorm));
135   if (yAnorm) {
136     PetscCall(VecCopy(X, D));
137     PetscCall(VecAXPY(D, -1.0, XA));
138     PetscCall(VecNormBegin(D, NORM_2, yAnorm));
139   }
140   if (dnorm) {
141     PetscCall(VecCopy(XA, D));
142     PetscCall(VecAXPY(D, -1.0, XM));
143     PetscCall(VecNormBegin(D, NORM_2, dnorm));
144   }
145   if (dminnorm) {
146     for (i = 0; i < l; i++) {
147       PetscCall(VecCopy(Xdot[i], D));
148       PetscCall(VecAXPY(D, -1.0, XA));
149       PetscCall(VecNormBegin(D, NORM_2, &ngmres->xnorms[i]));
150     }
151   }
152   if (xMnorm) PetscCall(VecNormEnd(XM, NORM_2, xMnorm));
153   if (fMnorm) PetscCall(VecNormEnd(FM, NORM_2, fMnorm));
154   if (yMnorm) PetscCall(VecNormEnd(D, NORM_2, yMnorm));
155   if (xAnorm) PetscCall(VecNormEnd(XA, NORM_2, xAnorm));
156   if (fAnorm) PetscCall(VecNormEnd(FA, NORM_2, fAnorm));
157   if (yAnorm) PetscCall(VecNormEnd(D, NORM_2, yAnorm));
158   if (dnorm) PetscCall(VecNormEnd(D, NORM_2, dnorm));
159   if (dminnorm) {
160     for (i = 0; i < l; i++) {
161       PetscCall(VecNormEnd(D, NORM_2, &ngmres->xnorms[i]));
162       dcurnorm = ngmres->xnorms[i];
163       if ((dcurnorm < dmin) || (dmin < 0.0)) dmin = dcurnorm;
164     }
165     *dminnorm = dmin;
166   }
167   PetscFunctionReturn(PETSC_SUCCESS);
168 }
169 
170 PetscErrorCode SNESNGMRESSelect_Private(SNES snes, PetscInt k_restart, Vec XM, Vec FM, PetscReal xMnorm, PetscReal fMnorm, PetscReal yMnorm, PetscReal objM, Vec XA, Vec FA, PetscReal xAnorm, PetscReal fAnorm, PetscReal yAnorm, PetscReal objA, PetscReal dnorm, PetscReal objmin, PetscReal dminnorm, Vec X, Vec F, Vec Y, PetscReal *xnorm, PetscReal *fnorm, PetscReal *ynorm)
171 {
172   SNES_NGMRES         *ngmres = (SNES_NGMRES *)snes->data;
173   SNESLineSearchReason lssucceed;
174   PetscBool            selectA;
175 
176   PetscFunctionBegin;
177   if (ngmres->select_type == SNES_NGMRES_SELECT_LINESEARCH) {
178     /* X = X + \lambda(XA - X) */
179     if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
180     /* Test if is XA - XM is a descent direction: we want < F(XM), XA - XM > not positive
181        If positive, GMRES will be restarted see https://epubs.siam.org/doi/pdf/10.1137/110835530 */
182     PetscCall(VecCopy(FM, F));
183     PetscCall(VecCopy(XM, X));
184     PetscCall(VecWAXPY(Y, -1.0, XA, X));                        /* minus sign since linesearch expects to find Xnew = X - lambda * Y */
185     PetscCall(VecDotRealPart(FM, Y, &ngmres->descent_ls_test)); /* this is actually < F(XM), XM - XA > */
186     *fnorm = fMnorm;
187     if (ngmres->descent_ls_test < 0) { /* XA - XM is not a descent direction, select XM */
188       *xnorm = xMnorm;
189       *fnorm = fMnorm;
190       *ynorm = yMnorm;
191       PetscCall(VecWAXPY(Y, -1.0, X, XM));
192       PetscCall(VecCopy(FM, F));
193       PetscCall(VecCopy(XM, X));
194     } else {
195       PetscCall(SNESNGMRESGetAdditiveLineSearch_Private(snes, &ngmres->additive_linesearch));
196       PetscCall(SNESLineSearchApply(ngmres->additive_linesearch, X, F, fnorm, Y));
197       PetscCall(SNESLineSearchGetReason(ngmres->additive_linesearch, &lssucceed));
198       PetscCall(SNESLineSearchGetNorms(ngmres->additive_linesearch, xnorm, fnorm, ynorm));
199       if (lssucceed) {
200         if (++snes->numFailures >= snes->maxFailures) {
201           snes->reason = SNES_DIVERGED_LINE_SEARCH;
202           PetscFunctionReturn(PETSC_SUCCESS);
203         }
204       }
205     }
206     if (ngmres->monitor) {
207       PetscReal objT = *fnorm;
208       PetscErrorCode (*objective)(SNES, Vec, PetscReal *, void *);
209 
210       PetscCall(SNESGetObjective(snes, &objective, NULL));
211       if (objective) PetscCall(SNESComputeObjective(snes, X, &objT));
212       PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "Additive solution: objective = %e\n", (double)objT));
213     }
214   } else if (ngmres->select_type == SNES_NGMRES_SELECT_DIFFERENCE) {
215     /* Conditions for choosing the accelerated answer:
216           Criterion A -- the objective function isn't increased above the minimum by too much
217           Criterion B -- the choice of x^A isn't too close to some other choice
218     */
219     selectA = (PetscBool)(/* A */ (objA < ngmres->gammaA * objmin) && /* B */ (ngmres->epsilonB * dnorm < dminnorm || objA < ngmres->deltaB * objmin));
220 
221     if (selectA) {
222       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "picked X_A, obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
223       /* copy it over */
224       *xnorm = xAnorm;
225       *fnorm = fAnorm;
226       *ynorm = yAnorm;
227       PetscCall(VecCopy(FA, F));
228       PetscCall(VecCopy(XA, X));
229     } else {
230       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "picked X_M, obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
231       *xnorm = xMnorm;
232       *fnorm = fMnorm;
233       *ynorm = yMnorm;
234       PetscCall(VecWAXPY(Y, -1.0, X, XM));
235       PetscCall(VecCopy(FM, F));
236       PetscCall(VecCopy(XM, X));
237     }
238   } else { /* none */
239     *xnorm = xAnorm;
240     *fnorm = fAnorm;
241     *ynorm = yAnorm;
242     PetscCall(VecCopy(FA, F));
243     PetscCall(VecCopy(XA, X));
244   }
245   PetscFunctionReturn(PETSC_SUCCESS);
246 }
247 
248 PetscErrorCode SNESNGMRESSelectRestart_Private(SNES snes, PetscInt l, PetscReal obj, PetscReal objM, PetscReal objA, PetscReal dnorm, PetscReal objmin, PetscReal dminnorm, PetscBool *selectRestart)
249 {
250   SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
251 
252   PetscFunctionBegin;
253   *selectRestart = PETSC_FALSE;
254   if (ngmres->select_type == SNES_NGMRES_SELECT_LINESEARCH) {
255     if (ngmres->descent_ls_test < 0) { /* XA - XM is not a descent direction */
256       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "ascent restart: %e > 0\n", (double)-ngmres->descent_ls_test));
257       *selectRestart = PETSC_TRUE;
258     }
259   } else if (ngmres->select_type == SNES_NGMRES_SELECT_DIFFERENCE) {
260     /* difference stagnation restart */
261     if (ngmres->epsilonB * dnorm > dminnorm && objA > ngmres->deltaB * objmin && l > 0) {
262       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "difference restart: %e > %e\n", (double)(ngmres->epsilonB * dnorm), (double)dminnorm));
263       *selectRestart = PETSC_TRUE;
264     }
265     /* residual stagnation restart */
266     if (objA > ngmres->gammaC * objmin) {
267       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "residual restart: %e > %e\n", (double)objA, (double)(ngmres->gammaC * objmin)));
268       *selectRestart = PETSC_TRUE;
269     }
270 
271     /* F_M stagnation restart */
272     if (ngmres->restart_fm_rise && objM > obj) {
273       if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "F_M rise restart: %e > %e\n", (double)objM, (double)obj));
274       *selectRestart = PETSC_TRUE;
275     }
276   }
277   PetscFunctionReturn(PETSC_SUCCESS);
278 }
279