xref: /petsc/src/ksp/ksp/utils/lmvm/sr1/sr1.c (revision 8577b683712d1cca1e9b8fdaa9ae028364224dad)
1 #include <../src/ksp/ksp/utils/lmvm/lmvm.h> /*I "petscksp.h" I*/
2 
3 /*
4   Limited-memory Symmetric-Rank-1 method for approximating both
5   the forward product and inverse application of a Jacobian.
6 */
7 
8 // bases needed by SR1 algorithms beyond those in Mat_LMVM
9 enum {
10   SR1_BASIS_Y_MINUS_BKS = 0, // Y_k - B_k S_k for recursive algorithms
11   SR1_BASIS_S_MINUS_HKY = 1, // dual to the above, S_k - H_k Y_k
12   SR1_BASIS_COUNT
13 };
14 
15 typedef PetscInt SR1BasisType;
16 
17 // products needed by SR1 algorithms beyond those in Mat_LMVM
18 enum {
19   SR1_PRODUCTS_YTS_MINUS_STB0S = 0, // stores and factors symm(triu((Y - B_0 S)^T S)) for compact algorithms
20   SR1_PRODUCTS_STY_MINUS_YTH0Y = 1, // dual to the above, stores and factors symm(triu((S - H_0 Y)^T Y))
21   SR1_PRODUCTS_YTS_MINUS_STBKS = 2, // diagonal (Y_k - B_k S_k)^T S_k values for recursive algorithms
22   SR1_PRODUCTS_STY_MINUS_YTHKY = 3, // dual to the above, diagonal (S_k - H_k Y_k)^T Y_k
23   SR1_PRODUCTS_COUNT
24 };
25 
26 typedef PetscInt SR1ProductsType;
27 
28 typedef struct {
29   LMBasis    basis[SR1_BASIS_COUNT];
30   LMProducts products[SR1_PRODUCTS_COUNT];
31   Vec        StFprev, SmH0YtFprev;
32 } Mat_LSR1;
33 
34 /* The SR1 kernel can be written as
35 
36      B_{k+1} = B_k + (y_k - B_k s_k) ((y_k - B_k s_k)^T s_k)^{-1} (y_k - B_k s_k)^T
37 
38    this unrolls to a rank-m update
39 
40      B_{k+1} = B_0 + \sum_{i = k-m+1}^k (y_i - B_i s_i) ((y_i - B_i s_i)^T s_i)^{-1} (y_i - B_i s_i)^T
41 
42    This inner kernel assumes the (y_i - B_i s_i) vectors and the ((y_i - B_i s_i)^T s_i) products have been computed
43  */
44 
SR1Kernel_Recursive_Inner(Mat B,MatLMVMMode mode,PetscInt oldest,PetscInt next,Vec X,Vec BX)45 static PetscErrorCode SR1Kernel_Recursive_Inner(Mat B, MatLMVMMode mode, PetscInt oldest, PetscInt next, Vec X, Vec BX)
46 {
47   Mat_LMVM       *lmvm              = (Mat_LMVM *)B->data;
48   Mat_LSR1       *lsr1              = (Mat_LSR1 *)lmvm->ctx;
49   SR1BasisType    Y_minus_BkS_t     = LMVMModeMap(SR1_BASIS_Y_MINUS_BKS, mode);
50   SR1ProductsType YtS_minus_StBkS_t = LMVMModeMap(SR1_PRODUCTS_STY_MINUS_YTHKY, mode);
51   LMBasis         Y_minus_BkS       = lsr1->basis[Y_minus_BkS_t];
52   LMProducts      YtS_minus_StBkS   = lsr1->products[YtS_minus_StBkS_t];
53   Vec             YmBkStX;
54 
55   PetscFunctionBegin;
56   PetscCall(MatLMVMGetWorkRow(B, &YmBkStX));
57   PetscCall(LMBasisGEMVH(Y_minus_BkS, oldest, next, 1.0, X, 0.0, YmBkStX));
58   PetscCall(LMProductsSolve(YtS_minus_StBkS, oldest, next, YmBkStX, YmBkStX, /* ^H */ PETSC_FALSE));
59   PetscCall(LMBasisGEMV(Y_minus_BkS, oldest, next, 1.0, YmBkStX, 1.0, BX));
60   PetscCall(MatLMVMRestoreWorkRow(B, &YmBkStX));
61   PetscFunctionReturn(PETSC_SUCCESS);
62 }
63 
64 /* Recursively compute the (y_i - B_i s_i) vectors and ((y_i - B_i s_i)^T s_i) products */
65 
SR1RecursiveBasisUpdate(Mat B,MatLMVMMode mode)66 static PetscErrorCode SR1RecursiveBasisUpdate(Mat B, MatLMVMMode mode)
67 {
68   Mat_LMVM        *lmvm              = (Mat_LMVM *)B->data;
69   Mat_LSR1        *lsr1              = (Mat_LSR1 *)lmvm->ctx;
70   MatLMVMBasisType B0S_t             = LMVMModeMap(LMBASIS_B0S, mode);
71   MatLMVMBasisType S_t               = LMVMModeMap(LMBASIS_S, mode);
72   MatLMVMBasisType Y_t               = LMVMModeMap(LMBASIS_Y, mode);
73   SR1BasisType     Y_minus_BkS_t     = LMVMModeMap(SR1_BASIS_Y_MINUS_BKS, mode);
74   SR1ProductsType  YtS_minus_StBkS_t = LMVMModeMap(SR1_PRODUCTS_STY_MINUS_YTHKY, mode);
75   LMBasis          Y_minus_BkS;
76   LMProducts       YtS_minus_StBkS;
77   PetscInt         oldest, next;
78   PetscInt         products_oldest;
79   LMBasis          S, Y;
80   PetscInt         start;
81 
82   PetscFunctionBegin;
83   if (!lsr1->basis[Y_minus_BkS_t]) PetscCall(LMBasisCreate(mode == MATLMVM_MODE_PRIMAL ? lmvm->Fprev : lmvm->Xprev, lmvm->m, &lsr1->basis[Y_minus_BkS_t]));
84   Y_minus_BkS = lsr1->basis[Y_minus_BkS_t];
85   if (!lsr1->products[YtS_minus_StBkS_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_DIAGONAL, &lsr1->products[YtS_minus_StBkS_t]));
86   YtS_minus_StBkS = lsr1->products[YtS_minus_StBkS_t];
87   PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
88   PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
89   PetscCall(MatLMVMGetRange(B, &oldest, &next));
90   // invalidate computed values if J0 has changed
91   PetscCall(LMProductsPrepare(YtS_minus_StBkS, lmvm->J0, oldest, next));
92   products_oldest = PetscMax(0, YtS_minus_StBkS->k - lmvm->m);
93   if (oldest > products_oldest) {
94     // recursion is starting from a different starting index, it must be recomputed
95     YtS_minus_StBkS->k = oldest;
96   }
97   Y_minus_BkS->k = start = YtS_minus_StBkS->k;
98   // recompute each column in Y_minus_BkS in order
99   for (PetscInt j = start; j < next; j++) {
100     Vec         s_j, B0s_j, p_j, y_j;
101     PetscScalar alpha, ymbksts;
102 
103     PetscCall(LMBasisGetWorkVec(Y_minus_BkS, &p_j));
104 
105     // p_j starts as B_0 * s_j
106     PetscCall(MatLMVMBasisGetVecRead(B, B0S_t, j, &B0s_j, &alpha));
107     PetscCall(VecAXPBY(p_j, alpha, 0.0, B0s_j));
108     PetscCall(MatLMVMBasisRestoreVecRead(B, B0S_t, j, &B0s_j, &alpha));
109 
110     // Use the matmult kernel to compute p_j = B_j * p_j
111     PetscCall(LMBasisGetVecRead(S, j, &s_j));
112     // if j == oldest p_j is already correct
113     if (j > oldest) PetscCall(SR1Kernel_Recursive_Inner(B, mode, oldest, j, s_j, p_j));
114     PetscCall(LMBasisGetVecRead(Y, j, &y_j));
115     PetscCall(VecAYPX(p_j, -1.0, y_j));
116     PetscCall(VecDot(s_j, p_j, &ymbksts));
117     PetscCall(LMProductsInsertNextDiagonalValue(YtS_minus_StBkS, j, ymbksts));
118     PetscCall(LMBasisRestoreVecRead(S, j, &s_j));
119     PetscCall(LMBasisRestoreVecRead(Y, j, &y_j));
120     PetscCall(LMBasisSetNextVec(Y_minus_BkS, p_j));
121     PetscCall(LMBasisRestoreWorkVec(Y_minus_BkS, &p_j));
122   }
123   PetscFunctionReturn(PETSC_SUCCESS);
124 }
125 
SR1Kernel_Recursive(Mat B,MatLMVMMode mode,Vec X,Vec BX)126 static PetscErrorCode SR1Kernel_Recursive(Mat B, MatLMVMMode mode, Vec X, Vec BX)
127 {
128   PetscInt oldest, next;
129 
130   PetscFunctionBegin;
131   PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, BX));
132   PetscCall(MatLMVMGetRange(B, &oldest, &next));
133   if (next > oldest) {
134     PetscCall(SR1RecursiveBasisUpdate(B, mode));
135     PetscCall(SR1Kernel_Recursive_Inner(B, mode, oldest, next, X, BX));
136   }
137   PetscFunctionReturn(PETSC_SUCCESS);
138 }
139 
140 /* The SR1 kernel can be written as (See Byrd, Schnabel & Nocedal 1994)
141 
142      B_{k+1} = B_0 + (Y - B_0 S) (diag(S^T Y) + stril(S^T Y) + stril(S^T Y)^T - S^T B_0 S)^{-1} (Y - B_0 S)^T
143                                  \___________________________ ___________________________/
144                                                              V
145                                                              M
146 
147    M is symmetric indefinite (stril is the strictly lower triangular part)
148 
149    M can be computed by computed triu((Y - B_0 S)^T S) and filling in the lower triangle
150  */
151 
SR1CompactProductsUpdate(Mat B,MatLMVMMode mode)152 static PetscErrorCode SR1CompactProductsUpdate(Mat B, MatLMVMMode mode)
153 {
154   Mat_LMVM        *lmvm              = (Mat_LMVM *)B->data;
155   Mat_LSR1        *lsr1              = (Mat_LSR1 *)lmvm->ctx;
156   MatLMVMBasisType S_t               = LMVMModeMap(LMBASIS_S, mode);
157   MatLMVMBasisType YmB0S_t           = LMVMModeMap(LMBASIS_Y_MINUS_B0S, mode);
158   SR1ProductsType  YtS_minus_StB0S_t = LMVMModeMap(SR1_PRODUCTS_YTS_MINUS_STB0S, mode);
159   LMProducts       YtS_minus_StB0S;
160   Mat              local;
161   PetscInt         oldest, next, k;
162   PetscBool        local_is_nonempty;
163 
164   PetscFunctionBegin;
165   if (!lsr1->products[YtS_minus_StB0S_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_FULL, &lsr1->products[YtS_minus_StB0S_t]));
166   YtS_minus_StB0S = lsr1->products[YtS_minus_StB0S_t];
167   PetscCall(MatLMVMGetRange(B, &oldest, &next));
168   PetscCall(LMProductsPrepare(YtS_minus_StB0S, lmvm->J0, oldest, next));
169   PetscCall(LMProductsGetLocalMatrix(YtS_minus_StB0S, &local, &k, &local_is_nonempty));
170   if (YtS_minus_StB0S->k < next) {
171     // copy to factor in place
172     LMProducts YmB0StS;
173     Mat        ymb0sts_local;
174 
175     PetscCall(PetscCitationsRegister(ByrdNocedalSchnabelCitation, &ByrdNocedalSchnabelCite));
176     YtS_minus_StB0S->k = next;
177     PetscCall(MatLMVMGetUpdatedProducts(B, YmB0S_t, S_t, LMBLOCK_UPPER_TRIANGLE, &YmB0StS));
178     PetscCall(LMProductsGetLocalMatrix(YmB0StS, &ymb0sts_local, NULL, NULL));
179     if (local_is_nonempty) {
180       PetscErrorCode ierr;
181 
182       PetscCall(MatSetUnfactored(local));
183       PetscCall(MatCopy(ymb0sts_local, local, SAME_NONZERO_PATTERN));
184       PetscCall(LMProductsMakeHermitian(local, oldest, next));
185       PetscCall(LMProductsOnesOnUnusedDiagonal(local, oldest, next));
186       PetscCall(MatSetOption(local, MAT_HERMITIAN, PETSC_TRUE));
187       // Set not spd so that "Cholesky" factorization is actually the symmetric indefinite Bunch Kaufman factorization
188       PetscCall(MatSetOption(local, MAT_SPD, PETSC_FALSE));
189 
190       PetscCall(PetscPushErrorHandler(PetscReturnErrorHandler, NULL));
191       ierr = MatCholeskyFactor(local, NULL, NULL);
192       PetscCall(PetscPopErrorHandler());
193       PetscCheck(ierr == PETSC_SUCCESS || ierr == PETSC_ERR_SUP, PETSC_COMM_SELF, ierr, "Error in Bunch-Kaufman factorization");
194       // cusolver does not provide Bunch Kaufman, resort to LU if it is unavailable
195       if (ierr == PETSC_ERR_SUP) PetscCall(MatLUFactor(local, NULL, NULL, NULL));
196     }
197     PetscCall(LMProductsRestoreLocalMatrix(YmB0StS, &ymb0sts_local, NULL));
198   }
199   PetscCall(LMProductsRestoreLocalMatrix(YtS_minus_StB0S, &local, &next));
200   PetscFunctionReturn(PETSC_SUCCESS);
201 }
202 
SR1Kernel_CompactDense(Mat B,MatLMVMMode mode,Vec X,Vec BX)203 static PetscErrorCode SR1Kernel_CompactDense(Mat B, MatLMVMMode mode, Vec X, Vec BX)
204 {
205   PetscInt oldest, next;
206 
207   PetscFunctionBegin;
208   PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, BX));
209   PetscCall(MatLMVMGetRange(B, &oldest, &next));
210   if (next > oldest) {
211     Mat_LMVM        *lmvm              = (Mat_LMVM *)B->data;
212     Mat_LSR1        *lsr1              = (Mat_LSR1 *)lmvm->ctx;
213     MatLMVMBasisType Y_minus_B0S_t     = LMVMModeMap(LMBASIS_Y_MINUS_B0S, mode);
214     SR1ProductsType  YtS_minus_StB0S_t = LMVMModeMap(SR1_PRODUCTS_YTS_MINUS_STB0S, mode);
215     LMProducts       YtS_minus_StB0S;
216     Vec              YmB0StX, v;
217 
218     PetscCall(SR1CompactProductsUpdate(B, mode));
219     YtS_minus_StB0S = lsr1->products[YtS_minus_StB0S_t];
220     PetscCall(MatLMVMGetWorkRow(B, &YmB0StX));
221     PetscCall(MatLMVMGetWorkRow(B, &v));
222     if (lmvm->do_not_cache_J0_products) {
223       /* the initial (Y - B_0 S)^T x inner product can be computed as Y^T x - S^T (B_0 x)
224          if we are not caching B_0 S products */
225       MatLMVMBasisType S_t = LMVMModeMap(LMBASIS_S, mode);
226       MatLMVMBasisType Y_t = LMVMModeMap(LMBASIS_Y, mode);
227       LMBasis          S, Y;
228 
229       PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
230       PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
231       PetscCall(LMBasisGEMVH(Y, oldest, next, 1.0, X, 0.0, YmB0StX));
232       PetscCall(LMBasisGEMVH(S, oldest, next, -1.0, BX, 1.0, YmB0StX));
233     } else PetscCall(MatLMVMBasisGEMVH(B, Y_minus_B0S_t, oldest, next, 1.0, X, 0.0, YmB0StX));
234     PetscCall(LMProductsSolve(YtS_minus_StB0S, oldest, next, YmB0StX, v, PETSC_FALSE));
235     PetscCall(MatLMVMBasisGEMV(B, Y_minus_B0S_t, oldest, next, 1.0, v, 1.0, BX));
236     PetscCall(MatLMVMRestoreWorkRow(B, &v));
237     PetscCall(MatLMVMRestoreWorkRow(B, &YmB0StX));
238   }
239   PetscFunctionReturn(PETSC_SUCCESS);
240 }
241 
MatMult_LMVMSR1_CompactDense(Mat B,Vec X,Vec BX)242 static PetscErrorCode MatMult_LMVMSR1_CompactDense(Mat B, Vec X, Vec BX)
243 {
244   PetscFunctionBegin;
245   PetscCall(SR1Kernel_CompactDense(B, MATLMVM_MODE_PRIMAL, X, BX));
246   PetscFunctionReturn(PETSC_SUCCESS);
247 }
248 
MatSolve_LMVMSR1_CompactDense(Mat B,Vec X,Vec BX)249 static PetscErrorCode MatSolve_LMVMSR1_CompactDense(Mat B, Vec X, Vec BX)
250 {
251   PetscFunctionBegin;
252   PetscCall(SR1Kernel_CompactDense(B, MATLMVM_MODE_DUAL, X, BX));
253   PetscFunctionReturn(PETSC_SUCCESS);
254 }
255 
MatMult_LMVMSR1_Recursive(Mat B,Vec X,Vec Z)256 static PetscErrorCode MatMult_LMVMSR1_Recursive(Mat B, Vec X, Vec Z)
257 {
258   PetscFunctionBegin;
259   PetscCall(SR1Kernel_Recursive(B, MATLMVM_MODE_PRIMAL, X, Z));
260   PetscFunctionReturn(PETSC_SUCCESS);
261 }
262 
MatSolve_LMVMSR1_Recursive(Mat B,Vec F,Vec dX)263 static PetscErrorCode MatSolve_LMVMSR1_Recursive(Mat B, Vec F, Vec dX)
264 {
265   PetscFunctionBegin;
266   PetscCall(SR1Kernel_Recursive(B, MATLMVM_MODE_DUAL, F, dX));
267   PetscFunctionReturn(PETSC_SUCCESS);
268 }
269 
MatUpdate_LMVMSR1(Mat B,Vec X,Vec F)270 static PetscErrorCode MatUpdate_LMVMSR1(Mat B, Vec X, Vec F)
271 {
272   Mat_LMVM *lmvm          = (Mat_LMVM *)B->data;
273   Mat_LSR1 *sr1           = (Mat_LSR1 *)lmvm->ctx;
274   PetscBool cache_SmH0YtF = (lmvm->mult_alg != MAT_LMVM_MULT_RECURSIVE && !lmvm->do_not_cache_J0_products) ? lmvm->cache_gradient_products : PETSC_FALSE;
275 
276   PetscFunctionBegin;
277   if (!lmvm->m) PetscFunctionReturn(PETSC_SUCCESS);
278   if (lmvm->prev_set) {
279     PetscReal   snorm, pnorm;
280     PetscScalar sktw;
281     Vec         work;
282     Vec         Fprev_old       = NULL;
283     Vec         SmH0YtFprev_old = NULL;
284     LMProducts  SmH0YtY         = NULL;
285     PetscInt    oldest, next;
286     LMBasis     SmH0Y = NULL;
287     LMBasis     Y;
288 
289     PetscCall(MatLMVMGetRange(B, &oldest, &next));
290     if (cache_SmH0YtF) {
291       PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_S_MINUS_H0Y, &SmH0Y, NULL, NULL));
292       if (!sr1->SmH0YtFprev) PetscCall(LMBasisCreateRow(SmH0Y, &sr1->SmH0YtFprev));
293       PetscCall(LMBasisGetWorkVec(SmH0Y, &Fprev_old));
294       PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_S_MINUS_H0Y, LMBASIS_Y, LMBLOCK_UPPER_TRIANGLE, &SmH0YtY));
295       PetscCall(LMProductsGetNextColumn(SmH0YtY, &SmH0YtFprev_old));
296       PetscCall(VecCopy(lmvm->Fprev, Fprev_old));
297       if (sr1->SmH0YtFprev == SmH0Y->cached_product) {
298         PetscCall(VecCopy(sr1->SmH0YtFprev, SmH0YtFprev_old));
299       } else {
300         if (next > oldest) {
301           // need to recalculate
302           PetscCall(LMBasisGEMVH(SmH0Y, oldest, next, 1.0, Fprev_old, 0.0, SmH0YtFprev_old));
303         } else {
304           PetscCall(VecZeroEntries(SmH0YtFprev_old));
305         }
306       }
307     }
308 
309     /* Compute the new (S = X - Xprev) and (Y = F - Fprev) vectors */
310     PetscCall(VecAYPX(lmvm->Xprev, -1.0, X));
311     PetscCall(VecAYPX(lmvm->Fprev, -1.0, F));
312 
313     /* See if the updates can be accepted
314        NOTE: This tests abs(S[k]^T (Y[k] - B_k*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_k*S[k])
315 
316        Note that this test is flawed because this is a **limited memory** SR1 method: we are testing
317 
318          abs(S[k]^T (Y[k] - B_{k,m}*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_{k,m}*S[k])
319 
320        when the oldest pair of vectors in the definition of B_{k,m}, (s_{k-m}, y_{k-m}), will be dropped if we add a new
321        pair.  To really ensure that B_{k+1} = B_{k+1,m} is nonsingular, you need to test
322 
323          abs(S[k]^T (Y[k] - B_{k,m-1}*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_{k,m-1}*S[k])
324 
325        But the product B_{k,m-1}*S[k] is not readily computable (see e.g. Lu, Xuehua, "A study of the limited memory SR1
326        method in practice", 1996).
327      */
328     PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_Y, &Y, NULL, NULL));
329     PetscCall(LMBasisGetWorkVec(Y, &work));
330     PetscCall(MatMult(B, lmvm->Xprev, work));
331     PetscCall(VecAYPX(work, -1.0, lmvm->Fprev));
332     PetscCall(VecDot(lmvm->Xprev, work, &sktw));
333     PetscCall(VecNorm(lmvm->Xprev, NORM_2, &snorm));
334     PetscCall(VecNorm(work, NORM_2, &pnorm));
335     PetscCall(LMBasisRestoreWorkVec(Y, &work));
336     if (PetscAbsReal(PetscRealPart(sktw)) >= lmvm->eps * snorm * pnorm) {
337       /* Update is good, accept it */
338       PetscCall(MatUpdateKernel_LMVM(B, lmvm->Xprev, lmvm->Fprev));
339       if (cache_SmH0YtF) {
340         PetscInt oldest_new, next_new;
341 
342         PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_S_MINUS_H0Y, &SmH0Y, NULL, NULL));
343         PetscCall(MatLMVMGetRange(B, &oldest_new, &next_new));
344         PetscCall(LMBasisGEMVH(SmH0Y, next, next_new, 1.0, Fprev_old, 0.0, SmH0YtFprev_old));
345         PetscCall(LMBasisGEMVH(SmH0Y, oldest_new, next_new, 1.0, F, 0.0, sr1->SmH0YtFprev));
346         PetscCall(LMBasisSetCachedProduct(SmH0Y, F, sr1->SmH0YtFprev));
347         PetscCall(VecAXPBY(SmH0YtFprev_old, 1.0, -1.0, sr1->SmH0YtFprev));
348         PetscCall(LMProductsRestoreNextColumn(SmH0YtY, &SmH0YtFprev_old));
349       }
350     } else {
351       /* Update is bad, skip it */
352       lmvm->nrejects++;
353       if (cache_SmH0YtF) {
354         // we still need to update the cached product
355         PetscCall(LMBasisGEMVH(SmH0Y, oldest, next, 1.0, F, 0.0, sr1->SmH0YtFprev));
356         PetscCall(LMBasisSetCachedProduct(SmH0Y, F, sr1->SmH0YtFprev));
357       }
358     }
359     if (cache_SmH0YtF) PetscCall(LMBasisRestoreWorkVec(SmH0Y, &Fprev_old));
360   }
361   /* Save the solution and function to be used in the next update */
362   PetscCall(VecCopy(X, lmvm->Xprev));
363   PetscCall(VecCopy(F, lmvm->Fprev));
364   lmvm->prev_set = PETSC_TRUE;
365   PetscFunctionReturn(PETSC_SUCCESS);
366 }
367 
MatReset_LMVMSR1(Mat B,MatLMVMResetMode mode)368 static PetscErrorCode MatReset_LMVMSR1(Mat B, MatLMVMResetMode mode)
369 {
370   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
371   Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;
372 
373   PetscFunctionBegin;
374   if (MatLMVMResetClearsBases(mode)) {
375     for (PetscInt i = 0; i < SR1_BASIS_COUNT; i++) PetscCall(LMBasisDestroy(&lsr1->basis[i]));
376     for (PetscInt i = 0; i < SR1_PRODUCTS_COUNT; i++) PetscCall(LMProductsDestroy(&lsr1->products[i]));
377     PetscCall(VecDestroy(&lsr1->StFprev));
378     PetscCall(VecDestroy(&lsr1->SmH0YtFprev));
379   } else {
380     for (PetscInt i = 0; i < SR1_BASIS_COUNT; i++) PetscCall(LMBasisReset(lsr1->basis[i]));
381     for (PetscInt i = 0; i < SR1_PRODUCTS_COUNT; i++) PetscCall(LMProductsReset(lsr1->products[i]));
382     if (lsr1->StFprev) PetscCall(VecZeroEntries(lsr1->StFprev));
383     if (lsr1->SmH0YtFprev) PetscCall(VecZeroEntries(lsr1->SmH0YtFprev));
384   }
385   PetscFunctionReturn(PETSC_SUCCESS);
386 }
387 
MatDestroy_LMVMSR1(Mat B)388 static PetscErrorCode MatDestroy_LMVMSR1(Mat B)
389 {
390   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
391 
392   PetscFunctionBegin;
393   PetscCall(MatReset_LMVMSR1(B, MAT_LMVM_RESET_ALL));
394   PetscCall(PetscFree(lmvm->ctx));
395   PetscCall(MatDestroy_LMVM(B));
396   PetscFunctionReturn(PETSC_SUCCESS);
397 }
398 
MatLMVMSetMultAlgorithm_SR1(Mat B)399 static PetscErrorCode MatLMVMSetMultAlgorithm_SR1(Mat B)
400 {
401   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
402 
403   PetscFunctionBegin;
404   switch (lmvm->mult_alg) {
405   case MAT_LMVM_MULT_RECURSIVE:
406     lmvm->ops->mult  = MatMult_LMVMSR1_Recursive;
407     lmvm->ops->solve = MatSolve_LMVMSR1_Recursive;
408     break;
409   case MAT_LMVM_MULT_DENSE:
410   case MAT_LMVM_MULT_COMPACT_DENSE:
411     lmvm->ops->mult  = MatMult_LMVMSR1_CompactDense;
412     lmvm->ops->solve = MatSolve_LMVMSR1_CompactDense;
413     break;
414   }
415   lmvm->ops->multht  = lmvm->ops->mult;
416   lmvm->ops->solveht = lmvm->ops->solve;
417   PetscFunctionReturn(PETSC_SUCCESS);
418 }
419 
MatCreate_LMVMSR1(Mat B)420 PetscErrorCode MatCreate_LMVMSR1(Mat B)
421 {
422   Mat_LMVM *lmvm;
423   Mat_LSR1 *lsr1;
424 
425   PetscFunctionBegin;
426   PetscCall(MatCreate_LMVM(B));
427   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMSR1));
428   PetscCall(MatSetOption(B, MAT_HERMITIAN, PETSC_TRUE));
429   B->ops->destroy = MatDestroy_LMVMSR1;
430 
431   lmvm                          = (Mat_LMVM *)B->data;
432   lmvm->ops->reset              = MatReset_LMVMSR1;
433   lmvm->ops->update             = MatUpdate_LMVMSR1;
434   lmvm->ops->setmultalgorithm   = MatLMVMSetMultAlgorithm_SR1;
435   lmvm->cache_gradient_products = PETSC_TRUE;
436   PetscCall(MatLMVMSetMultAlgorithm_SR1(B));
437   PetscCall(PetscNew(&lsr1));
438   lmvm->ctx = (void *)lsr1;
439   PetscFunctionReturn(PETSC_SUCCESS);
440 }
441 
442 /*@
443   MatCreateLMVMSR1 - Creates a limited-memory Symmetric-Rank-1 approximation
444   matrix used for a Jacobian. L-SR1 is symmetric by construction, but is not
445   guaranteed to be positive-definite.
446 
447   To use the L-SR1 matrix with other vector types, the matrix must be
448   created using `MatCreate()` and `MatSetType()`, followed by `MatLMVMAllocate()`.
449   This ensures that the internal storage and work vectors are duplicated from the
450   correct type of vector.
451 
452   Collective
453 
454   Input Parameters:
455 + comm - MPI communicator
456 . n    - number of local rows for storage vectors
457 - N    - global size of the storage vectors
458 
459   Output Parameter:
460 . B - the matrix
461 
462   Options Database Keys:
463 + -mat_lmvm_hist_size         - the number of history vectors to keep
464 . -mat_lmvm_mult_algorithm    - the algorithm to use for multiplication (recursive, dense, compact_dense)
465 . -mat_lmvm_cache_J0_products - whether products between the base Jacobian J0 and history vectors should be cached or recomputed
466 . -mat_lmvm_eps               - (developer) numerical zero tolerance for testing when an update should be skipped
467 - -mat_lmvm_debug             - (developer) perform internal debugging checks
468 
469   Level: intermediate
470 
471   Note:
472   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`
473   paradigm instead of this routine directly.
474 
475 .seealso: [](ch_ksp), `MatCreate()`, `MATLMVM`, `MATLMVMSR1`, `MatCreateLMVMBFGS()`, `MatCreateLMVMDFP()`,
476           `MatCreateLMVMBroyden()`, `MatCreateLMVMBadBroyden()`, `MatCreateLMVMSymBroyden()`
477 @*/
MatCreateLMVMSR1(MPI_Comm comm,PetscInt n,PetscInt N,Mat * B)478 PetscErrorCode MatCreateLMVMSR1(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
479 {
480   PetscFunctionBegin;
481   PetscCall(KSPInitializePackage());
482   PetscCall(MatCreate(comm, B));
483   PetscCall(MatSetSizes(*B, n, n, N, N));
484   PetscCall(MatSetType(*B, MATLMVMSR1));
485   PetscCall(MatSetUp(*B));
486   PetscFunctionReturn(PETSC_SUCCESS);
487 }
488