1 #include <../src/ksp/ksp/utils/lmvm/lmvm.h> /*I "petscksp.h" I*/
2
3 /*
4 Limited-memory Symmetric-Rank-1 method for approximating both
5 the forward product and inverse application of a Jacobian.
6 */
7
8 // bases needed by SR1 algorithms beyond those in Mat_LMVM
9 enum {
10 SR1_BASIS_Y_MINUS_BKS = 0, // Y_k - B_k S_k for recursive algorithms
11 SR1_BASIS_S_MINUS_HKY = 1, // dual to the above, S_k - H_k Y_k
12 SR1_BASIS_COUNT
13 };
14
15 typedef PetscInt SR1BasisType;
16
17 // products needed by SR1 algorithms beyond those in Mat_LMVM
18 enum {
19 SR1_PRODUCTS_YTS_MINUS_STB0S = 0, // stores and factors symm(triu((Y - B_0 S)^T S)) for compact algorithms
20 SR1_PRODUCTS_STY_MINUS_YTH0Y = 1, // dual to the above, stores and factors symm(triu((S - H_0 Y)^T Y))
21 SR1_PRODUCTS_YTS_MINUS_STBKS = 2, // diagonal (Y_k - B_k S_k)^T S_k values for recursive algorithms
22 SR1_PRODUCTS_STY_MINUS_YTHKY = 3, // dual to the above, diagonal (S_k - H_k Y_k)^T Y_k
23 SR1_PRODUCTS_COUNT
24 };
25
26 typedef PetscInt SR1ProductsType;
27
28 typedef struct {
29 LMBasis basis[SR1_BASIS_COUNT];
30 LMProducts products[SR1_PRODUCTS_COUNT];
31 Vec StFprev, SmH0YtFprev;
32 } Mat_LSR1;
33
34 /* The SR1 kernel can be written as
35
36 B_{k+1} = B_k + (y_k - B_k s_k) ((y_k - B_k s_k)^T s_k)^{-1} (y_k - B_k s_k)^T
37
38 this unrolls to a rank-m update
39
40 B_{k+1} = B_0 + \sum_{i = k-m+1}^k (y_i - B_i s_i) ((y_i - B_i s_i)^T s_i)^{-1} (y_i - B_i s_i)^T
41
42 This inner kernel assumes the (y_i - B_i s_i) vectors and the ((y_i - B_i s_i)^T s_i) products have been computed
43 */
44
SR1Kernel_Recursive_Inner(Mat B,MatLMVMMode mode,PetscInt oldest,PetscInt next,Vec X,Vec BX)45 static PetscErrorCode SR1Kernel_Recursive_Inner(Mat B, MatLMVMMode mode, PetscInt oldest, PetscInt next, Vec X, Vec BX)
46 {
47 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
48 Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;
49 SR1BasisType Y_minus_BkS_t = LMVMModeMap(SR1_BASIS_Y_MINUS_BKS, mode);
50 SR1ProductsType YtS_minus_StBkS_t = LMVMModeMap(SR1_PRODUCTS_STY_MINUS_YTHKY, mode);
51 LMBasis Y_minus_BkS = lsr1->basis[Y_minus_BkS_t];
52 LMProducts YtS_minus_StBkS = lsr1->products[YtS_minus_StBkS_t];
53 Vec YmBkStX;
54
55 PetscFunctionBegin;
56 PetscCall(MatLMVMGetWorkRow(B, &YmBkStX));
57 PetscCall(LMBasisGEMVH(Y_minus_BkS, oldest, next, 1.0, X, 0.0, YmBkStX));
58 PetscCall(LMProductsSolve(YtS_minus_StBkS, oldest, next, YmBkStX, YmBkStX, /* ^H */ PETSC_FALSE));
59 PetscCall(LMBasisGEMV(Y_minus_BkS, oldest, next, 1.0, YmBkStX, 1.0, BX));
60 PetscCall(MatLMVMRestoreWorkRow(B, &YmBkStX));
61 PetscFunctionReturn(PETSC_SUCCESS);
62 }
63
64 /* Recursively compute the (y_i - B_i s_i) vectors and ((y_i - B_i s_i)^T s_i) products */
65
SR1RecursiveBasisUpdate(Mat B,MatLMVMMode mode)66 static PetscErrorCode SR1RecursiveBasisUpdate(Mat B, MatLMVMMode mode)
67 {
68 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
69 Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;
70 MatLMVMBasisType B0S_t = LMVMModeMap(LMBASIS_B0S, mode);
71 MatLMVMBasisType S_t = LMVMModeMap(LMBASIS_S, mode);
72 MatLMVMBasisType Y_t = LMVMModeMap(LMBASIS_Y, mode);
73 SR1BasisType Y_minus_BkS_t = LMVMModeMap(SR1_BASIS_Y_MINUS_BKS, mode);
74 SR1ProductsType YtS_minus_StBkS_t = LMVMModeMap(SR1_PRODUCTS_STY_MINUS_YTHKY, mode);
75 LMBasis Y_minus_BkS;
76 LMProducts YtS_minus_StBkS;
77 PetscInt oldest, next;
78 PetscInt products_oldest;
79 LMBasis S, Y;
80 PetscInt start;
81
82 PetscFunctionBegin;
83 if (!lsr1->basis[Y_minus_BkS_t]) PetscCall(LMBasisCreate(mode == MATLMVM_MODE_PRIMAL ? lmvm->Fprev : lmvm->Xprev, lmvm->m, &lsr1->basis[Y_minus_BkS_t]));
84 Y_minus_BkS = lsr1->basis[Y_minus_BkS_t];
85 if (!lsr1->products[YtS_minus_StBkS_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_DIAGONAL, &lsr1->products[YtS_minus_StBkS_t]));
86 YtS_minus_StBkS = lsr1->products[YtS_minus_StBkS_t];
87 PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
88 PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
89 PetscCall(MatLMVMGetRange(B, &oldest, &next));
90 // invalidate computed values if J0 has changed
91 PetscCall(LMProductsPrepare(YtS_minus_StBkS, lmvm->J0, oldest, next));
92 products_oldest = PetscMax(0, YtS_minus_StBkS->k - lmvm->m);
93 if (oldest > products_oldest) {
94 // recursion is starting from a different starting index, it must be recomputed
95 YtS_minus_StBkS->k = oldest;
96 }
97 Y_minus_BkS->k = start = YtS_minus_StBkS->k;
98 // recompute each column in Y_minus_BkS in order
99 for (PetscInt j = start; j < next; j++) {
100 Vec s_j, B0s_j, p_j, y_j;
101 PetscScalar alpha, ymbksts;
102
103 PetscCall(LMBasisGetWorkVec(Y_minus_BkS, &p_j));
104
105 // p_j starts as B_0 * s_j
106 PetscCall(MatLMVMBasisGetVecRead(B, B0S_t, j, &B0s_j, &alpha));
107 PetscCall(VecAXPBY(p_j, alpha, 0.0, B0s_j));
108 PetscCall(MatLMVMBasisRestoreVecRead(B, B0S_t, j, &B0s_j, &alpha));
109
110 // Use the matmult kernel to compute p_j = B_j * p_j
111 PetscCall(LMBasisGetVecRead(S, j, &s_j));
112 // if j == oldest p_j is already correct
113 if (j > oldest) PetscCall(SR1Kernel_Recursive_Inner(B, mode, oldest, j, s_j, p_j));
114 PetscCall(LMBasisGetVecRead(Y, j, &y_j));
115 PetscCall(VecAYPX(p_j, -1.0, y_j));
116 PetscCall(VecDot(s_j, p_j, &ymbksts));
117 PetscCall(LMProductsInsertNextDiagonalValue(YtS_minus_StBkS, j, ymbksts));
118 PetscCall(LMBasisRestoreVecRead(S, j, &s_j));
119 PetscCall(LMBasisRestoreVecRead(Y, j, &y_j));
120 PetscCall(LMBasisSetNextVec(Y_minus_BkS, p_j));
121 PetscCall(LMBasisRestoreWorkVec(Y_minus_BkS, &p_j));
122 }
123 PetscFunctionReturn(PETSC_SUCCESS);
124 }
125
SR1Kernel_Recursive(Mat B,MatLMVMMode mode,Vec X,Vec BX)126 static PetscErrorCode SR1Kernel_Recursive(Mat B, MatLMVMMode mode, Vec X, Vec BX)
127 {
128 PetscInt oldest, next;
129
130 PetscFunctionBegin;
131 PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, BX));
132 PetscCall(MatLMVMGetRange(B, &oldest, &next));
133 if (next > oldest) {
134 PetscCall(SR1RecursiveBasisUpdate(B, mode));
135 PetscCall(SR1Kernel_Recursive_Inner(B, mode, oldest, next, X, BX));
136 }
137 PetscFunctionReturn(PETSC_SUCCESS);
138 }
139
140 /* The SR1 kernel can be written as (See Byrd, Schnabel & Nocedal 1994)
141
142 B_{k+1} = B_0 + (Y - B_0 S) (diag(S^T Y) + stril(S^T Y) + stril(S^T Y)^T - S^T B_0 S)^{-1} (Y - B_0 S)^T
143 \___________________________ ___________________________/
144 V
145 M
146
147 M is symmetric indefinite (stril is the strictly lower triangular part)
148
149 M can be computed by computed triu((Y - B_0 S)^T S) and filling in the lower triangle
150 */
151
SR1CompactProductsUpdate(Mat B,MatLMVMMode mode)152 static PetscErrorCode SR1CompactProductsUpdate(Mat B, MatLMVMMode mode)
153 {
154 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
155 Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;
156 MatLMVMBasisType S_t = LMVMModeMap(LMBASIS_S, mode);
157 MatLMVMBasisType YmB0S_t = LMVMModeMap(LMBASIS_Y_MINUS_B0S, mode);
158 SR1ProductsType YtS_minus_StB0S_t = LMVMModeMap(SR1_PRODUCTS_YTS_MINUS_STB0S, mode);
159 LMProducts YtS_minus_StB0S;
160 Mat local;
161 PetscInt oldest, next, k;
162 PetscBool local_is_nonempty;
163
164 PetscFunctionBegin;
165 if (!lsr1->products[YtS_minus_StB0S_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_FULL, &lsr1->products[YtS_minus_StB0S_t]));
166 YtS_minus_StB0S = lsr1->products[YtS_minus_StB0S_t];
167 PetscCall(MatLMVMGetRange(B, &oldest, &next));
168 PetscCall(LMProductsPrepare(YtS_minus_StB0S, lmvm->J0, oldest, next));
169 PetscCall(LMProductsGetLocalMatrix(YtS_minus_StB0S, &local, &k, &local_is_nonempty));
170 if (YtS_minus_StB0S->k < next) {
171 // copy to factor in place
172 LMProducts YmB0StS;
173 Mat ymb0sts_local;
174
175 PetscCall(PetscCitationsRegister(ByrdNocedalSchnabelCitation, &ByrdNocedalSchnabelCite));
176 YtS_minus_StB0S->k = next;
177 PetscCall(MatLMVMGetUpdatedProducts(B, YmB0S_t, S_t, LMBLOCK_UPPER_TRIANGLE, &YmB0StS));
178 PetscCall(LMProductsGetLocalMatrix(YmB0StS, &ymb0sts_local, NULL, NULL));
179 if (local_is_nonempty) {
180 PetscErrorCode ierr;
181
182 PetscCall(MatSetUnfactored(local));
183 PetscCall(MatCopy(ymb0sts_local, local, SAME_NONZERO_PATTERN));
184 PetscCall(LMProductsMakeHermitian(local, oldest, next));
185 PetscCall(LMProductsOnesOnUnusedDiagonal(local, oldest, next));
186 PetscCall(MatSetOption(local, MAT_HERMITIAN, PETSC_TRUE));
187 // Set not spd so that "Cholesky" factorization is actually the symmetric indefinite Bunch Kaufman factorization
188 PetscCall(MatSetOption(local, MAT_SPD, PETSC_FALSE));
189
190 PetscCall(PetscPushErrorHandler(PetscReturnErrorHandler, NULL));
191 ierr = MatCholeskyFactor(local, NULL, NULL);
192 PetscCall(PetscPopErrorHandler());
193 PetscCheck(ierr == PETSC_SUCCESS || ierr == PETSC_ERR_SUP, PETSC_COMM_SELF, ierr, "Error in Bunch-Kaufman factorization");
194 // cusolver does not provide Bunch Kaufman, resort to LU if it is unavailable
195 if (ierr == PETSC_ERR_SUP) PetscCall(MatLUFactor(local, NULL, NULL, NULL));
196 }
197 PetscCall(LMProductsRestoreLocalMatrix(YmB0StS, &ymb0sts_local, NULL));
198 }
199 PetscCall(LMProductsRestoreLocalMatrix(YtS_minus_StB0S, &local, &next));
200 PetscFunctionReturn(PETSC_SUCCESS);
201 }
202
SR1Kernel_CompactDense(Mat B,MatLMVMMode mode,Vec X,Vec BX)203 static PetscErrorCode SR1Kernel_CompactDense(Mat B, MatLMVMMode mode, Vec X, Vec BX)
204 {
205 PetscInt oldest, next;
206
207 PetscFunctionBegin;
208 PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, BX));
209 PetscCall(MatLMVMGetRange(B, &oldest, &next));
210 if (next > oldest) {
211 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
212 Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;
213 MatLMVMBasisType Y_minus_B0S_t = LMVMModeMap(LMBASIS_Y_MINUS_B0S, mode);
214 SR1ProductsType YtS_minus_StB0S_t = LMVMModeMap(SR1_PRODUCTS_YTS_MINUS_STB0S, mode);
215 LMProducts YtS_minus_StB0S;
216 Vec YmB0StX, v;
217
218 PetscCall(SR1CompactProductsUpdate(B, mode));
219 YtS_minus_StB0S = lsr1->products[YtS_minus_StB0S_t];
220 PetscCall(MatLMVMGetWorkRow(B, &YmB0StX));
221 PetscCall(MatLMVMGetWorkRow(B, &v));
222 if (lmvm->do_not_cache_J0_products) {
223 /* the initial (Y - B_0 S)^T x inner product can be computed as Y^T x - S^T (B_0 x)
224 if we are not caching B_0 S products */
225 MatLMVMBasisType S_t = LMVMModeMap(LMBASIS_S, mode);
226 MatLMVMBasisType Y_t = LMVMModeMap(LMBASIS_Y, mode);
227 LMBasis S, Y;
228
229 PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
230 PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
231 PetscCall(LMBasisGEMVH(Y, oldest, next, 1.0, X, 0.0, YmB0StX));
232 PetscCall(LMBasisGEMVH(S, oldest, next, -1.0, BX, 1.0, YmB0StX));
233 } else PetscCall(MatLMVMBasisGEMVH(B, Y_minus_B0S_t, oldest, next, 1.0, X, 0.0, YmB0StX));
234 PetscCall(LMProductsSolve(YtS_minus_StB0S, oldest, next, YmB0StX, v, PETSC_FALSE));
235 PetscCall(MatLMVMBasisGEMV(B, Y_minus_B0S_t, oldest, next, 1.0, v, 1.0, BX));
236 PetscCall(MatLMVMRestoreWorkRow(B, &v));
237 PetscCall(MatLMVMRestoreWorkRow(B, &YmB0StX));
238 }
239 PetscFunctionReturn(PETSC_SUCCESS);
240 }
241
MatMult_LMVMSR1_CompactDense(Mat B,Vec X,Vec BX)242 static PetscErrorCode MatMult_LMVMSR1_CompactDense(Mat B, Vec X, Vec BX)
243 {
244 PetscFunctionBegin;
245 PetscCall(SR1Kernel_CompactDense(B, MATLMVM_MODE_PRIMAL, X, BX));
246 PetscFunctionReturn(PETSC_SUCCESS);
247 }
248
MatSolve_LMVMSR1_CompactDense(Mat B,Vec X,Vec BX)249 static PetscErrorCode MatSolve_LMVMSR1_CompactDense(Mat B, Vec X, Vec BX)
250 {
251 PetscFunctionBegin;
252 PetscCall(SR1Kernel_CompactDense(B, MATLMVM_MODE_DUAL, X, BX));
253 PetscFunctionReturn(PETSC_SUCCESS);
254 }
255
MatMult_LMVMSR1_Recursive(Mat B,Vec X,Vec Z)256 static PetscErrorCode MatMult_LMVMSR1_Recursive(Mat B, Vec X, Vec Z)
257 {
258 PetscFunctionBegin;
259 PetscCall(SR1Kernel_Recursive(B, MATLMVM_MODE_PRIMAL, X, Z));
260 PetscFunctionReturn(PETSC_SUCCESS);
261 }
262
MatSolve_LMVMSR1_Recursive(Mat B,Vec F,Vec dX)263 static PetscErrorCode MatSolve_LMVMSR1_Recursive(Mat B, Vec F, Vec dX)
264 {
265 PetscFunctionBegin;
266 PetscCall(SR1Kernel_Recursive(B, MATLMVM_MODE_DUAL, F, dX));
267 PetscFunctionReturn(PETSC_SUCCESS);
268 }
269
MatUpdate_LMVMSR1(Mat B,Vec X,Vec F)270 static PetscErrorCode MatUpdate_LMVMSR1(Mat B, Vec X, Vec F)
271 {
272 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
273 Mat_LSR1 *sr1 = (Mat_LSR1 *)lmvm->ctx;
274 PetscBool cache_SmH0YtF = (lmvm->mult_alg != MAT_LMVM_MULT_RECURSIVE && !lmvm->do_not_cache_J0_products) ? lmvm->cache_gradient_products : PETSC_FALSE;
275
276 PetscFunctionBegin;
277 if (!lmvm->m) PetscFunctionReturn(PETSC_SUCCESS);
278 if (lmvm->prev_set) {
279 PetscReal snorm, pnorm;
280 PetscScalar sktw;
281 Vec work;
282 Vec Fprev_old = NULL;
283 Vec SmH0YtFprev_old = NULL;
284 LMProducts SmH0YtY = NULL;
285 PetscInt oldest, next;
286 LMBasis SmH0Y = NULL;
287 LMBasis Y;
288
289 PetscCall(MatLMVMGetRange(B, &oldest, &next));
290 if (cache_SmH0YtF) {
291 PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_S_MINUS_H0Y, &SmH0Y, NULL, NULL));
292 if (!sr1->SmH0YtFprev) PetscCall(LMBasisCreateRow(SmH0Y, &sr1->SmH0YtFprev));
293 PetscCall(LMBasisGetWorkVec(SmH0Y, &Fprev_old));
294 PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_S_MINUS_H0Y, LMBASIS_Y, LMBLOCK_UPPER_TRIANGLE, &SmH0YtY));
295 PetscCall(LMProductsGetNextColumn(SmH0YtY, &SmH0YtFprev_old));
296 PetscCall(VecCopy(lmvm->Fprev, Fprev_old));
297 if (sr1->SmH0YtFprev == SmH0Y->cached_product) {
298 PetscCall(VecCopy(sr1->SmH0YtFprev, SmH0YtFprev_old));
299 } else {
300 if (next > oldest) {
301 // need to recalculate
302 PetscCall(LMBasisGEMVH(SmH0Y, oldest, next, 1.0, Fprev_old, 0.0, SmH0YtFprev_old));
303 } else {
304 PetscCall(VecZeroEntries(SmH0YtFprev_old));
305 }
306 }
307 }
308
309 /* Compute the new (S = X - Xprev) and (Y = F - Fprev) vectors */
310 PetscCall(VecAYPX(lmvm->Xprev, -1.0, X));
311 PetscCall(VecAYPX(lmvm->Fprev, -1.0, F));
312
313 /* See if the updates can be accepted
314 NOTE: This tests abs(S[k]^T (Y[k] - B_k*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_k*S[k])
315
316 Note that this test is flawed because this is a **limited memory** SR1 method: we are testing
317
318 abs(S[k]^T (Y[k] - B_{k,m}*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_{k,m}*S[k])
319
320 when the oldest pair of vectors in the definition of B_{k,m}, (s_{k-m}, y_{k-m}), will be dropped if we add a new
321 pair. To really ensure that B_{k+1} = B_{k+1,m} is nonsingular, you need to test
322
323 abs(S[k]^T (Y[k] - B_{k,m-1}*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_{k,m-1}*S[k])
324
325 But the product B_{k,m-1}*S[k] is not readily computable (see e.g. Lu, Xuehua, "A study of the limited memory SR1
326 method in practice", 1996).
327 */
328 PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_Y, &Y, NULL, NULL));
329 PetscCall(LMBasisGetWorkVec(Y, &work));
330 PetscCall(MatMult(B, lmvm->Xprev, work));
331 PetscCall(VecAYPX(work, -1.0, lmvm->Fprev));
332 PetscCall(VecDot(lmvm->Xprev, work, &sktw));
333 PetscCall(VecNorm(lmvm->Xprev, NORM_2, &snorm));
334 PetscCall(VecNorm(work, NORM_2, &pnorm));
335 PetscCall(LMBasisRestoreWorkVec(Y, &work));
336 if (PetscAbsReal(PetscRealPart(sktw)) >= lmvm->eps * snorm * pnorm) {
337 /* Update is good, accept it */
338 PetscCall(MatUpdateKernel_LMVM(B, lmvm->Xprev, lmvm->Fprev));
339 if (cache_SmH0YtF) {
340 PetscInt oldest_new, next_new;
341
342 PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_S_MINUS_H0Y, &SmH0Y, NULL, NULL));
343 PetscCall(MatLMVMGetRange(B, &oldest_new, &next_new));
344 PetscCall(LMBasisGEMVH(SmH0Y, next, next_new, 1.0, Fprev_old, 0.0, SmH0YtFprev_old));
345 PetscCall(LMBasisGEMVH(SmH0Y, oldest_new, next_new, 1.0, F, 0.0, sr1->SmH0YtFprev));
346 PetscCall(LMBasisSetCachedProduct(SmH0Y, F, sr1->SmH0YtFprev));
347 PetscCall(VecAXPBY(SmH0YtFprev_old, 1.0, -1.0, sr1->SmH0YtFprev));
348 PetscCall(LMProductsRestoreNextColumn(SmH0YtY, &SmH0YtFprev_old));
349 }
350 } else {
351 /* Update is bad, skip it */
352 lmvm->nrejects++;
353 if (cache_SmH0YtF) {
354 // we still need to update the cached product
355 PetscCall(LMBasisGEMVH(SmH0Y, oldest, next, 1.0, F, 0.0, sr1->SmH0YtFprev));
356 PetscCall(LMBasisSetCachedProduct(SmH0Y, F, sr1->SmH0YtFprev));
357 }
358 }
359 if (cache_SmH0YtF) PetscCall(LMBasisRestoreWorkVec(SmH0Y, &Fprev_old));
360 }
361 /* Save the solution and function to be used in the next update */
362 PetscCall(VecCopy(X, lmvm->Xprev));
363 PetscCall(VecCopy(F, lmvm->Fprev));
364 lmvm->prev_set = PETSC_TRUE;
365 PetscFunctionReturn(PETSC_SUCCESS);
366 }
367
MatReset_LMVMSR1(Mat B,MatLMVMResetMode mode)368 static PetscErrorCode MatReset_LMVMSR1(Mat B, MatLMVMResetMode mode)
369 {
370 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
371 Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;
372
373 PetscFunctionBegin;
374 if (MatLMVMResetClearsBases(mode)) {
375 for (PetscInt i = 0; i < SR1_BASIS_COUNT; i++) PetscCall(LMBasisDestroy(&lsr1->basis[i]));
376 for (PetscInt i = 0; i < SR1_PRODUCTS_COUNT; i++) PetscCall(LMProductsDestroy(&lsr1->products[i]));
377 PetscCall(VecDestroy(&lsr1->StFprev));
378 PetscCall(VecDestroy(&lsr1->SmH0YtFprev));
379 } else {
380 for (PetscInt i = 0; i < SR1_BASIS_COUNT; i++) PetscCall(LMBasisReset(lsr1->basis[i]));
381 for (PetscInt i = 0; i < SR1_PRODUCTS_COUNT; i++) PetscCall(LMProductsReset(lsr1->products[i]));
382 if (lsr1->StFprev) PetscCall(VecZeroEntries(lsr1->StFprev));
383 if (lsr1->SmH0YtFprev) PetscCall(VecZeroEntries(lsr1->SmH0YtFprev));
384 }
385 PetscFunctionReturn(PETSC_SUCCESS);
386 }
387
MatDestroy_LMVMSR1(Mat B)388 static PetscErrorCode MatDestroy_LMVMSR1(Mat B)
389 {
390 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
391
392 PetscFunctionBegin;
393 PetscCall(MatReset_LMVMSR1(B, MAT_LMVM_RESET_ALL));
394 PetscCall(PetscFree(lmvm->ctx));
395 PetscCall(MatDestroy_LMVM(B));
396 PetscFunctionReturn(PETSC_SUCCESS);
397 }
398
MatLMVMSetMultAlgorithm_SR1(Mat B)399 static PetscErrorCode MatLMVMSetMultAlgorithm_SR1(Mat B)
400 {
401 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
402
403 PetscFunctionBegin;
404 switch (lmvm->mult_alg) {
405 case MAT_LMVM_MULT_RECURSIVE:
406 lmvm->ops->mult = MatMult_LMVMSR1_Recursive;
407 lmvm->ops->solve = MatSolve_LMVMSR1_Recursive;
408 break;
409 case MAT_LMVM_MULT_DENSE:
410 case MAT_LMVM_MULT_COMPACT_DENSE:
411 lmvm->ops->mult = MatMult_LMVMSR1_CompactDense;
412 lmvm->ops->solve = MatSolve_LMVMSR1_CompactDense;
413 break;
414 }
415 lmvm->ops->multht = lmvm->ops->mult;
416 lmvm->ops->solveht = lmvm->ops->solve;
417 PetscFunctionReturn(PETSC_SUCCESS);
418 }
419
MatCreate_LMVMSR1(Mat B)420 PetscErrorCode MatCreate_LMVMSR1(Mat B)
421 {
422 Mat_LMVM *lmvm;
423 Mat_LSR1 *lsr1;
424
425 PetscFunctionBegin;
426 PetscCall(MatCreate_LMVM(B));
427 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMSR1));
428 PetscCall(MatSetOption(B, MAT_HERMITIAN, PETSC_TRUE));
429 B->ops->destroy = MatDestroy_LMVMSR1;
430
431 lmvm = (Mat_LMVM *)B->data;
432 lmvm->ops->reset = MatReset_LMVMSR1;
433 lmvm->ops->update = MatUpdate_LMVMSR1;
434 lmvm->ops->setmultalgorithm = MatLMVMSetMultAlgorithm_SR1;
435 lmvm->cache_gradient_products = PETSC_TRUE;
436 PetscCall(MatLMVMSetMultAlgorithm_SR1(B));
437 PetscCall(PetscNew(&lsr1));
438 lmvm->ctx = (void *)lsr1;
439 PetscFunctionReturn(PETSC_SUCCESS);
440 }
441
442 /*@
443 MatCreateLMVMSR1 - Creates a limited-memory Symmetric-Rank-1 approximation
444 matrix used for a Jacobian. L-SR1 is symmetric by construction, but is not
445 guaranteed to be positive-definite.
446
447 To use the L-SR1 matrix with other vector types, the matrix must be
448 created using `MatCreate()` and `MatSetType()`, followed by `MatLMVMAllocate()`.
449 This ensures that the internal storage and work vectors are duplicated from the
450 correct type of vector.
451
452 Collective
453
454 Input Parameters:
455 + comm - MPI communicator
456 . n - number of local rows for storage vectors
457 - N - global size of the storage vectors
458
459 Output Parameter:
460 . B - the matrix
461
462 Options Database Keys:
463 + -mat_lmvm_hist_size - the number of history vectors to keep
464 . -mat_lmvm_mult_algorithm - the algorithm to use for multiplication (recursive, dense, compact_dense)
465 . -mat_lmvm_cache_J0_products - whether products between the base Jacobian J0 and history vectors should be cached or recomputed
466 . -mat_lmvm_eps - (developer) numerical zero tolerance for testing when an update should be skipped
467 - -mat_lmvm_debug - (developer) perform internal debugging checks
468
469 Level: intermediate
470
471 Note:
472 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`
473 paradigm instead of this routine directly.
474
475 .seealso: [](ch_ksp), `MatCreate()`, `MATLMVM`, `MATLMVMSR1`, `MatCreateLMVMBFGS()`, `MatCreateLMVMDFP()`,
476 `MatCreateLMVMBroyden()`, `MatCreateLMVMBadBroyden()`, `MatCreateLMVMSymBroyden()`
477 @*/
MatCreateLMVMSR1(MPI_Comm comm,PetscInt n,PetscInt N,Mat * B)478 PetscErrorCode MatCreateLMVMSR1(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
479 {
480 PetscFunctionBegin;
481 PetscCall(KSPInitializePackage());
482 PetscCall(MatCreate(comm, B));
483 PetscCall(MatSetSizes(*B, n, n, N, N));
484 PetscCall(MatSetType(*B, MATLMVMSR1));
485 PetscCall(MatSetUp(*B));
486 PetscFunctionReturn(PETSC_SUCCESS);
487 }
488