1 #include <../src/ksp/ksp/utils/lmvm/symbrdn/symbrdn.h> /*I "petscksp.h" I*/
2 #include <petsc/private/vecimpl.h>
3 #include <petscdevice.h>
4
5 /* The BFGS update can be written
6
7 B_{k+1} = B_k + y_k (y_k^T s_k)^{-1} y_k^T - B_k s_k (s_k^T B_k s_k)^{-1} s_k^T B_k + y_k (y_k^T s_k)^{-1} y_k^T
8
9 Which can be unrolled as a parallel sum
10
11 B_{k+1} = B_0 + \sum_i B_i y_i (y_i^T s_i)^{-1} y_i^T - s_i (s_i^T B_i s_i)^{-1} s_i^T B_i
12
13 Once the (B_i y_i) vectors, (y_i^T s_i), and (s_i^T B_i s_i) products have been computed
14 */
BFGSKernel_Recursive_Inner(Mat B,MatLMVMMode mode,PetscInt oldest,PetscInt next,Vec X,Vec B0X)15 static PetscErrorCode BFGSKernel_Recursive_Inner(Mat B, MatLMVMMode mode, PetscInt oldest, PetscInt next, Vec X, Vec B0X)
16 {
17 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
18 Mat_SymBrdn *lsb = (Mat_SymBrdn *)lmvm->ctx;
19 MatLMVMBasisType Y_t = LMVMModeMap(LMBASIS_Y, mode);
20 LMBasis BkS = lsb->basis[LMVMModeMap(SYMBROYDEN_BASIS_BKS, mode)];
21 LMProducts YtS;
22 LMProducts StBkS = lsb->products[LMVMModeMap(SYMBROYDEN_PRODUCTS_STBKS, mode)];
23 LMBasis Y;
24 Vec StBkX, YtX;
25
26 PetscFunctionBegin;
27 PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
28 PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_Y, LMBASIS_S, LMBLOCK_DIAGONAL, &YtS));
29 PetscCall(MatLMVMGetWorkRow(B, &StBkX));
30 PetscCall(MatLMVMGetWorkRow(B, &YtX));
31 PetscCall(LMBasisGEMVH(BkS, oldest, next, 1.0, X, 0.0, StBkX));
32 PetscCall(LMProductsSolve(StBkS, oldest, next, StBkX, StBkX, /* ^H */ PETSC_FALSE));
33 PetscCall(LMBasisGEMV(BkS, oldest, next, -1.0, StBkX, 1.0, B0X));
34 PetscCall(LMBasisGEMVH(Y, oldest, next, 1.0, X, 0.0, YtX));
35 PetscCall(LMProductsSolve(YtS, oldest, next, YtX, YtX, /* ^H */ PETSC_FALSE));
36 PetscCall(LMBasisGEMV(Y, oldest, next, 1.0, YtX, 1.0, B0X));
37 PetscCall(MatLMVMRestoreWorkRow(B, &YtX));
38 PetscCall(MatLMVMRestoreWorkRow(B, &StBkX));
39 PetscFunctionReturn(PETSC_SUCCESS);
40 }
41
42 /*
43 The B_i s_i vectors and (s_i^T B_i s_i) products are computed recursively
44 */
BFGSRecursiveBasisUpdate(Mat B,MatLMVMMode mode)45 static PetscErrorCode BFGSRecursiveBasisUpdate(Mat B, MatLMVMMode mode)
46 {
47 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
48 Mat_SymBrdn *lsb = (Mat_SymBrdn *)lmvm->ctx;
49 MatLMVMBasisType S_t = LMVMModeMap(LMBASIS_S, mode);
50 MatLMVMBasisType B0S_t = LMVMModeMap(LMBASIS_B0S, mode);
51 SymBroydenProductsType StBkS_t = LMVMModeMap(SYMBROYDEN_PRODUCTS_STBKS, mode);
52 SymBroydenBasisType BkS_t = LMVMModeMap(SYMBROYDEN_BASIS_BKS, mode);
53 LMBasis BkS;
54 LMProducts StBkS, YtS;
55 PetscInt oldest, start, next;
56 PetscInt products_oldest;
57 LMBasis S;
58
59 PetscFunctionBegin;
60 PetscCall(MatLMVMGetRange(B, &oldest, &next));
61 if (!lsb->basis[BkS_t]) PetscCall(LMBasisCreate(MatLMVMBasisSizeOf(B0S_t) == LMBASIS_S ? lmvm->Xprev : lmvm->Fprev, lmvm->m, &lsb->basis[BkS_t]));
62 BkS = lsb->basis[BkS_t];
63 if (!lsb->products[StBkS_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_DIAGONAL, &lsb->products[StBkS_t]));
64 StBkS = lsb->products[StBkS_t];
65 PetscCall(LMProductsPrepare(StBkS, lmvm->J0, oldest, next));
66 products_oldest = PetscMax(0, StBkS->k - lmvm->m);
67 if (oldest > products_oldest) {
68 // recursion is starting from a different starting index, it must be recomputed
69 StBkS->k = oldest;
70 }
71 BkS->k = start = StBkS->k;
72 if (start == next) PetscFunctionReturn(PETSC_SUCCESS);
73
74 PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
75 // make sure YtS is updated before entering the loop
76 PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_Y, LMBASIS_S, LMBLOCK_DIAGONAL, &YtS));
77 for (PetscInt j = start; j < next; j++) {
78 Vec p_j, s_j, B0s_j;
79 PetscScalar alpha, sjtbjsj;
80
81 PetscCall(LMBasisGetWorkVec(BkS, &p_j));
82 // p_j starts as B_0 * s_j
83 PetscCall(MatLMVMBasisGetVecRead(B, B0S_t, j, &B0s_j, &alpha));
84 PetscCall(VecAXPBY(p_j, alpha, 0.0, B0s_j));
85 PetscCall(MatLMVMBasisRestoreVecRead(B, B0S_t, j, &B0s_j, &alpha));
86
87 // Use the matmult kernel to compute p_j = B_j * p_j
88 PetscCall(LMBasisGetVecRead(S, j, &s_j));
89 if (j > oldest) PetscCall(BFGSKernel_Recursive_Inner(B, mode, oldest, j, s_j, p_j));
90 PetscCall(VecDot(p_j, s_j, &sjtbjsj));
91 PetscCall(LMBasisRestoreVecRead(S, j, &s_j));
92 PetscCall(LMProductsInsertNextDiagonalValue(StBkS, j, sjtbjsj));
93 PetscCall(LMBasisSetNextVec(BkS, p_j));
94 PetscCall(LMBasisRestoreWorkVec(BkS, &p_j));
95 }
96 PetscFunctionReturn(PETSC_SUCCESS);
97 }
98
BFGSKernel_Recursive(Mat B,MatLMVMMode mode,Vec X,Vec Y)99 PETSC_INTERN PetscErrorCode BFGSKernel_Recursive(Mat B, MatLMVMMode mode, Vec X, Vec Y)
100 {
101 PetscInt oldest, next;
102
103 PetscFunctionBegin;
104 PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, Y));
105 PetscCall(MatLMVMGetRange(B, &oldest, &next));
106 if (next > oldest) {
107 PetscCall(BFGSRecursiveBasisUpdate(B, mode));
108 PetscCall(BFGSKernel_Recursive_Inner(B, mode, oldest, next, X, Y));
109 }
110 PetscFunctionReturn(PETSC_SUCCESS);
111 }
112
BFGSCompactDenseProductsUpdate(Mat B,MatLMVMMode mode)113 static PetscErrorCode BFGSCompactDenseProductsUpdate(Mat B, MatLMVMMode mode)
114 {
115 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
116 Mat_SymBrdn *lsb = (Mat_SymBrdn *)lmvm->ctx;
117 PetscInt oldest, next, k;
118 MatLMVMBasisType S_t = LMVMModeMap(LMBASIS_S, mode);
119 MatLMVMBasisType B0S_t = LMVMModeMap(LMBASIS_B0S, mode);
120 MatLMVMBasisType Y_t = LMVMModeMap(LMBASIS_Y, mode);
121 SymBroydenProductsType M00_t = LMVMModeMap(SYMBROYDEN_PRODUCTS_M00, mode);
122 LMProducts M00, StB0S, YtS, D;
123 Mat YtS_local, StB0S_local, M00_local;
124 Vec D_local;
125 PetscBool local_is_nonempty;
126
127 PetscFunctionBegin;
128 PetscCall(MatLMVMGetRange(B, &oldest, &next));
129 if (lsb->products[M00_t] && lsb->products[M00_t]->block_type != LMBLOCK_FULL) PetscCall(LMProductsDestroy(&lsb->products[M00_t]));
130 if (!lsb->products[M00_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_FULL, &lsb->products[M00_t]));
131 M00 = lsb->products[M00_t];
132 PetscCall(LMProductsPrepare(M00, lmvm->J0, oldest, next));
133 PetscCall(LMProductsGetLocalMatrix(M00, &M00_local, &k, &local_is_nonempty));
134 if (k < next) {
135 PetscCall(MatLMVMGetUpdatedProducts(B, Y_t, S_t, LMBLOCK_STRICT_UPPER_TRIANGLE, &YtS));
136 PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_Y, LMBASIS_S, LMBLOCK_DIAGONAL, &D));
137 PetscCall(MatLMVMGetUpdatedProducts(B, S_t, B0S_t, LMBLOCK_UPPER_TRIANGLE, &StB0S));
138
139 PetscCall(LMProductsGetLocalMatrix(StB0S, &StB0S_local, NULL, NULL));
140 PetscCall(LMProductsGetLocalMatrix(YtS, &YtS_local, NULL, NULL));
141 PetscCall(LMProductsGetLocalDiagonal(D, &D_local));
142 if (local_is_nonempty) {
143 Vec invD;
144 Mat stril_StY;
145
146 PetscCall(MatSetUnfactored(M00_local));
147 PetscCall(MatCopy(StB0S_local, M00_local, SAME_NONZERO_PATTERN));
148 PetscCall(VecDuplicate(D_local, &invD));
149 PetscCall(VecCopy(D_local, invD));
150 PetscCall(VecReciprocal(invD));
151 PetscCall(MatTranspose(YtS_local, MAT_INITIAL_MATRIX, &stril_StY));
152 if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(stril_StY));
153
154 PetscCall(MatDiagonalScale(stril_StY, NULL, invD));
155 PetscCall(MatMatMult(stril_StY, YtS_local, MAT_REUSE_MATRIX, PETSC_DETERMINE, &M00_local));
156 PetscCall(MatAXPY(M00_local, 1.0, StB0S_local, UNKNOWN_NONZERO_PATTERN));
157 PetscCall(LMProductsMakeHermitian(M00_local, oldest, next));
158 PetscCall(LMProductsOnesOnUnusedDiagonal(M00_local, oldest, next));
159 PetscCall(MatSetOption(M00_local, MAT_HERMITIAN, PETSC_TRUE));
160 PetscCall(MatSetOption(M00_local, MAT_SPD, PETSC_TRUE));
161 PetscCall(MatCholeskyFactor(M00_local, NULL, NULL));
162 PetscCall(MatDestroy(&stril_StY));
163 PetscCall(VecDestroy(&invD));
164 }
165 PetscCall(LMProductsRestoreLocalDiagonal(D, &D_local));
166 PetscCall(LMProductsRestoreLocalMatrix(YtS, &YtS_local, NULL));
167 PetscCall(LMProductsRestoreLocalMatrix(StB0S, &StB0S_local, NULL));
168 }
169 PetscCall(LMProductsRestoreLocalMatrix(M00, &M00_local, &next));
170 PetscFunctionReturn(PETSC_SUCCESS);
171 }
172
BFGSKernel_CompactDense(Mat B,MatLMVMMode mode,Vec X,Vec BX)173 PETSC_INTERN PetscErrorCode BFGSKernel_CompactDense(Mat B, MatLMVMMode mode, Vec X, Vec BX)
174 {
175 PetscInt oldest, next;
176
177 PetscFunctionBegin;
178 PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, BX));
179 PetscCall(MatLMVMGetRange(B, &oldest, &next));
180 if (next > oldest) {
181 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
182 Mat_SymBrdn *bfgs = (Mat_SymBrdn *)lmvm->ctx;
183 MatLMVMBasisType S_t = LMVMModeMap(LMBASIS_S, mode);
184 MatLMVMBasisType Y_t = LMVMModeMap(LMBASIS_Y, mode);
185 MatLMVMBasisType B0S_t = LMVMModeMap(LMBASIS_B0S, mode);
186 SymBroydenProductsType M00_t = LMVMModeMap(SYMBROYDEN_PRODUCTS_M00, mode);
187 LMBasis S, Y;
188 PetscBool use_B0S;
189 Vec YtX, StB0X, u, v;
190 LMProducts M00, YtS, D;
191
192 PetscCall(BFGSCompactDenseProductsUpdate(B, mode));
193 PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
194 PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
195 PetscCall(MatLMVMGetUpdatedProducts(B, Y_t, S_t, LMBLOCK_STRICT_UPPER_TRIANGLE, &YtS));
196 PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_Y, LMBASIS_S, LMBLOCK_DIAGONAL, &D));
197 M00 = bfgs->products[M00_t];
198
199 PetscCall(MatLMVMGetWorkRow(B, &YtX));
200 PetscCall(MatLMVMGetWorkRow(B, &StB0X));
201 PetscCall(MatLMVMGetWorkRow(B, &u));
202 PetscCall(MatLMVMGetWorkRow(B, &v));
203
204 PetscCall(LMBasisGEMVH(Y, oldest, next, 1.0, X, 0.0, YtX));
205 PetscCall(SymBroydenCompactDenseKernelUseB0S(B, mode, X, &use_B0S));
206 if (use_B0S) PetscCall(MatLMVMBasisGEMVH(B, B0S_t, oldest, next, 1.0, X, 0.0, StB0X));
207 else PetscCall(LMBasisGEMVH(S, oldest, next, 1.0, BX, 0.0, StB0X));
208
209 PetscCall(LMProductsSolve(D, oldest, next, YtX, YtX, /* ^H */ PETSC_FALSE));
210 PetscCall(LMProductsMult(YtS, oldest, next, 1.0, YtX, 1.0, StB0X, /* ^H */ PETSC_TRUE));
211 PetscCall(LMProductsSolve(M00, oldest, next, StB0X, u, PETSC_FALSE));
212 PetscCall(VecScale(u, -1.0));
213 PetscCall(LMProductsMult(YtS, oldest, next, 1.0, u, 0.0, v, /* ^H */ PETSC_FALSE));
214 PetscCall(LMProductsSolve(D, oldest, next, v, v, PETSC_FALSE));
215 PetscCall(VecAXPY(v, 1.0, YtX));
216
217 PetscCall(LMBasisGEMV(Y, oldest, next, 1.0, v, 1.0, BX));
218 PetscCall(MatLMVMBasisGEMV(B, B0S_t, oldest, next, 1.0, u, 1.0, BX));
219
220 PetscCall(MatLMVMRestoreWorkRow(B, &v));
221 PetscCall(MatLMVMRestoreWorkRow(B, &u));
222 PetscCall(MatLMVMRestoreWorkRow(B, &StB0X));
223 PetscCall(MatLMVMRestoreWorkRow(B, &YtX));
224 }
225 PetscFunctionReturn(PETSC_SUCCESS);
226 }
227
MatMult_LMVMBFGS_Recursive(Mat B,Vec X,Vec Y)228 static PetscErrorCode MatMult_LMVMBFGS_Recursive(Mat B, Vec X, Vec Y)
229 {
230 PetscFunctionBegin;
231 PetscCall(BFGSKernel_Recursive(B, MATLMVM_MODE_PRIMAL, X, Y));
232 PetscFunctionReturn(PETSC_SUCCESS);
233 }
234
MatMult_LMVMBFGS_CompactDense(Mat B,Vec X,Vec Y)235 static PetscErrorCode MatMult_LMVMBFGS_CompactDense(Mat B, Vec X, Vec Y)
236 {
237 PetscFunctionBegin;
238 PetscCall(BFGSKernel_CompactDense(B, MATLMVM_MODE_PRIMAL, X, Y));
239 PetscFunctionReturn(PETSC_SUCCESS);
240 }
241
MatSolve_LMVMBFGS_Recursive(Mat B,Vec X,Vec HX)242 static PetscErrorCode MatSolve_LMVMBFGS_Recursive(Mat B, Vec X, Vec HX)
243 {
244 PetscFunctionBegin;
245 PetscCall(DFPKernel_Recursive(B, MATLMVM_MODE_DUAL, X, HX));
246 PetscFunctionReturn(PETSC_SUCCESS);
247 }
248
MatSolve_LMVMBFGS_CompactDense(Mat B,Vec X,Vec HX)249 static PetscErrorCode MatSolve_LMVMBFGS_CompactDense(Mat B, Vec X, Vec HX)
250 {
251 PetscFunctionBegin;
252 PetscCall(DFPKernel_CompactDense(B, MATLMVM_MODE_DUAL, X, HX));
253 PetscFunctionReturn(PETSC_SUCCESS);
254 }
255
MatSolve_LMVMBFGS_Dense(Mat B,Vec X,Vec HX)256 static PetscErrorCode MatSolve_LMVMBFGS_Dense(Mat B, Vec X, Vec HX)
257 {
258 PetscFunctionBegin;
259 PetscCall(DFPKernel_Dense(B, MATLMVM_MODE_DUAL, X, HX));
260 PetscFunctionReturn(PETSC_SUCCESS);
261 }
262
MatSetFromOptions_LMVMBFGS(Mat B,PetscOptionItems PetscOptionsObject)263 static PetscErrorCode MatSetFromOptions_LMVMBFGS(Mat B, PetscOptionItems PetscOptionsObject)
264 {
265 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
266 Mat_SymBrdn *lbfgs = (Mat_SymBrdn *)lmvm->ctx;
267
268 PetscFunctionBegin;
269 PetscCall(MatSetFromOptions_LMVM(B, PetscOptionsObject));
270 PetscOptionsHeadBegin(PetscOptionsObject, "L-BFGS method for approximating SPD Jacobian actions (MATLMVMBFGS)");
271 PetscCall(SymBroydenRescaleSetFromOptions(B, lbfgs->rescale, PetscOptionsObject));
272 PetscOptionsHeadEnd();
273 PetscFunctionReturn(PETSC_SUCCESS);
274 }
275
MatLMVMSetMultAlgorithm_BFGS(Mat B)276 static PetscErrorCode MatLMVMSetMultAlgorithm_BFGS(Mat B)
277 {
278 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
279
280 PetscFunctionBegin;
281 switch (lmvm->mult_alg) {
282 case MAT_LMVM_MULT_RECURSIVE:
283 lmvm->ops->mult = MatMult_LMVMBFGS_Recursive;
284 lmvm->ops->solve = MatSolve_LMVMBFGS_Recursive;
285 break;
286 case MAT_LMVM_MULT_DENSE:
287 lmvm->ops->mult = MatMult_LMVMBFGS_CompactDense;
288 lmvm->ops->solve = MatSolve_LMVMBFGS_Dense;
289 break;
290 case MAT_LMVM_MULT_COMPACT_DENSE:
291 lmvm->ops->mult = MatMult_LMVMBFGS_CompactDense;
292 lmvm->ops->solve = MatSolve_LMVMBFGS_CompactDense;
293 break;
294 }
295 lmvm->ops->multht = lmvm->ops->mult;
296 lmvm->ops->solveht = lmvm->ops->solve;
297 PetscFunctionReturn(PETSC_SUCCESS);
298 }
299
MatCreate_LMVMBFGS(Mat B)300 PetscErrorCode MatCreate_LMVMBFGS(Mat B)
301 {
302 Mat_LMVM *lmvm;
303 Mat_SymBrdn *lbfgs;
304
305 PetscFunctionBegin;
306 PetscCall(MatCreate_LMVMSymBrdn(B));
307 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMBFGS));
308 B->ops->setfromoptions = MatSetFromOptions_LMVMBFGS;
309
310 lmvm = (Mat_LMVM *)B->data;
311 lmvm->ops->setmultalgorithm = MatLMVMSetMultAlgorithm_BFGS;
312 PetscCall(MatLMVMSetMultAlgorithm_BFGS(B));
313
314 lbfgs = (Mat_SymBrdn *)lmvm->ctx;
315
316 lbfgs->phi_scalar = 0.0;
317 lbfgs->psi_scalar = 1.0;
318 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSymBroydenSetPhi_C", NULL));
319 PetscFunctionReturn(PETSC_SUCCESS);
320 }
321
322 /*@
323 MatCreateLMVMBFGS - Creates a limited-memory Broyden-Fletcher-Goldfarb-Shano (BFGS)
324 matrix used for approximating Jacobians. L-BFGS is symmetric positive-definite by
325 construction, and is commonly used to approximate Hessians in optimization
326 problems.
327
328 To use the L-BFGS matrix with other vector types, the matrix must be
329 created using `MatCreate()` and `MatSetType()`, followed by `MatLMVMAllocate()`.
330 This ensures that the internal storage and work vectors are duplicated from the
331 correct type of vector.
332
333 Collective
334
335 Input Parameters:
336 + comm - MPI communicator
337 . n - number of local rows for storage vectors
338 - N - global size of the storage vectors
339
340 Output Parameter:
341 . B - the matrix
342
343 Options Database Keys:
344 + -mat_lmvm_scale_type - (developer) type of scaling applied to J0 (none, scalar, diagonal)
345 . -mat_lmvm_theta - (developer) convex ratio between BFGS and DFP components of the diagonal J0 scaling
346 . -mat_lmvm_rho - (developer) update limiter for the J0 scaling
347 . -mat_lmvm_alpha - (developer) coefficient factor for the quadratic subproblem in J0 scaling
348 . -mat_lmvm_beta - (developer) exponential factor for the diagonal J0 scaling
349 - -mat_lmvm_sigma_hist - (developer) number of past updates to use in J0 scaling
350
351 Level: intermediate
352
353 Note:
354 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`
355 paradigm instead of this routine directly.
356
357 .seealso: [](ch_ksp), `MatCreate()`, `MATLMVM`, `MATLMVMBFGS`, `MatCreateLMVMDFP()`, `MatCreateLMVMSR1()`,
358 `MatCreateLMVMBroyden()`, `MatCreateLMVMBadBroyden()`, `MatCreateLMVMSymBroyden()`
359 @*/
MatCreateLMVMBFGS(MPI_Comm comm,PetscInt n,PetscInt N,Mat * B)360 PetscErrorCode MatCreateLMVMBFGS(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
361 {
362 PetscFunctionBegin;
363 PetscCall(KSPInitializePackage());
364 PetscCall(MatCreate(comm, B));
365 PetscCall(MatSetSizes(*B, n, n, N, N));
366 PetscCall(MatSetType(*B, MATLMVMBFGS));
367 PetscCall(MatSetUp(*B));
368 PetscFunctionReturn(PETSC_SUCCESS);
369 }
370