1 #include <../src/ksp/ksp/utils/lmvm/dense/denseqn.h> /*I "petscksp.h" I*/
2 #include <../src/ksp/ksp/utils/lmvm/blas_cyclic/blas_cyclic.h>
3 #include <petscblaslapack.h>
4 #include <petscmat.h>
5 #include <petscsys.h>
6 #include <petscsystypes.h>
7 #include <petscis.h>
8 #include <petscoptions.h>
9 #include <petscdevice.h>
10 #include <petsc/private/deviceimpl.h>
11
12 static PetscErrorCode MatMult_LMVMDQN(Mat, Vec, Vec);
13 static PetscErrorCode MatMult_LMVMDBFGS(Mat, Vec, Vec);
14 static PetscErrorCode MatMult_LMVMDDFP(Mat, Vec, Vec);
15 static PetscErrorCode MatSolve_LMVMDQN(Mat, Vec, Vec);
16 static PetscErrorCode MatSolve_LMVMDBFGS(Mat, Vec, Vec);
17 static PetscErrorCode MatSolve_LMVMDDFP(Mat, Vec, Vec);
18
recycle_index(PetscInt m,PetscInt idx)19 static inline PetscInt recycle_index(PetscInt m, PetscInt idx)
20 {
21 return idx % m;
22 }
23
history_index(PetscInt m,PetscInt num_updates,PetscInt idx)24 static inline PetscInt history_index(PetscInt m, PetscInt num_updates, PetscInt idx)
25 {
26 return (idx - num_updates) + PetscMin(m, num_updates);
27 }
28
oldest_update(PetscInt m,PetscInt idx)29 static inline PetscInt oldest_update(PetscInt m, PetscInt idx)
30 {
31 return PetscMax(0, idx - m);
32 }
33
MatView_LMVMDQN(Mat B,PetscViewer pv)34 static PetscErrorCode MatView_LMVMDQN(Mat B, PetscViewer pv)
35 {
36 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
37 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
38
39 PetscBool isascii;
40
41 PetscFunctionBegin;
42 PetscCall(PetscObjectTypeCompare((PetscObject)pv, PETSCVIEWERASCII, &isascii));
43 PetscCall(MatView_LMVM(B, pv));
44 PetscCall(SymBroydenRescaleView(lqn->rescale, pv));
45 if (isascii) PetscCall(PetscViewerASCIIPrintf(pv, "Counts: S x : %" PetscInt_FMT ", S^T x : %" PetscInt_FMT ", Y x : %" PetscInt_FMT ", Y^T x: %" PetscInt_FMT "\n", lqn->S_count, lqn->St_count, lqn->Y_count, lqn->Yt_count));
46 PetscFunctionReturn(PETSC_SUCCESS);
47 }
48
MatLMVMDQNResetDestructive(Mat B)49 static PetscErrorCode MatLMVMDQNResetDestructive(Mat B)
50 {
51 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
52 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
53
54 PetscFunctionBegin;
55 PetscCall(MatDestroy(&lqn->HY));
56 PetscCall(MatDestroy(&lqn->BS));
57 PetscCall(MatDestroy(&lqn->StY_triu));
58 PetscCall(MatDestroy(&lqn->YtS_triu));
59 PetscCall(VecDestroy(&lqn->StFprev));
60 PetscCall(VecDestroy(&lqn->Fprev_ref));
61 lqn->Fprev_state = 0;
62 PetscCall(MatDestroy(&lqn->YtS_triu_strict));
63 PetscCall(MatDestroy(&lqn->StY_triu_strict));
64 PetscCall(MatDestroy(&lqn->StBS));
65 PetscCall(MatDestroy(&lqn->YtHY));
66 PetscCall(MatDestroy(&lqn->J));
67 PetscCall(MatDestroy(&lqn->temp_mat));
68 PetscCall(VecDestroy(&lqn->diag_vec));
69 PetscCall(VecDestroy(&lqn->diag_vec_recycle_order));
70 PetscCall(VecDestroy(&lqn->inv_diag_vec));
71 PetscCall(VecDestroy(&lqn->column_work));
72 PetscCall(VecDestroy(&lqn->column_work2));
73 PetscCall(VecDestroy(&lqn->rwork1));
74 PetscCall(VecDestroy(&lqn->rwork2));
75 PetscCall(VecDestroy(&lqn->rwork3));
76 PetscCall(VecDestroy(&lqn->rwork2_local));
77 PetscCall(VecDestroy(&lqn->rwork3_local));
78 PetscCall(VecDestroy(&lqn->cyclic_work_vec));
79 PetscCall(VecDestroyVecs(lmvm->m, &lqn->PQ));
80 PetscCall(PetscFree(lqn->stp));
81 PetscCall(PetscFree(lqn->yts));
82 PetscCall(PetscFree(lqn->ytq));
83 lqn->allocated = PETSC_FALSE;
84 PetscFunctionReturn(PETSC_SUCCESS);
85 }
86
MatReset_LMVMDQN_Internal(Mat B,MatLMVMResetMode mode)87 static PetscErrorCode MatReset_LMVMDQN_Internal(Mat B, MatLMVMResetMode mode)
88 {
89 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
90 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
91
92 PetscFunctionBegin;
93 lqn->watchdog = 0;
94 lqn->needPQ = PETSC_TRUE;
95 lqn->num_updates = 0;
96 lqn->num_mult_updates = 0;
97 if (MatLMVMResetClearsBases(mode)) PetscCall(MatLMVMDQNResetDestructive(B));
98 else {
99 if (lqn->BS) PetscCall(MatZeroEntries(lqn->BS));
100 if (lqn->HY) PetscCall(MatZeroEntries(lqn->HY));
101 if (lqn->StY_triu) { /* Set to identity by default so it is invertible */
102 PetscCall(MatZeroEntries(lqn->StY_triu));
103 PetscCall(MatShift(lqn->StY_triu, 1.0));
104 }
105 if (lqn->YtS_triu) {
106 PetscCall(MatZeroEntries(lqn->YtS_triu));
107 PetscCall(MatShift(lqn->YtS_triu, 1.0));
108 }
109 if (lqn->YtS_triu_strict) PetscCall(MatZeroEntries(lqn->YtS_triu_strict));
110 if (lqn->StY_triu_strict) PetscCall(MatZeroEntries(lqn->StY_triu_strict));
111 if (lqn->StBS) {
112 PetscCall(MatZeroEntries(lqn->StBS));
113 PetscCall(MatShift(lqn->StBS, 1.0));
114 }
115 if (lqn->YtHY) {
116 PetscCall(MatZeroEntries(lqn->YtHY));
117 PetscCall(MatShift(lqn->YtHY, 1.0));
118 }
119 if (lqn->Fprev_ref) PetscCall(VecDestroy(&lqn->Fprev_ref));
120 lqn->Fprev_state = 0;
121 if (lqn->StFprev) PetscCall(VecZeroEntries(lqn->StFprev));
122 }
123 PetscFunctionReturn(PETSC_SUCCESS);
124 }
125
MatReset_LMVMDQN(Mat B,MatLMVMResetMode mode)126 static PetscErrorCode MatReset_LMVMDQN(Mat B, MatLMVMResetMode mode)
127 {
128 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
129 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
130
131 PetscFunctionBegin;
132 PetscCall(SymBroydenRescaleReset(B, lqn->rescale, mode));
133 PetscCall(MatReset_LMVMDQN_Internal(B, mode));
134 PetscFunctionReturn(PETSC_SUCCESS);
135 }
136
MatAllocate_LMVMDQN_Internal(Mat B)137 static PetscErrorCode MatAllocate_LMVMDQN_Internal(Mat B)
138 {
139 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
140 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
141
142 PetscFunctionBegin;
143 if (!lqn->allocated) {
144 if (lmvm->m > 0) {
145 PetscMPIInt rank;
146 PetscInt n, N, m, M;
147 PetscBool is_dbfgs, is_ddfp, is_dqn;
148 VecType vec_type;
149 MPI_Comm comm = PetscObjectComm((PetscObject)B);
150 Mat Sfull = lmvm->basis[LMBASIS_S]->vecs;
151
152 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDBFGS, &is_dbfgs));
153 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDDFP, &is_ddfp));
154 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDQN, &is_dqn));
155
156 PetscCallMPI(MPI_Comm_rank(comm, &rank));
157 PetscCall(MatGetSize(B, &N, NULL));
158 PetscCall(MatGetLocalSize(B, &n, NULL));
159 M = lmvm->m;
160 m = (rank == 0) ? M : 0;
161
162 /* For DBFGS: Create data needed for MatSolve() eagerly; data needed for MatMult() will be created on demand
163 * For DDFP : Create data needed for MatMult() eagerly; data needed for MatSolve() will be created on demand
164 * For DQN : Create all data eagerly */
165 PetscCall(VecGetType(lmvm->Xprev, &vec_type));
166 if (is_dqn) {
167 PetscCall(MatCreateDenseFromVecType(comm, vec_type, m, m, M, M, -1, NULL, &lqn->StY_triu));
168 PetscCall(MatCreateDenseFromVecType(comm, vec_type, m, m, M, M, -1, NULL, &lqn->YtS_triu));
169 PetscCall(MatCreateVecs(lqn->StY_triu, &lqn->diag_vec, &lqn->rwork1));
170 PetscCall(MatCreateVecs(lqn->StY_triu, &lqn->rwork2, &lqn->rwork3));
171 } else if (is_ddfp) {
172 PetscCall(MatCreateDenseFromVecType(comm, vec_type, m, m, M, M, -1, NULL, &lqn->YtS_triu));
173 PetscCall(MatDuplicate(Sfull, MAT_SHARE_NONZERO_PATTERN, &lqn->HY));
174 PetscCall(MatCreateVecs(lqn->YtS_triu, &lqn->diag_vec, &lqn->rwork1));
175 PetscCall(MatCreateVecs(lqn->YtS_triu, &lqn->rwork2, &lqn->rwork3));
176 } else if (is_dbfgs) {
177 PetscCall(MatCreateDenseFromVecType(comm, vec_type, m, m, M, M, -1, NULL, &lqn->StY_triu));
178 PetscCall(MatDuplicate(Sfull, MAT_SHARE_NONZERO_PATTERN, &lqn->BS));
179 PetscCall(MatCreateVecs(lqn->StY_triu, &lqn->diag_vec, &lqn->rwork1));
180 PetscCall(MatCreateVecs(lqn->StY_triu, &lqn->rwork2, &lqn->rwork3));
181 } else {
182 SETERRQ(PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_INCOMP, "MatAllocate_LMVMDQN is only available for dense derived types. (DBFGS, DDFP, DQN");
183 }
184 /* initialize StY_triu and YtS_triu to identity, if they exist, so it is invertible */
185 if (lqn->StY_triu) {
186 PetscCall(MatZeroEntries(lqn->StY_triu));
187 PetscCall(MatShift(lqn->StY_triu, 1.0));
188 }
189 if (lqn->YtS_triu) {
190 PetscCall(MatZeroEntries(lqn->YtS_triu));
191 PetscCall(MatShift(lqn->YtS_triu, 1.0));
192 }
193 if (lqn->use_recursive && (is_dbfgs || is_ddfp)) {
194 PetscCall(VecDuplicateVecs(lmvm->Xprev, lmvm->m, &lqn->PQ));
195 PetscCall(VecDuplicate(lmvm->Xprev, &lqn->column_work2));
196 PetscCall(PetscMalloc1(lmvm->m, &lqn->yts));
197 if (is_dbfgs) {
198 PetscCall(PetscMalloc1(lmvm->m, &lqn->stp));
199 } else if (is_ddfp) {
200 PetscCall(PetscMalloc1(lmvm->m, &lqn->ytq));
201 }
202 }
203 PetscCall(VecDuplicate(lqn->rwork2, &lqn->cyclic_work_vec));
204 PetscCall(VecZeroEntries(lqn->rwork1));
205 PetscCall(VecZeroEntries(lqn->rwork2));
206 PetscCall(VecZeroEntries(lqn->rwork3));
207 PetscCall(VecZeroEntries(lqn->diag_vec));
208 }
209 PetscCall(VecDuplicate(lmvm->Xprev, &lqn->column_work));
210 lqn->allocated = PETSC_TRUE;
211 }
212 PetscFunctionReturn(PETSC_SUCCESS);
213 }
214
MatSetUp_LMVMDQN(Mat B)215 static PetscErrorCode MatSetUp_LMVMDQN(Mat B)
216 {
217 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
218 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
219
220 PetscFunctionBegin;
221 PetscCall(MatSetUp_LMVM(B));
222 PetscCall(SymBroydenRescaleInitializeJ0(B, lqn->rescale));
223 PetscCall(MatAllocate_LMVMDQN_Internal(B));
224 PetscFunctionReturn(PETSC_SUCCESS);
225 }
226
MatSetFromOptions_LMVMDQN(Mat B,PetscOptionItems PetscOptionsObject)227 static PetscErrorCode MatSetFromOptions_LMVMDQN(Mat B, PetscOptionItems PetscOptionsObject)
228 {
229 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
230 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
231 PetscBool is_dbfgs, is_ddfp, is_dqn;
232
233 PetscFunctionBegin;
234 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDBFGS, &is_dbfgs));
235 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDDFP, &is_ddfp));
236 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDQN, &is_dqn));
237 PetscCall(MatSetFromOptions_LMVM(B, PetscOptionsObject));
238 PetscOptionsHeadBegin(PetscOptionsObject, "Dense symmetric Broyden method for approximating SPD Jacobian actions");
239 if (is_dqn) {
240 PetscCall(PetscOptionsEnum("-mat_lqn_type", "Implementation options for L-QN", "MatLMVMDenseType", MatLMVMDenseTypes, (PetscEnum)lqn->strategy, (PetscEnum *)&lqn->strategy, NULL));
241 } else if (is_dbfgs) {
242 PetscCall(PetscOptionsBool("-mat_lbfgs_recursive", "Use recursive formulation for MatMult_LMVMDBFGS, instead of Cholesky", "", lqn->use_recursive, &lqn->use_recursive, NULL));
243 PetscCall(PetscOptionsEnum("-mat_lbfgs_type", "Implementation options for L-BFGS", "MatLMVMDenseType", MatLMVMDenseTypes, (PetscEnum)lqn->strategy, (PetscEnum *)&lqn->strategy, NULL));
244 } else if (is_ddfp) {
245 PetscCall(PetscOptionsBool("-mat_ldfp_recursive", "Use recursive formulation for MatSolve_LMVMDDFP, instead of Cholesky", "", lqn->use_recursive, &lqn->use_recursive, NULL));
246 PetscCall(PetscOptionsEnum("-mat_ldfp_type", "Implementation options for L-DFP", "MatLMVMDenseType", MatLMVMDenseTypes, (PetscEnum)lqn->strategy, (PetscEnum *)&lqn->strategy, NULL));
247 } else {
248 SETERRQ(PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_INCOMP, "MatSetFromOptions_LMVMDQN is only available for dense derived types. (DBFGS, DDFP, DQN");
249 }
250 PetscCall(SymBroydenRescaleSetFromOptions(B, lqn->rescale, PetscOptionsObject));
251 PetscOptionsHeadEnd();
252 PetscFunctionReturn(PETSC_SUCCESS);
253 }
254
MatDestroy_LMVMDQN(Mat B)255 static PetscErrorCode MatDestroy_LMVMDQN(Mat B)
256 {
257 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
258 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
259
260 PetscFunctionBegin;
261 PetscCall(SymBroydenRescaleDestroy(&lqn->rescale));
262 PetscCall(MatReset_LMVMDQN_Internal(B, MAT_LMVM_RESET_ALL));
263 PetscCall(PetscFree(lqn->workscalar));
264 PetscCall(PetscFree(lmvm->ctx));
265 PetscCall(MatDestroy_LMVM(B));
266 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSymBroydenSetDelta_C", NULL));
267 PetscFunctionReturn(PETSC_SUCCESS);
268 }
269
MatUpdate_LMVMDQN(Mat B,Vec X,Vec F)270 static PetscErrorCode MatUpdate_LMVMDQN(Mat B, Vec X, Vec F)
271 {
272 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
273 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
274 Mat Sfull = lmvm->basis[LMBASIS_S]->vecs;
275 Mat Yfull = lmvm->basis[LMBASIS_Y]->vecs;
276
277 PetscBool is_ddfp, is_dbfgs, is_dqn;
278 PetscDeviceContext dctx;
279
280 PetscFunctionBegin;
281 if (!lmvm->m) PetscFunctionReturn(PETSC_SUCCESS);
282 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDBFGS, &is_dbfgs));
283 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDDFP, &is_ddfp));
284 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDQN, &is_dqn));
285 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
286 if (lmvm->prev_set) {
287 Vec FX[2];
288 PetscScalar dotFX[2];
289 PetscScalar stFprev;
290 PetscScalar curvature, yTy;
291 PetscReal curvtol;
292
293 /* Compute the new (S = X - Xprev) and (Y = F - Fprev) vectors */
294 PetscCall(VecAYPX(lmvm->Xprev, -1.0, X));
295 /* Test if the updates can be accepted */
296 FX[0] = lmvm->Fprev; /* dotFX[0] = s^T Fprev */
297 FX[1] = F; /* dotFX[1] = s^T F */
298 PetscCall(VecMDot(lmvm->Xprev, 2, FX, dotFX));
299 PetscCall(VecAYPX(lmvm->Fprev, -1.0, F));
300 PetscCall(VecDot(lmvm->Fprev, lmvm->Fprev, &yTy));
301 stFprev = PetscConj(dotFX[0]);
302 curvature = PetscConj(dotFX[1] - dotFX[0]); /* s^T y */
303 if (PetscRealPart(yTy) < lmvm->eps) {
304 curvtol = 0.0;
305 } else {
306 curvtol = lmvm->eps * PetscRealPart(yTy);
307 }
308 if (PetscRealPart(curvature) > curvtol) {
309 PetscInt m = lmvm->m;
310 PetscInt k = lmvm->k;
311 PetscInt h_old = k - oldest_update(m, k);
312 PetscInt h_new = k + 1 - oldest_update(m, k + 1);
313 PetscInt idx = recycle_index(m, k);
314
315 /* Update is good, accept it */
316 PetscCall(MatUpdateKernel_LMVM(B, lmvm->Xprev, lmvm->Fprev));
317 lqn->num_updates++;
318 lqn->watchdog = 0;
319 lqn->needPQ = PETSC_TRUE;
320
321 if (h_old == m && lqn->strategy == MAT_LMVM_DENSE_REORDER) {
322 if (is_dqn) {
323 PetscCall(MatMove_LR3(B, lqn->StY_triu, m - 1));
324 PetscCall(MatMove_LR3(B, lqn->YtS_triu, m - 1));
325 } else if (is_dbfgs) {
326 PetscCall(MatMove_LR3(B, lqn->StY_triu, m - 1));
327 } else if (is_ddfp) {
328 PetscCall(MatMove_LR3(B, lqn->YtS_triu, m - 1));
329 } else {
330 SETERRQ(PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_INCOMP, "MatUpdate_LMVMDQN is only available for dense derived types. (DBFGS, DDFP, DQN");
331 }
332 }
333
334 if (lqn->use_recursive && (is_dbfgs || is_ddfp)) lqn->yts[idx] = PetscRealPart(curvature);
335
336 if (is_dqn || is_dbfgs) { /* implement the scheme of Byrd, Nocedal, and Schnabel to save a MatMultTranspose call in the common case the *
337 * H_k is immediately applied to F after begin updated. The S^T y computation can be split up as S^T (F - F_prev) */
338 PetscInt local_n;
339 PetscScalar *StFprev;
340 PetscMemType memtype;
341 PetscInt StYidx;
342
343 StYidx = (lqn->strategy == MAT_LMVM_DENSE_REORDER) ? history_index(m, lqn->num_updates, k) : idx;
344 if (!lqn->StFprev) PetscCall(VecDuplicate(lqn->rwork1, &lqn->StFprev));
345 PetscCall(VecGetLocalSize(lqn->StFprev, &local_n));
346 PetscCall(VecGetArrayAndMemType(lqn->StFprev, &StFprev, &memtype));
347 if (local_n) {
348 if (PetscMemTypeHost(memtype)) {
349 StFprev[idx] = stFprev;
350 } else {
351 PetscCall(PetscDeviceRegisterMemory(&stFprev, PETSC_MEMTYPE_HOST, 1 * sizeof(stFprev)));
352 PetscCall(PetscDeviceRegisterMemory(StFprev, memtype, local_n * sizeof(*StFprev)));
353 PetscCall(PetscDeviceArrayCopy(dctx, &StFprev[idx], &stFprev, 1));
354 }
355 }
356 PetscCall(VecRestoreArrayAndMemType(lqn->StFprev, &StFprev));
357
358 {
359 Vec this_sy_col;
360 /* Now StFprev is updated for the new S vector. Write -StFprev into the appropriate row */
361 PetscCall(MatDenseGetColumnVecWrite(lqn->StY_triu, StYidx, &this_sy_col));
362 PetscCall(VecAXPBY(this_sy_col, -1.0, 0.0, lqn->StFprev));
363
364 /* Now compute the new StFprev */
365 PetscCall(MatMultHermitianTransposeColumnRange(Sfull, F, lqn->StFprev, 0, h_new));
366 lqn->St_count++;
367
368 /* Now add StFprev: this_sy_col == S^T (F - Fprev) == S^T y */
369 PetscCall(VecAXPY(this_sy_col, 1.0, lqn->StFprev));
370
371 if (lqn->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, this_sy_col, lqn->num_updates, lqn->cyclic_work_vec));
372 PetscCall(MatDenseRestoreColumnVecWrite(lqn->StY_triu, StYidx, &this_sy_col));
373 }
374 }
375
376 if (is_ddfp || is_dqn) {
377 PetscInt YtSidx;
378
379 YtSidx = (lqn->strategy == MAT_LMVM_DENSE_REORDER) ? history_index(m, lqn->num_updates, k) : idx;
380
381 {
382 Vec this_ys_col;
383
384 PetscCall(MatDenseGetColumnVecWrite(lqn->YtS_triu, YtSidx, &this_ys_col));
385 PetscCall(MatMultHermitianTransposeColumnRange(Yfull, lmvm->Xprev, this_ys_col, 0, h_new));
386 lqn->Yt_count++;
387
388 if (lqn->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, this_ys_col, lqn->num_updates, lqn->cyclic_work_vec));
389 PetscCall(MatDenseRestoreColumnVecWrite(lqn->YtS_triu, YtSidx, &this_ys_col));
390 }
391 }
392
393 if (is_dbfgs || is_dqn) {
394 PetscCall(MatGetDiagonal(lqn->StY_triu, lqn->diag_vec));
395 } else if (is_ddfp) {
396 PetscCall(MatGetDiagonal(lqn->YtS_triu, lqn->diag_vec));
397 } else {
398 SETERRQ(PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_INCOMP, "MatUpdate_LMVMDQN is only available for dense derived types. (DBFGS, DDFP, DQN");
399 }
400
401 if (lqn->strategy == MAT_LMVM_DENSE_REORDER) {
402 if (!lqn->diag_vec_recycle_order) PetscCall(VecDuplicate(lqn->diag_vec, &lqn->diag_vec_recycle_order));
403 PetscCall(VecCopy(lqn->diag_vec, lqn->diag_vec_recycle_order));
404 PetscCall(VecHistoryOrderToRecycleOrder(B, lqn->diag_vec_recycle_order, lqn->num_updates, lqn->cyclic_work_vec));
405 } else {
406 if (!lqn->diag_vec_recycle_order) {
407 PetscCall(PetscObjectReference((PetscObject)lqn->diag_vec));
408 lqn->diag_vec_recycle_order = lqn->diag_vec;
409 }
410 }
411
412 PetscCall(SymBroydenRescaleUpdate(B, lqn->rescale));
413 } else {
414 /* Update is bad, skip it */
415 ++lmvm->nrejects;
416 ++lqn->watchdog;
417 PetscInt m = lmvm->m;
418 PetscInt k = lmvm->k;
419 PetscInt h = k - oldest_update(m, k);
420
421 /* we still have to maintain StFprev */
422 if (!lqn->StFprev) PetscCall(VecDuplicate(lqn->rwork1, &lqn->StFprev));
423 PetscCall(MatMultHermitianTransposeColumnRange(Sfull, F, lqn->StFprev, 0, h));
424 lqn->St_count++;
425 }
426 }
427
428 if (lqn->watchdog > lqn->max_seq_rejects) PetscCall(MatLMVMReset(B, PETSC_FALSE));
429
430 /* Save the solution and function to be used in the next update */
431 PetscCall(VecCopy(X, lmvm->Xprev));
432 PetscCall(VecCopy(F, lmvm->Fprev));
433 PetscCall(PetscObjectReference((PetscObject)F));
434 PetscCall(VecDestroy(&lqn->Fprev_ref));
435 lqn->Fprev_ref = F;
436 PetscCall(PetscObjectStateGet((PetscObject)F, &lqn->Fprev_state));
437 lmvm->prev_set = PETSC_TRUE;
438 PetscFunctionReturn(PETSC_SUCCESS);
439 }
440
MatDestroyThenCopy(Mat src,Mat * dst)441 static PetscErrorCode MatDestroyThenCopy(Mat src, Mat *dst)
442 {
443 PetscFunctionBegin;
444 PetscCall(MatDestroy(dst));
445 if (src) PetscCall(MatDuplicate(src, MAT_COPY_VALUES, dst));
446 PetscFunctionReturn(PETSC_SUCCESS);
447 }
448
VecDestroyThenCopy(Vec src,Vec * dst)449 static PetscErrorCode VecDestroyThenCopy(Vec src, Vec *dst)
450 {
451 PetscFunctionBegin;
452 PetscCall(VecDestroy(dst));
453 if (src) {
454 PetscCall(VecDuplicate(src, dst));
455 PetscCall(VecCopy(src, *dst));
456 }
457 PetscFunctionReturn(PETSC_SUCCESS);
458 }
459
MatCopy_LMVMDQN(Mat B,Mat M,MatStructure str)460 static PetscErrorCode MatCopy_LMVMDQN(Mat B, Mat M, MatStructure str)
461 {
462 Mat_LMVM *bdata = (Mat_LMVM *)B->data;
463 Mat_DQN *blqn = (Mat_DQN *)bdata->ctx;
464 Mat_LMVM *mdata = (Mat_LMVM *)M->data;
465 Mat_DQN *mlqn = (Mat_DQN *)mdata->ctx;
466 PetscInt i;
467 PetscBool is_dbfgs, is_ddfp, is_dqn;
468
469 PetscFunctionBegin;
470 PetscCall(SymBroydenRescaleCopy(blqn->rescale, mlqn->rescale));
471 mlqn->num_updates = blqn->num_updates;
472 mlqn->num_mult_updates = blqn->num_mult_updates;
473 mlqn->dense_type = blqn->dense_type;
474 mlqn->strategy = blqn->strategy;
475 mlqn->S_count = 0;
476 mlqn->St_count = 0;
477 mlqn->Y_count = 0;
478 mlqn->Yt_count = 0;
479 mlqn->watchdog = blqn->watchdog;
480 mlqn->max_seq_rejects = blqn->max_seq_rejects;
481 mlqn->use_recursive = blqn->use_recursive;
482 mlqn->needPQ = blqn->needPQ;
483 if (blqn->allocated) {
484 PetscCall(MatAllocate_LMVMDQN_Internal(M));
485 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDBFGS, &is_dbfgs));
486 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDDFP, &is_ddfp));
487 PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDQN, &is_dqn));
488 PetscCall(MatDestroyThenCopy(blqn->HY, &mlqn->BS));
489 PetscCall(VecDestroyThenCopy(blqn->StFprev, &mlqn->StFprev));
490 PetscCall(MatDestroyThenCopy(blqn->StY_triu, &mlqn->StY_triu));
491 PetscCall(MatDestroyThenCopy(blqn->StY_triu_strict, &mlqn->StY_triu_strict));
492 PetscCall(MatDestroyThenCopy(blqn->YtS_triu, &mlqn->YtS_triu));
493 PetscCall(MatDestroyThenCopy(blqn->YtS_triu_strict, &mlqn->YtS_triu_strict));
494 PetscCall(MatDestroyThenCopy(blqn->YtHY, &mlqn->YtHY));
495 PetscCall(MatDestroyThenCopy(blqn->StBS, &mlqn->StBS));
496 PetscCall(MatDestroyThenCopy(blqn->J, &mlqn->J));
497 PetscCall(VecDestroyThenCopy(blqn->diag_vec, &mlqn->diag_vec));
498 PetscCall(VecDestroyThenCopy(blqn->diag_vec_recycle_order, &mlqn->diag_vec_recycle_order));
499 PetscCall(VecDestroyThenCopy(blqn->inv_diag_vec, &mlqn->inv_diag_vec));
500 if (blqn->use_recursive && (is_dbfgs || is_ddfp)) {
501 for (i = 0; i < bdata->m; i++) {
502 PetscCall(VecDestroyThenCopy(blqn->PQ[i], &mlqn->PQ[i]));
503 mlqn->yts[i] = blqn->yts[i];
504 if (is_dbfgs) {
505 mlqn->stp[i] = blqn->stp[i];
506 } else if (is_ddfp) {
507 mlqn->ytq[i] = blqn->ytq[i];
508 }
509 }
510 }
511 }
512 PetscCall(PetscObjectReference((PetscObject)blqn->Fprev_ref));
513 PetscCall(VecDestroy(&mlqn->Fprev_ref));
514 mlqn->Fprev_ref = blqn->Fprev_ref;
515 mlqn->Fprev_state = blqn->Fprev_state;
516 PetscFunctionReturn(PETSC_SUCCESS);
517 }
518
MatMult_LMVMDQN(Mat B,Vec X,Vec Z)519 static PetscErrorCode MatMult_LMVMDQN(Mat B, Vec X, Vec Z)
520 {
521 PetscFunctionBegin;
522 PetscCall(MatMult_LMVMDDFP(B, X, Z));
523 PetscFunctionReturn(PETSC_SUCCESS);
524 }
525
MatSolve_LMVMDQN(Mat H,Vec F,Vec dX)526 static PetscErrorCode MatSolve_LMVMDQN(Mat H, Vec F, Vec dX)
527 {
528 PetscFunctionBegin;
529 PetscCall(MatSolve_LMVMDBFGS(H, F, dX));
530 PetscFunctionReturn(PETSC_SUCCESS);
531 }
532
MatLMVMSymBroydenSetDelta_LMVMDQN(Mat B,PetscScalar delta)533 static PetscErrorCode MatLMVMSymBroydenSetDelta_LMVMDQN(Mat B, PetscScalar delta)
534 {
535 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
536 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
537
538 PetscFunctionBegin;
539 PetscCall(SymBroydenRescaleSetDelta(B, lqn->rescale, PetscAbsReal(PetscRealPart(delta))));
540 PetscFunctionReturn(PETSC_SUCCESS);
541 }
542
543 /*
544 This dense representation uses Davidon-Fletcher-Powell (DFP) for MatMult,
545 and Broyden-Fletcher-Goldfarb-Shanno (BFGS) for MatSolve. This implementation
546 results in avoiding costly Cholesky factorization, at the cost of duality cap.
547 Please refer to MatLMVMDDFP and MatLMVMDBFGS for more information.
548 */
MatCreate_LMVMDQN(Mat B)549 PetscErrorCode MatCreate_LMVMDQN(Mat B)
550 {
551 Mat_LMVM *lmvm;
552 Mat_DQN *lqn;
553
554 PetscFunctionBegin;
555 PetscCall(MatCreate_LMVM(B));
556 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMDQN));
557 PetscCall(MatSetOption(B, MAT_HERMITIAN, PETSC_TRUE));
558 PetscCall(MatSetOption(B, MAT_SPD, PETSC_TRUE));
559 PetscCall(MatSetOption(B, MAT_SPD_ETERNAL, PETSC_TRUE));
560 B->ops->view = MatView_LMVMDQN;
561 B->ops->setup = MatSetUp_LMVMDQN;
562 B->ops->setfromoptions = MatSetFromOptions_LMVMDQN;
563 B->ops->destroy = MatDestroy_LMVMDQN;
564
565 lmvm = (Mat_LMVM *)B->data;
566 lmvm->ops->reset = MatReset_LMVMDQN;
567 lmvm->ops->update = MatUpdate_LMVMDQN;
568 lmvm->ops->mult = MatMult_LMVMDQN;
569 lmvm->ops->solve = MatSolve_LMVMDQN;
570 lmvm->ops->copy = MatCopy_LMVMDQN;
571
572 lmvm->ops->multht = lmvm->ops->mult;
573 lmvm->ops->solveht = lmvm->ops->solve;
574
575 PetscCall(PetscNew(&lqn));
576 lmvm->ctx = (void *)lqn;
577 lqn->allocated = PETSC_FALSE;
578 lqn->use_recursive = PETSC_FALSE;
579 lqn->needPQ = PETSC_FALSE;
580 lqn->watchdog = 0;
581 lqn->max_seq_rejects = lmvm->m / 2;
582 lqn->strategy = MAT_LMVM_DENSE_INPLACE;
583
584 PetscCall(SymBroydenRescaleCreate(&lqn->rescale));
585 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSymBroydenSetDelta_C", MatLMVMSymBroydenSetDelta_LMVMDQN));
586 PetscFunctionReturn(PETSC_SUCCESS);
587 }
588
589 /*@
590 MatCreateLMVMDQN - Creates a dense representation of the limited-memory
591 Quasi-Newton approximation to a Hessian.
592
593 Collective
594
595 Input Parameters:
596 + comm - MPI communicator
597 . n - number of local rows for storage vectors
598 - N - global size of the storage vectors
599
600 Output Parameter:
601 . B - the matrix
602
603 Level: advanced
604
605 Note:
606 It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`
607 paradigm instead of this routine directly.
608
609 .seealso: `MatCreate()`, `MATLMVM`, `MATLMVMDBFGS`, `MATLMVMDDFP`, `MatCreateLMVMDDFP()`, `MatCreateLMVMDBFGS()`
610 @*/
MatCreateLMVMDQN(MPI_Comm comm,PetscInt n,PetscInt N,Mat * B)611 PetscErrorCode MatCreateLMVMDQN(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
612 {
613 PetscFunctionBegin;
614 PetscCall(KSPInitializePackage());
615 PetscCall(MatCreate(comm, B));
616 PetscCall(MatSetSizes(*B, n, n, N, N));
617 PetscCall(MatSetType(*B, MATLMVMDQN));
618 PetscCall(MatSetUp(*B));
619 PetscFunctionReturn(PETSC_SUCCESS);
620 }
621
MatDQNApplyJ0Fwd(Mat B,Vec X,Vec Z)622 static PetscErrorCode MatDQNApplyJ0Fwd(Mat B, Vec X, Vec Z)
623 {
624 PetscFunctionBegin;
625 PetscCall(MatLMVMApplyJ0Fwd(B, X, Z));
626 PetscFunctionReturn(PETSC_SUCCESS);
627 }
628
MatDQNApplyJ0Inv(Mat B,Vec F,Vec dX)629 static PetscErrorCode MatDQNApplyJ0Inv(Mat B, Vec F, Vec dX)
630 {
631 PetscFunctionBegin;
632 PetscCall(MatLMVMApplyJ0Inv(B, F, dX));
633 PetscFunctionReturn(PETSC_SUCCESS);
634 }
635
636 /* This is not Bunch-Kaufman LDLT: here L is strictly lower triangular part of STY */
MatGetLDLT(Mat B,Mat result)637 static PetscErrorCode MatGetLDLT(Mat B, Mat result)
638 {
639 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
640 Mat_DQN *lbfgs = (Mat_DQN *)lmvm->ctx;
641 PetscInt m_local;
642
643 PetscFunctionBegin;
644 if (!lbfgs->temp_mat) PetscCall(MatDuplicate(lbfgs->YtS_triu_strict, MAT_SHARE_NONZERO_PATTERN, &lbfgs->temp_mat));
645 PetscCall(MatCopy(lbfgs->YtS_triu_strict, lbfgs->temp_mat, SAME_NONZERO_PATTERN));
646 PetscCall(MatDiagonalScale(lbfgs->temp_mat, lbfgs->inv_diag_vec, NULL));
647 PetscCall(MatGetLocalSize(result, &m_local, NULL));
648 // need to conjugate and conjugate again because we have MatTransposeMatMult but not MatHermitianTransposeMatMult()
649 PetscCall(MatConjugate(lbfgs->temp_mat));
650 if (m_local) {
651 Mat temp_local, YtS_local, result_local;
652 PetscCall(MatDenseGetLocalMatrix(lbfgs->YtS_triu_strict, &YtS_local));
653 PetscCall(MatDenseGetLocalMatrix(lbfgs->temp_mat, &temp_local));
654 PetscCall(MatDenseGetLocalMatrix(result, &result_local));
655 PetscCall(MatTransposeMatMult(YtS_local, temp_local, MAT_REUSE_MATRIX, PETSC_DETERMINE, &result_local));
656 }
657 PetscCall(MatConjugate(result));
658 PetscFunctionReturn(PETSC_SUCCESS);
659 }
660
MatLMVMDBFGSUpdateMultData(Mat B)661 static PetscErrorCode MatLMVMDBFGSUpdateMultData(Mat B)
662 {
663 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
664 Mat_DQN *lbfgs = (Mat_DQN *)lmvm->ctx;
665 PetscInt m = lmvm->m, m_local;
666 PetscInt k = lmvm->k;
667 PetscInt h = k - oldest_update(m, k);
668 PetscInt j_0;
669 PetscInt prev_oldest;
670 Mat J_local;
671 Mat Sfull = lmvm->basis[LMBASIS_S]->vecs;
672 Mat Yfull = lmvm->basis[LMBASIS_Y]->vecs;
673
674 PetscFunctionBegin;
675 if (!lbfgs->YtS_triu_strict) {
676 PetscCall(MatDuplicate(lbfgs->StY_triu, MAT_SHARE_NONZERO_PATTERN, &lbfgs->YtS_triu_strict));
677 PetscCall(MatDestroy(&lbfgs->StBS));
678 PetscCall(MatDuplicate(lbfgs->StY_triu, MAT_SHARE_NONZERO_PATTERN, &lbfgs->StBS));
679 PetscCall(MatDestroy(&lbfgs->J));
680 PetscCall(MatDuplicate(lbfgs->StY_triu, MAT_SHARE_NONZERO_PATTERN, &lbfgs->J));
681 PetscCall(MatDestroy(&lbfgs->BS));
682 PetscCall(MatDuplicate(Yfull, MAT_SHARE_NONZERO_PATTERN, &lbfgs->BS));
683 PetscCall(MatShift(lbfgs->StBS, 1.0));
684 lbfgs->num_mult_updates = oldest_update(m, k);
685 }
686 if (lbfgs->num_mult_updates == k) PetscFunctionReturn(PETSC_SUCCESS);
687
688 /* B_0 may have been updated, we must recompute B_0 S and S^T B_0 S */
689 for (PetscInt j = oldest_update(m, k); j < k; j++) {
690 Vec s_j;
691 Vec Bs_j;
692 Vec StBs_j;
693 PetscInt S_idx = recycle_index(m, j);
694 PetscInt StBS_idx = lbfgs->strategy == MAT_LMVM_DENSE_INPLACE ? S_idx : history_index(m, k, j);
695
696 PetscCall(MatDenseGetColumnVecWrite(lbfgs->BS, S_idx, &Bs_j));
697 PetscCall(MatDenseGetColumnVecRead(Sfull, S_idx, &s_j));
698 PetscCall(MatDQNApplyJ0Fwd(B, s_j, Bs_j));
699 PetscCall(MatDenseRestoreColumnVecRead(Sfull, S_idx, &s_j));
700 PetscCall(MatDenseGetColumnVecWrite(lbfgs->StBS, StBS_idx, &StBs_j));
701 PetscCall(MatMultHermitianTransposeColumnRange(Sfull, Bs_j, StBs_j, 0, h));
702 lbfgs->St_count++;
703 if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, StBs_j, lbfgs->num_updates, lbfgs->cyclic_work_vec));
704 PetscCall(MatDenseRestoreColumnVecWrite(lbfgs->StBS, StBS_idx, &StBs_j));
705 PetscCall(MatDenseRestoreColumnVecWrite(lbfgs->BS, S_idx, &Bs_j));
706 }
707 prev_oldest = oldest_update(m, lbfgs->num_mult_updates);
708 if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER && prev_oldest < oldest_update(m, k)) {
709 /* move the YtS entries that have been computed and need to be kept back up */
710 PetscInt m_keep = m - (oldest_update(m, k) - prev_oldest);
711
712 PetscCall(MatMove_LR3(B, lbfgs->YtS_triu_strict, m_keep));
713 }
714 PetscCall(MatGetLocalSize(lbfgs->YtS_triu_strict, &m_local, NULL));
715 j_0 = PetscMax(lbfgs->num_mult_updates, oldest_update(m, k));
716 for (PetscInt j = j_0; j < k; j++) {
717 PetscInt S_idx = recycle_index(m, j);
718 PetscInt YtS_idx = lbfgs->strategy == MAT_LMVM_DENSE_INPLACE ? S_idx : history_index(m, k, j);
719 Vec s_j, Yts_j;
720
721 PetscCall(MatDenseGetColumnVecRead(Sfull, S_idx, &s_j));
722 PetscCall(MatDenseGetColumnVecWrite(lbfgs->YtS_triu_strict, YtS_idx, &Yts_j));
723 PetscCall(MatMultHermitianTransposeColumnRange(Yfull, s_j, Yts_j, 0, h));
724 lbfgs->Yt_count++;
725 if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, Yts_j, lbfgs->num_updates, lbfgs->cyclic_work_vec));
726 PetscCall(MatDenseRestoreColumnVecWrite(lbfgs->YtS_triu_strict, YtS_idx, &Yts_j));
727 PetscCall(MatDenseRestoreColumnVecRead(Sfull, S_idx, &s_j));
728 /* zero the corresponding row */
729 if (m_local > 0) {
730 Mat YtS_local, YtS_row;
731
732 PetscCall(MatDenseGetLocalMatrix(lbfgs->YtS_triu_strict, &YtS_local));
733 PetscCall(MatDenseGetSubMatrix(YtS_local, YtS_idx, YtS_idx + 1, PETSC_DECIDE, PETSC_DECIDE, &YtS_row));
734 PetscCall(MatZeroEntries(YtS_row));
735 PetscCall(MatDenseRestoreSubMatrix(YtS_local, &YtS_row));
736 }
737 }
738 if (!lbfgs->inv_diag_vec) PetscCall(VecDuplicate(lbfgs->diag_vec, &lbfgs->inv_diag_vec));
739 PetscCall(VecCopy(lbfgs->diag_vec, lbfgs->inv_diag_vec));
740 PetscCall(VecReciprocal(lbfgs->inv_diag_vec));
741 PetscCall(MatDenseGetLocalMatrix(lbfgs->J, &J_local));
742 PetscCall(MatSetFactorType(J_local, MAT_FACTOR_NONE));
743 PetscCall(MatGetLDLT(B, lbfgs->J));
744 PetscCall(MatAXPY(lbfgs->J, 1.0, lbfgs->StBS, SAME_NONZERO_PATTERN));
745 if (m_local) {
746 PetscCall(MatSetOption(J_local, MAT_SPD, PETSC_TRUE));
747 PetscCall(MatCholeskyFactor(J_local, NULL, NULL));
748 }
749 lbfgs->num_mult_updates = lbfgs->num_updates;
750 PetscFunctionReturn(PETSC_SUCCESS);
751 }
752
753 /* Solves for
754 * [ I | -S R^{-T} ] [ I | 0 ] [ H_0 | 0 ] [ I | Y ] [ I ]
755 * [-----+---] [-----+---] [---+---] [-------------]
756 * [ Y^T | I ] [ 0 | D ] [ 0 | I ] [ -R^{-1} S^T ] */
757
MatSolve_LMVMDBFGS(Mat H,Vec F,Vec dX)758 static PetscErrorCode MatSolve_LMVMDBFGS(Mat H, Vec F, Vec dX)
759 {
760 Mat_LMVM *lmvm = (Mat_LMVM *)H->data;
761 Mat_DQN *lbfgs = (Mat_DQN *)lmvm->ctx;
762 Vec rwork1 = lbfgs->rwork1;
763 PetscInt m = lmvm->m;
764 PetscInt k = lmvm->k;
765 PetscInt h = k - oldest_update(m, k);
766 Mat Sfull = lmvm->basis[LMBASIS_S]->vecs;
767 Mat Yfull = lmvm->basis[LMBASIS_Y]->vecs;
768 PetscObjectState Fstate;
769
770 PetscFunctionBegin;
771 VecCheckSameSize(F, 2, dX, 3);
772 VecCheckMatCompatible(H, dX, 3, F, 2);
773
774 /* Block Version */
775 if (!lbfgs->num_updates) {
776 PetscCall(MatDQNApplyJ0Inv(H, F, dX));
777 PetscFunctionReturn(PETSC_SUCCESS); /* No updates stored yet */
778 }
779
780 PetscCall(PetscObjectStateGet((PetscObject)F, &Fstate));
781 if (F == lbfgs->Fprev_ref && Fstate == lbfgs->Fprev_state) {
782 PetscCall(VecCopy(lbfgs->StFprev, rwork1));
783 } else {
784 PetscCall(MatMultHermitianTransposeColumnRange(Sfull, F, rwork1, 0, h));
785 lbfgs->St_count++;
786 }
787
788 /* Reordering rwork1, as STY is in history order, while S is in recycled order */
789 if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(H, rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
790 PetscCall(MatUpperTriangularSolveInPlace(H, lbfgs->StY_triu, rwork1, PETSC_FALSE, lbfgs->num_updates, lbfgs->strategy));
791 PetscCall(VecScale(rwork1, -1.0));
792 if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecHistoryOrderToRecycleOrder(H, rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
793
794 PetscCall(VecCopy(F, lbfgs->column_work));
795 PetscCall(MatMultAddColumnRange(Yfull, rwork1, lbfgs->column_work, lbfgs->column_work, 0, h));
796 lbfgs->Y_count++;
797
798 PetscCall(VecPointwiseMult(rwork1, lbfgs->diag_vec_recycle_order, rwork1));
799 PetscCall(MatDQNApplyJ0Inv(H, lbfgs->column_work, dX));
800
801 PetscCall(MatMultHermitianTransposeAddColumnRange(Yfull, dX, rwork1, rwork1, 0, h));
802 lbfgs->Yt_count++;
803
804 if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(H, rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
805 PetscCall(MatUpperTriangularSolveInPlace(H, lbfgs->StY_triu, rwork1, PETSC_TRUE, lbfgs->num_updates, lbfgs->strategy));
806 PetscCall(VecScale(rwork1, -1.0));
807 if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecHistoryOrderToRecycleOrder(H, rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
808
809 PetscCall(MatMultAddColumnRange(Sfull, rwork1, dX, dX, 0, h));
810 lbfgs->S_count++;
811 PetscFunctionReturn(PETSC_SUCCESS);
812 }
813
814 /* Solves for
815 B_0 - [ Y | B_0 S] [ -D | L^T ]^-1 [ Y^T ]
816 [-----+-----------] [---------]
817 [ L | S^T B_0 S ] [ S^T B_0 ]
818
819 Above is equivalent to
820
821 B_0 - [ Y | B_0 S] [[ I | 0 ][ -D | 0 ][ I | -D^{-1} L^T ]]^-1 [ Y^T ]
822 [[-----------+---][-----+---][---+-------------]] [---------]
823 [[ -L D^{-1} | I ][ 0 | J ][ 0 | I ]] [ S^T B_0 ]
824
825 where J = S^T B_0 S + L D^{-1} L^T
826
827 becomes
828
829 B_0 - [ Y | B_0 S] [ I | D^{-1} L^T ][ -D^{-1} | 0 ][ I | 0 ] [ Y^T ]
830 [---+------------][----------+--------][----------+---] [---------]
831 [ 0 | I ][ 0 | J^{-1} ][ L D^{-1} | I ] [ S^T B_0 ]
832
833 =
834
835 B_0 + [ Y | B_0 S] [ D^{-1} | 0 ][ I | L^T ][ I | 0 ][ I | 0 ] [ Y^T ]
836 [--------+---][---+-----][---+---------][----------+---] [---------]
837 [ 0 | I ][ 0 | I ][ 0 | -J^{-1} ][ L D^{-1} | I ] [ S^T B_0 ]
838
839 (Note that YtS_triu_strict is L^T)
840 Byrd, Nocedal, Schnabel 1994
841
842 Alternative approach: considering the fact that DFP is dual to BFGS, use MatMult of DPF:
843 (See ddfp.c's MatMult_LMVMDDFP)
844
845 */
MatMult_LMVMDBFGS(Mat B,Vec X,Vec Z)846 static PetscErrorCode MatMult_LMVMDBFGS(Mat B, Vec X, Vec Z)
847 {
848 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
849 Mat_DQN *lbfgs = (Mat_DQN *)lmvm->ctx;
850 Mat J_local;
851 PetscInt idx, i, j, m_local, local_n;
852 PetscInt m = lmvm->m;
853 PetscInt k = lmvm->k;
854 PetscInt h = k - oldest_update(m, k);
855 Mat Sfull = lmvm->basis[LMBASIS_S]->vecs;
856 Mat Yfull = lmvm->basis[LMBASIS_Y]->vecs;
857
858 PetscFunctionBegin;
859 VecCheckSameSize(X, 2, Z, 3);
860 VecCheckMatCompatible(B, X, 2, Z, 3);
861
862 /* Cholesky Version */
863 /* Start with the B0 term */
864 PetscCall(MatDQNApplyJ0Fwd(B, X, Z));
865 if (!lbfgs->num_updates) PetscFunctionReturn(PETSC_SUCCESS); /* No updates stored yet */
866
867 if (lbfgs->use_recursive) {
868 PetscDeviceContext dctx;
869 PetscMemType memtype;
870 PetscScalar stz, ytx, stp, sjtpi, yjtsi, *workscalar;
871 PetscInt oldest = oldest_update(m, k);
872
873 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
874 /* Recursive formulation to avoid Cholesky. Not a dense formulation */
875 PetscCall(MatMultHermitianTransposeColumnRange(Yfull, X, lbfgs->rwork1, 0, h));
876 lbfgs->Yt_count++;
877
878 PetscCall(VecGetLocalSize(lbfgs->rwork1, &local_n));
879
880 if (lbfgs->needPQ) {
881 PetscInt oldest = oldest_update(m, k);
882 for (i = oldest; i < k; ++i) {
883 idx = recycle_index(m, i);
884 /* column_work = S[idx] */
885 PetscCall(MatGetColumnVector(Sfull, lbfgs->column_work, idx));
886 PetscCall(MatDQNApplyJ0Fwd(B, lbfgs->column_work, lbfgs->PQ[idx]));
887 PetscCall(MatMultHermitianTransposeColumnRange(Yfull, lbfgs->column_work, lbfgs->rwork3, 0, h));
888 PetscCall(VecGetArrayAndMemType(lbfgs->rwork3, &workscalar, &memtype));
889 for (j = oldest; j < i; ++j) {
890 PetscInt idx_j = recycle_index(m, j);
891 /* Copy yjtsi in device-aware manner */
892 if (local_n) {
893 if (PetscMemTypeHost(memtype)) {
894 yjtsi = workscalar[idx_j];
895 } else {
896 PetscCall(PetscDeviceRegisterMemory(&yjtsi, PETSC_MEMTYPE_HOST, sizeof(yjtsi)));
897 PetscCall(PetscDeviceRegisterMemory(workscalar, memtype, local_n * sizeof(*workscalar)));
898 PetscCall(PetscDeviceArrayCopy(dctx, &yjtsi, &workscalar[idx_j], 1));
899 }
900 }
901 PetscCallMPI(MPI_Bcast(&yjtsi, 1, MPIU_SCALAR, 0, PetscObjectComm((PetscObject)B)));
902 /* column_work2 = S[j] */
903 PetscCall(MatGetColumnVector(Sfull, lbfgs->column_work2, idx_j));
904 PetscCall(VecDot(lbfgs->PQ[idx], lbfgs->column_work2, &sjtpi));
905 /* column_work2 = Y[j] */
906 PetscCall(MatGetColumnVector(Yfull, lbfgs->column_work2, idx_j));
907 /* Compute the pure BFGS component of the forward product */
908 PetscCall(VecAXPBYPCZ(lbfgs->PQ[idx], -sjtpi / lbfgs->stp[idx_j], yjtsi / lbfgs->yts[idx_j], 1.0, lbfgs->PQ[idx_j], lbfgs->column_work2));
909 }
910 PetscCall(VecDot(lbfgs->PQ[idx], lbfgs->column_work, &stp));
911 lbfgs->stp[idx] = PetscRealPart(stp);
912 }
913 lbfgs->needPQ = PETSC_FALSE;
914 }
915
916 PetscCall(VecGetArrayAndMemType(lbfgs->rwork1, &workscalar, &memtype));
917 for (i = oldest; i < k; ++i) {
918 idx = recycle_index(m, i);
919 /* Copy stz[i], ytx[i] in device-aware manner */
920 if (local_n) {
921 if (PetscMemTypeHost(memtype)) {
922 ytx = workscalar[idx];
923 } else {
924 PetscCall(PetscDeviceRegisterMemory(&ytx, PETSC_MEMTYPE_HOST, 1 * sizeof(ytx)));
925 PetscCall(PetscDeviceRegisterMemory(workscalar, memtype, local_n * sizeof(*workscalar)));
926 PetscCall(PetscDeviceArrayCopy(dctx, &ytx, &workscalar[idx], 1));
927 }
928 }
929 PetscCallMPI(MPI_Bcast(&ytx, 1, MPIU_SCALAR, 0, PetscObjectComm((PetscObject)B)));
930 /* column_work : S[i], column_work2 : Y[i] */
931 PetscCall(MatGetColumnVector(Sfull, lbfgs->column_work, idx));
932 PetscCall(MatGetColumnVector(Yfull, lbfgs->column_work2, idx));
933 PetscCall(VecDot(Z, lbfgs->column_work, &stz));
934 PetscCall(VecAXPBYPCZ(Z, -stz / lbfgs->stp[idx], ytx / lbfgs->yts[idx], 1.0, lbfgs->PQ[idx], lbfgs->column_work2));
935 }
936 PetscCall(VecRestoreArrayAndMemType(lbfgs->rwork1, &workscalar));
937 } else {
938 PetscCall(MatLMVMDBFGSUpdateMultData(B));
939 PetscCall(MatMultHermitianTransposeColumnRange(Yfull, X, lbfgs->rwork1, 0, h));
940 lbfgs->Yt_count++;
941 PetscCall(MatMultHermitianTransposeColumnRange(Sfull, Z, lbfgs->rwork2, 0, h));
942 lbfgs->St_count++;
943 if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) {
944 PetscCall(VecRecycleOrderToHistoryOrder(B, lbfgs->rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
945 PetscCall(VecRecycleOrderToHistoryOrder(B, lbfgs->rwork2, lbfgs->num_updates, lbfgs->cyclic_work_vec));
946 }
947
948 PetscCall(VecPointwiseMult(lbfgs->rwork1, lbfgs->rwork1, lbfgs->inv_diag_vec));
949 if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(lbfgs->YtS_triu_strict));
950 PetscCall(MatMultTransposeAdd(lbfgs->YtS_triu_strict, lbfgs->rwork1, lbfgs->rwork2, lbfgs->rwork2));
951 if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(lbfgs->YtS_triu_strict));
952
953 if (!lbfgs->rwork2_local) PetscCall(VecCreateLocalVector(lbfgs->rwork2, &lbfgs->rwork2_local));
954 if (!lbfgs->rwork3_local) PetscCall(VecCreateLocalVector(lbfgs->rwork3, &lbfgs->rwork3_local));
955 PetscCall(VecGetLocalVectorRead(lbfgs->rwork2, lbfgs->rwork2_local));
956 PetscCall(VecGetLocalVector(lbfgs->rwork3, lbfgs->rwork3_local));
957 PetscCall(MatDenseGetLocalMatrix(lbfgs->J, &J_local));
958 PetscCall(VecGetSize(lbfgs->rwork2_local, &m_local));
959 if (m_local) {
960 PetscCall(MatDenseGetLocalMatrix(lbfgs->J, &J_local));
961 PetscCall(MatSolve(J_local, lbfgs->rwork2_local, lbfgs->rwork3_local));
962 }
963 PetscCall(VecRestoreLocalVector(lbfgs->rwork3, lbfgs->rwork3_local));
964 PetscCall(VecRestoreLocalVectorRead(lbfgs->rwork2, lbfgs->rwork2_local));
965 PetscCall(VecScale(lbfgs->rwork3, -1.0));
966
967 PetscCall(MatMult(lbfgs->YtS_triu_strict, lbfgs->rwork3, lbfgs->rwork2));
968 PetscCall(VecPointwiseMult(lbfgs->rwork2, lbfgs->rwork2, lbfgs->inv_diag_vec));
969 PetscCall(VecAXPY(lbfgs->rwork1, 1.0, lbfgs->rwork2));
970
971 if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) {
972 PetscCall(VecHistoryOrderToRecycleOrder(B, lbfgs->rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
973 PetscCall(VecHistoryOrderToRecycleOrder(B, lbfgs->rwork3, lbfgs->num_updates, lbfgs->cyclic_work_vec));
974 }
975
976 PetscCall(MatMultAddColumnRange(Yfull, lbfgs->rwork1, Z, Z, 0, h));
977 lbfgs->Y_count++;
978 PetscCall(MatMultAddColumnRange(lbfgs->BS, lbfgs->rwork3, Z, Z, 0, h));
979 lbfgs->S_count++;
980 }
981 PetscFunctionReturn(PETSC_SUCCESS);
982 }
983
984 /*
985 This dense representation reduces the L-BFGS update to a series of
986 matrix-vector products with dense matrices in lieu of the conventional matrix-free
987 two-loop algorithm.
988 */
MatCreate_LMVMDBFGS(Mat B)989 PetscErrorCode MatCreate_LMVMDBFGS(Mat B)
990 {
991 Mat_LMVM *lmvm;
992 Mat_DQN *lbfgs;
993
994 PetscFunctionBegin;
995 PetscCall(MatCreate_LMVM(B));
996 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMDBFGS));
997 PetscCall(MatSetOption(B, MAT_HERMITIAN, PETSC_TRUE));
998 PetscCall(MatSetOption(B, MAT_SPD, PETSC_TRUE));
999 PetscCall(MatSetOption(B, MAT_SPD_ETERNAL, PETSC_TRUE));
1000 B->ops->view = MatView_LMVMDQN;
1001 B->ops->setup = MatSetUp_LMVMDQN;
1002 B->ops->setfromoptions = MatSetFromOptions_LMVMDQN;
1003 B->ops->destroy = MatDestroy_LMVMDQN;
1004
1005 lmvm = (Mat_LMVM *)B->data;
1006 lmvm->ops->reset = MatReset_LMVMDQN;
1007 lmvm->ops->update = MatUpdate_LMVMDQN;
1008 lmvm->ops->mult = MatMult_LMVMDBFGS;
1009 lmvm->ops->solve = MatSolve_LMVMDBFGS;
1010 lmvm->ops->copy = MatCopy_LMVMDQN;
1011
1012 lmvm->ops->multht = lmvm->ops->mult;
1013 lmvm->ops->solveht = lmvm->ops->solve;
1014
1015 PetscCall(PetscNew(&lbfgs));
1016 lmvm->ctx = (void *)lbfgs;
1017 lbfgs->allocated = PETSC_FALSE;
1018 lbfgs->use_recursive = PETSC_TRUE;
1019 lbfgs->needPQ = PETSC_TRUE;
1020 lbfgs->watchdog = 0;
1021 lbfgs->max_seq_rejects = lmvm->m / 2;
1022 lbfgs->strategy = MAT_LMVM_DENSE_INPLACE;
1023
1024 PetscCall(SymBroydenRescaleCreate(&lbfgs->rescale));
1025 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSymBroydenSetDelta_C", MatLMVMSymBroydenSetDelta_LMVMDQN));
1026 PetscFunctionReturn(PETSC_SUCCESS);
1027 }
1028
1029 /*@
1030 MatCreateLMVMDBFGS - Creates a dense representation of the limited-memory
1031 Broyden-Fletcher-Goldfarb-Shanno (BFGS) approximation to a Hessian.
1032
1033 Collective
1034
1035 Input Parameters:
1036 + comm - MPI communicator
1037 . n - number of local rows for storage vectors
1038 - N - global size of the storage vectors
1039
1040 Output Parameter:
1041 . B - the matrix
1042
1043 Level: advanced
1044
1045 Note:
1046 It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions()
1047 paradigm instead of this routine directly.
1048
1049 .seealso: `MatCreate()`, `MATLMVM`, `MATLMVMDBFGS`, `MatCreateLMVMBFGS()`
1050 @*/
MatCreateLMVMDBFGS(MPI_Comm comm,PetscInt n,PetscInt N,Mat * B)1051 PetscErrorCode MatCreateLMVMDBFGS(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
1052 {
1053 PetscFunctionBegin;
1054 PetscCall(KSPInitializePackage());
1055 PetscCall(MatCreate(comm, B));
1056 PetscCall(MatSetSizes(*B, n, n, N, N));
1057 PetscCall(MatSetType(*B, MATLMVMDBFGS));
1058 PetscCall(MatSetUp(*B));
1059 PetscFunctionReturn(PETSC_SUCCESS);
1060 }
1061
1062 /* here R is strictly upper triangular part of STY */
MatGetRTDR(Mat B,Mat result)1063 static PetscErrorCode MatGetRTDR(Mat B, Mat result)
1064 {
1065 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
1066 Mat_DQN *ldfp = (Mat_DQN *)lmvm->ctx;
1067 PetscInt m_local;
1068
1069 PetscFunctionBegin;
1070 if (!ldfp->temp_mat) PetscCall(MatDuplicate(ldfp->StY_triu_strict, MAT_SHARE_NONZERO_PATTERN, &ldfp->temp_mat));
1071 PetscCall(MatCopy(ldfp->StY_triu_strict, ldfp->temp_mat, SAME_NONZERO_PATTERN));
1072 PetscCall(MatDiagonalScale(ldfp->temp_mat, ldfp->inv_diag_vec, NULL));
1073 PetscCall(MatGetLocalSize(result, &m_local, NULL));
1074 // need to conjugate and conjugate again because we have MatTransposeMatMult but not MatHermitianTransposeMatMult()
1075 PetscCall(MatConjugate(ldfp->temp_mat));
1076 if (m_local) {
1077 Mat temp_local, StY_local, result_local;
1078 PetscCall(MatDenseGetLocalMatrix(ldfp->StY_triu_strict, &StY_local));
1079 PetscCall(MatDenseGetLocalMatrix(ldfp->temp_mat, &temp_local));
1080 PetscCall(MatDenseGetLocalMatrix(result, &result_local));
1081 PetscCall(MatTransposeMatMult(StY_local, temp_local, MAT_REUSE_MATRIX, PETSC_DETERMINE, &result_local));
1082 }
1083 PetscCall(MatConjugate(result));
1084 PetscFunctionReturn(PETSC_SUCCESS);
1085 }
1086
MatLMVMDDFPUpdateSolveData(Mat B)1087 static PetscErrorCode MatLMVMDDFPUpdateSolveData(Mat B)
1088 {
1089 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
1090 Mat_DQN *ldfp = (Mat_DQN *)lmvm->ctx;
1091 PetscInt m = lmvm->m, m_local;
1092 PetscInt k = lmvm->k;
1093 PetscInt h = k - oldest_update(m, k);
1094 PetscInt j_0;
1095 PetscInt prev_oldest;
1096 Mat Sfull = lmvm->basis[LMBASIS_S]->vecs;
1097 Mat Yfull = lmvm->basis[LMBASIS_Y]->vecs;
1098 Mat J_local;
1099
1100 PetscFunctionBegin;
1101 if (!ldfp->StY_triu_strict) {
1102 PetscCall(MatDuplicate(ldfp->YtS_triu, MAT_SHARE_NONZERO_PATTERN, &ldfp->StY_triu_strict));
1103 PetscCall(MatDestroy(&ldfp->YtHY));
1104 PetscCall(MatDuplicate(ldfp->YtS_triu, MAT_SHARE_NONZERO_PATTERN, &ldfp->YtHY));
1105 PetscCall(MatDestroy(&ldfp->J));
1106 PetscCall(MatDuplicate(ldfp->YtS_triu, MAT_SHARE_NONZERO_PATTERN, &ldfp->J));
1107 PetscCall(MatDestroy(&ldfp->HY));
1108 PetscCall(MatDuplicate(Yfull, MAT_SHARE_NONZERO_PATTERN, &ldfp->HY));
1109 PetscCall(MatShift(ldfp->YtHY, 1.0));
1110 ldfp->num_mult_updates = oldest_update(m, k);
1111 }
1112 if (ldfp->num_mult_updates == k) PetscFunctionReturn(PETSC_SUCCESS);
1113
1114 /* H_0 may have been updated, we must recompute H_0 Y and Y^T H_0 Y */
1115 for (PetscInt j = oldest_update(m, k); j < k; j++) {
1116 Vec y_j;
1117 Vec Hy_j;
1118 Vec YtHy_j;
1119 PetscInt Y_idx = recycle_index(m, j);
1120 PetscInt YtHY_idx = ldfp->strategy == MAT_LMVM_DENSE_INPLACE ? Y_idx : history_index(m, k, j);
1121
1122 PetscCall(MatDenseGetColumnVecWrite(ldfp->HY, Y_idx, &Hy_j));
1123 PetscCall(MatDenseGetColumnVecRead(Yfull, Y_idx, &y_j));
1124 PetscCall(MatDQNApplyJ0Inv(B, y_j, Hy_j));
1125 PetscCall(MatDenseRestoreColumnVecRead(Yfull, Y_idx, &y_j));
1126 PetscCall(MatDenseGetColumnVecWrite(ldfp->YtHY, YtHY_idx, &YtHy_j));
1127 PetscCall(MatMultHermitianTransposeColumnRange(Yfull, Hy_j, YtHy_j, 0, h));
1128 ldfp->Yt_count++;
1129 if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, YtHy_j, ldfp->num_updates, ldfp->cyclic_work_vec));
1130 PetscCall(MatDenseRestoreColumnVecWrite(ldfp->YtHY, YtHY_idx, &YtHy_j));
1131 PetscCall(MatDenseRestoreColumnVecWrite(ldfp->HY, Y_idx, &Hy_j));
1132 }
1133 prev_oldest = oldest_update(m, ldfp->num_mult_updates);
1134 if (ldfp->strategy == MAT_LMVM_DENSE_REORDER && prev_oldest < oldest_update(m, k)) {
1135 /* move the YtS entries that have been computed and need to be kept back up */
1136 PetscInt m_keep = m - (oldest_update(m, k) - prev_oldest);
1137
1138 PetscCall(MatMove_LR3(B, ldfp->StY_triu_strict, m_keep));
1139 }
1140 PetscCall(MatGetLocalSize(ldfp->StY_triu_strict, &m_local, NULL));
1141 j_0 = PetscMax(ldfp->num_mult_updates, oldest_update(m, k));
1142 for (PetscInt j = j_0; j < k; j++) {
1143 PetscInt Y_idx = recycle_index(m, j);
1144 PetscInt StY_idx = ldfp->strategy == MAT_LMVM_DENSE_INPLACE ? Y_idx : history_index(m, k, j);
1145 Vec y_j, Sty_j;
1146
1147 PetscCall(MatDenseGetColumnVecRead(Yfull, Y_idx, &y_j));
1148 PetscCall(MatDenseGetColumnVecWrite(ldfp->StY_triu_strict, StY_idx, &Sty_j));
1149 PetscCall(MatMultHermitianTransposeColumnRange(Sfull, y_j, Sty_j, 0, h));
1150 ldfp->St_count++;
1151 if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, Sty_j, ldfp->num_updates, ldfp->cyclic_work_vec));
1152 PetscCall(MatDenseRestoreColumnVecWrite(ldfp->StY_triu_strict, StY_idx, &Sty_j));
1153 PetscCall(MatDenseRestoreColumnVecRead(Yfull, Y_idx, &y_j));
1154 /* zero the corresponding row */
1155 if (m_local > 0) {
1156 Mat StY_local, StY_row;
1157
1158 PetscCall(MatDenseGetLocalMatrix(ldfp->StY_triu_strict, &StY_local));
1159 PetscCall(MatDenseGetSubMatrix(StY_local, StY_idx, StY_idx + 1, PETSC_DECIDE, PETSC_DECIDE, &StY_row));
1160 PetscCall(MatZeroEntries(StY_row));
1161 PetscCall(MatDenseRestoreSubMatrix(StY_local, &StY_row));
1162 }
1163 }
1164 if (!ldfp->inv_diag_vec) PetscCall(VecDuplicate(ldfp->diag_vec, &ldfp->inv_diag_vec));
1165 PetscCall(VecCopy(ldfp->diag_vec, ldfp->inv_diag_vec));
1166 PetscCall(VecReciprocal(ldfp->inv_diag_vec));
1167 PetscCall(MatDenseGetLocalMatrix(ldfp->J, &J_local));
1168 PetscCall(MatSetFactorType(J_local, MAT_FACTOR_NONE));
1169 PetscCall(MatGetRTDR(B, ldfp->J));
1170 PetscCall(MatAXPY(ldfp->J, 1.0, ldfp->YtHY, SAME_NONZERO_PATTERN));
1171 if (m_local) {
1172 PetscCall(MatSetOption(J_local, MAT_SPD, PETSC_TRUE));
1173 PetscCall(MatCholeskyFactor(J_local, NULL, NULL));
1174 }
1175 ldfp->num_mult_updates = ldfp->num_updates;
1176 PetscFunctionReturn(PETSC_SUCCESS);
1177 }
1178
1179 /* Solves for
1180
1181 H_0 - [ S | H_0 Y] [ -D | R.T ]^-1 [ S^T ]
1182 [-----+-----------] [---------]
1183 [ R | Y^T H_0 Y ] [ Y^T H_0 ]
1184
1185 Above is equivalent to
1186
1187 H_0 - [ S | H_0 Y] [[ I | 0 ][ -D | 0 ][ I | -D^{-1} R^T ]]^-1 [ S^T ]
1188 [[-----------+---][----+---][---+-------------]] [---------]
1189 [[ -R D^{-1} | I ][ 0 | J ][ 0 | I ]] [ Y^T H_0 ]
1190
1191 where J = Y^T H_0 Y + R D^{-1} R.T
1192
1193 becomes
1194
1195 H_0 - [ S | H_0 Y] [ I | D^{-1} R^T ][ -D^{-1} | 0 ][ I | 0 ] [ S^T ]
1196 [---+------------][----------+--------][----------+---] [---------]
1197 [ 0 | I ][ 0 | J^{-1} ][ R D^{-1} | I ] [ Y^T H_0 ]
1198
1199 =
1200
1201 H_0 + [ S | H_0 Y] [ D^{-1} | 0 ][ I | R^T ][ I | 0 ][ I | 0 ] [ S^T ]
1202 [--------+---][---+-----][---+---------][----------+---] [---------]
1203 [ 0 | I ][ 0 | I ][ 0 | -J^{-1} ][ R D^{-1} | I ] [ Y^T H_0 ]
1204
1205 (Note that StY_triu_strict is R)
1206 Byrd, Nocedal, Schnabel 1994
1207
1208 */
MatSolve_LMVMDDFP(Mat H,Vec F,Vec dX)1209 static PetscErrorCode MatSolve_LMVMDDFP(Mat H, Vec F, Vec dX)
1210 {
1211 Mat_LMVM *lmvm = (Mat_LMVM *)H->data;
1212 Mat_DQN *ldfp = (Mat_DQN *)lmvm->ctx;
1213 PetscInt m = lmvm->m;
1214 PetscInt k = lmvm->k;
1215 PetscInt h = k - oldest_update(m, k);
1216 PetscInt idx, i, j, local_n;
1217 PetscInt m_local;
1218 Mat J_local;
1219 Mat Sfull = lmvm->basis[LMBASIS_S]->vecs;
1220 Mat Yfull = lmvm->basis[LMBASIS_Y]->vecs;
1221
1222 PetscFunctionBegin;
1223 VecCheckSameSize(F, 2, dX, 3);
1224 VecCheckMatCompatible(H, dX, 3, F, 2);
1225
1226 /* Cholesky Version */
1227 /* Start with the B0 term */
1228 PetscCall(MatDQNApplyJ0Inv(H, F, dX));
1229 if (!ldfp->num_updates) PetscFunctionReturn(PETSC_SUCCESS); /* No updates stored yet */
1230
1231 if (ldfp->use_recursive) {
1232 PetscDeviceContext dctx;
1233 PetscMemType memtype;
1234 PetscScalar stf, ytx, ytq, yjtqi, sjtyi, *workscalar;
1235
1236 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
1237 /* Recursive formulation to avoid Cholesky. Not a dense formulation */
1238 PetscCall(MatMultHermitianTransposeColumnRange(Sfull, F, ldfp->rwork1, 0, h));
1239 ldfp->Yt_count++;
1240
1241 PetscCall(VecGetLocalSize(ldfp->rwork1, &local_n));
1242
1243 PetscInt oldest = oldest_update(m, k);
1244
1245 if (ldfp->needPQ) {
1246 PetscInt oldest = oldest_update(m, k);
1247 for (i = oldest; i < k; ++i) {
1248 idx = recycle_index(m, i);
1249 /* column_work = S[idx] */
1250 PetscCall(MatGetColumnVector(Yfull, ldfp->column_work, idx));
1251 PetscCall(MatDQNApplyJ0Inv(H, ldfp->column_work, ldfp->PQ[idx]));
1252 PetscCall(MatMultHermitianTransposeColumnRange(Sfull, ldfp->column_work, ldfp->rwork3, 0, h));
1253 PetscCall(VecGetArrayAndMemType(ldfp->rwork3, &workscalar, &memtype));
1254 for (j = oldest; j < i; ++j) {
1255 PetscInt idx_j = recycle_index(m, j);
1256 /* Copy sjtyi in device-aware manner */
1257 if (local_n) {
1258 if (PetscMemTypeHost(memtype)) {
1259 sjtyi = workscalar[idx_j];
1260 } else {
1261 PetscCall(PetscDeviceRegisterMemory(&sjtyi, PETSC_MEMTYPE_HOST, 1 * sizeof(sjtyi)));
1262 PetscCall(PetscDeviceRegisterMemory(workscalar, memtype, local_n * sizeof(*workscalar)));
1263 PetscCall(PetscDeviceArrayCopy(dctx, &sjtyi, &workscalar[idx_j], 1));
1264 }
1265 }
1266 PetscCallMPI(MPI_Bcast(&sjtyi, 1, MPIU_SCALAR, 0, PetscObjectComm((PetscObject)H)));
1267 /* column_work2 = Y[j] */
1268 PetscCall(MatGetColumnVector(Yfull, ldfp->column_work2, idx_j));
1269 PetscCall(VecDot(ldfp->PQ[idx], ldfp->column_work2, &yjtqi));
1270 /* column_work2 = Y[j] */
1271 PetscCall(MatGetColumnVector(Sfull, ldfp->column_work2, idx_j));
1272 /* Compute the pure BFGS component of the forward product */
1273 PetscCall(VecAXPBYPCZ(ldfp->PQ[idx], -yjtqi / ldfp->ytq[idx_j], sjtyi / ldfp->yts[idx_j], 1.0, ldfp->PQ[idx_j], ldfp->column_work2));
1274 }
1275 PetscCall(VecDot(ldfp->PQ[idx], ldfp->column_work, &ytq));
1276 ldfp->ytq[idx] = PetscRealPart(ytq);
1277 }
1278 ldfp->needPQ = PETSC_FALSE;
1279 }
1280
1281 PetscCall(VecGetArrayAndMemType(ldfp->rwork1, &workscalar, &memtype));
1282 for (i = oldest; i < k; ++i) {
1283 idx = recycle_index(m, i);
1284 /* Copy stz[i], ytx[i] in device-aware manner */
1285 if (local_n) {
1286 if (PetscMemTypeHost(memtype)) {
1287 stf = workscalar[idx];
1288 } else {
1289 PetscCall(PetscDeviceRegisterMemory(&stf, PETSC_MEMTYPE_HOST, sizeof(stf)));
1290 PetscCall(PetscDeviceRegisterMemory(workscalar, memtype, local_n * sizeof(*workscalar)));
1291 PetscCall(PetscDeviceArrayCopy(dctx, &stf, &workscalar[idx], 1));
1292 }
1293 }
1294 PetscCallMPI(MPI_Bcast(&stf, 1, MPIU_SCALAR, 0, PetscObjectComm((PetscObject)H)));
1295 /* column_work : S[i], column_work2 : Y[i] */
1296 PetscCall(MatGetColumnVector(Sfull, ldfp->column_work, idx));
1297 PetscCall(MatGetColumnVector(Yfull, ldfp->column_work2, idx));
1298 PetscCall(VecDot(dX, ldfp->column_work2, &ytx));
1299 PetscCall(VecAXPBYPCZ(dX, -ytx / ldfp->ytq[idx], stf / ldfp->yts[idx], 1.0, ldfp->PQ[idx], ldfp->column_work));
1300 }
1301 PetscCall(VecRestoreArrayAndMemType(ldfp->rwork1, &workscalar));
1302 } else {
1303 PetscCall(MatLMVMDDFPUpdateSolveData(H));
1304 PetscCall(MatMultHermitianTransposeColumnRange(Sfull, F, ldfp->rwork1, 0, h));
1305 ldfp->St_count++;
1306 PetscCall(MatMultHermitianTransposeColumnRange(Yfull, dX, ldfp->rwork2, 0, h));
1307 ldfp->Yt_count++;
1308 if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) {
1309 PetscCall(VecRecycleOrderToHistoryOrder(H, ldfp->rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1310 PetscCall(VecRecycleOrderToHistoryOrder(H, ldfp->rwork2, ldfp->num_updates, ldfp->cyclic_work_vec));
1311 }
1312
1313 PetscCall(VecPointwiseMult(ldfp->rwork3, ldfp->rwork1, ldfp->inv_diag_vec));
1314 if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(ldfp->StY_triu_strict));
1315 PetscCall(MatMultTransposeAdd(ldfp->StY_triu_strict, ldfp->rwork3, ldfp->rwork2, ldfp->rwork2));
1316 if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(ldfp->StY_triu_strict));
1317
1318 if (!ldfp->rwork2_local) PetscCall(VecCreateLocalVector(ldfp->rwork2, &ldfp->rwork2_local));
1319 if (!ldfp->rwork3_local) PetscCall(VecCreateLocalVector(ldfp->rwork3, &ldfp->rwork3_local));
1320 PetscCall(VecGetLocalVectorRead(ldfp->rwork2, ldfp->rwork2_local));
1321 PetscCall(VecGetLocalVector(ldfp->rwork3, ldfp->rwork3_local));
1322 PetscCall(MatDenseGetLocalMatrix(ldfp->J, &J_local));
1323 PetscCall(VecGetSize(ldfp->rwork2_local, &m_local));
1324 if (m_local) {
1325 Mat J_local;
1326
1327 PetscCall(MatDenseGetLocalMatrix(ldfp->J, &J_local));
1328 PetscCall(MatSolve(J_local, ldfp->rwork2_local, ldfp->rwork3_local));
1329 }
1330 PetscCall(VecRestoreLocalVector(ldfp->rwork3, ldfp->rwork3_local));
1331 PetscCall(VecRestoreLocalVectorRead(ldfp->rwork2, ldfp->rwork2_local));
1332 PetscCall(VecScale(ldfp->rwork3, -1.0));
1333
1334 PetscCall(MatMultAdd(ldfp->StY_triu_strict, ldfp->rwork3, ldfp->rwork1, ldfp->rwork1));
1335 PetscCall(VecPointwiseMult(ldfp->rwork1, ldfp->rwork1, ldfp->inv_diag_vec));
1336
1337 if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) {
1338 PetscCall(VecHistoryOrderToRecycleOrder(H, ldfp->rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1339 PetscCall(VecHistoryOrderToRecycleOrder(H, ldfp->rwork3, ldfp->num_updates, ldfp->cyclic_work_vec));
1340 }
1341
1342 PetscCall(MatMultAddColumnRange(Sfull, ldfp->rwork1, dX, dX, 0, h));
1343 ldfp->S_count++;
1344 PetscCall(MatMultAddColumnRange(ldfp->HY, ldfp->rwork3, dX, dX, 0, h));
1345 ldfp->Y_count++;
1346 }
1347 PetscFunctionReturn(PETSC_SUCCESS);
1348 }
1349
1350 /* Solves for
1351 (Theorem 1, Erway, Jain, and Marcia, 2013)
1352
1353 B_0 - [ Y | B_0 S] [ -R^{-T} (D + S^T B_0 S) R^{-1} | R^{-T} ] [ Y^T ]
1354 ---------------------------------+--------] [---------]
1355 [ R^{-1} | 0 ] [ S^T B_0 ]
1356
1357 (Note: R above is right triangular part of YTS)
1358 which becomes,
1359
1360 [ I | -Y L^{-T} ] [ I | 0 ] [ B_0 | 0 ] [ I | S ] [ I ]
1361 [-----+---] [-----+---] [---+---] [-------------]
1362 [ S^T | I ] [ 0 | D ] [ 0 | I ] [ -L^{-1} Y^T ]
1363
1364 (Note: L above is right triangular part of STY)
1365
1366 */
MatMult_LMVMDDFP(Mat B,Vec X,Vec Z)1367 static PetscErrorCode MatMult_LMVMDDFP(Mat B, Vec X, Vec Z)
1368 {
1369 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
1370 Mat_DQN *ldfp = (Mat_DQN *)lmvm->ctx;
1371 Vec rwork1 = ldfp->rwork1;
1372 PetscInt m = lmvm->m;
1373 PetscInt k = lmvm->k;
1374 PetscInt h = k - oldest_update(m, k);
1375 Mat Sfull = lmvm->basis[LMBASIS_S]->vecs;
1376 Mat Yfull = lmvm->basis[LMBASIS_Y]->vecs;
1377 PetscObjectState Xstate;
1378
1379 PetscFunctionBegin;
1380 VecCheckSameSize(X, 2, Z, 3);
1381 VecCheckMatCompatible(B, X, 2, Z, 3);
1382
1383 /* DFP Version. Erway, Jain, Marcia, 2013, Theorem 1 */
1384 /* Block Version */
1385 if (!ldfp->num_updates) {
1386 PetscCall(MatDQNApplyJ0Fwd(B, X, Z));
1387 PetscFunctionReturn(PETSC_SUCCESS); /* No updates stored yet */
1388 }
1389
1390 PetscCall(PetscObjectStateGet((PetscObject)X, &Xstate));
1391 PetscCall(MatMultHermitianTransposeColumnRange(Yfull, X, rwork1, 0, h));
1392
1393 /* Reordering rwork1, as STY is in history order, while Y is in recycled order */
1394 if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1395 PetscCall(MatUpperTriangularSolveInPlace(B, ldfp->YtS_triu, rwork1, PETSC_FALSE, ldfp->num_updates, ldfp->strategy));
1396 PetscCall(VecScale(rwork1, -1.0));
1397 if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecHistoryOrderToRecycleOrder(B, rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1398
1399 PetscCall(VecCopy(X, ldfp->column_work));
1400 PetscCall(MatMultAddColumnRange(Sfull, rwork1, ldfp->column_work, ldfp->column_work, 0, h));
1401 ldfp->S_count++;
1402
1403 PetscCall(VecPointwiseMult(rwork1, ldfp->diag_vec_recycle_order, rwork1));
1404 PetscCall(MatDQNApplyJ0Fwd(B, ldfp->column_work, Z));
1405
1406 PetscCall(MatMultHermitianTransposeAddColumnRange(Sfull, Z, rwork1, rwork1, 0, h));
1407 ldfp->St_count++;
1408
1409 if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1410 PetscCall(MatUpperTriangularSolveInPlace(B, ldfp->YtS_triu, rwork1, PETSC_TRUE, ldfp->num_updates, ldfp->strategy));
1411 PetscCall(VecScale(rwork1, -1.0));
1412 if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecHistoryOrderToRecycleOrder(B, rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1413
1414 PetscCall(MatMultAddColumnRange(Yfull, rwork1, Z, Z, 0, h));
1415 ldfp->Y_count++;
1416 PetscFunctionReturn(PETSC_SUCCESS);
1417 }
1418
1419 /*
1420 This dense representation reduces the L-DFP update to a series of
1421 matrix-vector products with dense matrices in lieu of the conventional
1422 matrix-free two-loop algorithm.
1423 */
MatCreate_LMVMDDFP(Mat B)1424 PetscErrorCode MatCreate_LMVMDDFP(Mat B)
1425 {
1426 Mat_LMVM *lmvm;
1427 Mat_DQN *ldfp;
1428
1429 PetscFunctionBegin;
1430 PetscCall(MatCreate_LMVM(B));
1431 PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMDDFP));
1432 PetscCall(MatSetOption(B, MAT_HERMITIAN, PETSC_TRUE));
1433 PetscCall(MatSetOption(B, MAT_SPD, PETSC_TRUE));
1434 PetscCall(MatSetOption(B, MAT_SPD_ETERNAL, PETSC_TRUE));
1435 B->ops->view = MatView_LMVMDQN;
1436 B->ops->setup = MatSetUp_LMVMDQN;
1437 B->ops->setfromoptions = MatSetFromOptions_LMVMDQN;
1438 B->ops->destroy = MatDestroy_LMVMDQN;
1439
1440 lmvm = (Mat_LMVM *)B->data;
1441 lmvm->ops->reset = MatReset_LMVMDQN;
1442 lmvm->ops->update = MatUpdate_LMVMDQN;
1443 lmvm->ops->mult = MatMult_LMVMDDFP;
1444 lmvm->ops->solve = MatSolve_LMVMDDFP;
1445 lmvm->ops->copy = MatCopy_LMVMDQN;
1446
1447 lmvm->ops->multht = lmvm->ops->mult;
1448 lmvm->ops->solveht = lmvm->ops->solve;
1449
1450 PetscCall(PetscNew(&ldfp));
1451 lmvm->ctx = (void *)ldfp;
1452 ldfp->allocated = PETSC_FALSE;
1453 ldfp->watchdog = 0;
1454 ldfp->max_seq_rejects = lmvm->m / 2;
1455 ldfp->strategy = MAT_LMVM_DENSE_INPLACE;
1456 ldfp->use_recursive = PETSC_TRUE;
1457 ldfp->needPQ = PETSC_TRUE;
1458
1459 PetscCall(SymBroydenRescaleCreate(&ldfp->rescale));
1460 PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSymBroydenSetDelta_C", MatLMVMSymBroydenSetDelta_LMVMDQN));
1461 PetscFunctionReturn(PETSC_SUCCESS);
1462 }
1463
1464 /*@
1465 MatCreateLMVMDDFP - Creates a dense representation of the limited-memory
1466 Davidon-Fletcher-Powell (DFP) approximation to a Hessian.
1467
1468 Collective
1469
1470 Input Parameters:
1471 + comm - MPI communicator
1472 . n - number of local rows for storage vectors
1473 - N - global size of the storage vectors
1474
1475 Output Parameter:
1476 . B - the matrix
1477
1478 Level: advanced
1479
1480 Note:
1481 It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions()
1482 paradigm instead of this routine directly.
1483
1484 .seealso: `MatCreate()`, `MATLMVM`, `MATLMVMDDFP`, `MatCreateLMVMDFP()`
1485 @*/
MatCreateLMVMDDFP(MPI_Comm comm,PetscInt n,PetscInt N,Mat * B)1486 PetscErrorCode MatCreateLMVMDDFP(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
1487 {
1488 PetscFunctionBegin;
1489 PetscCall(KSPInitializePackage());
1490 PetscCall(MatCreate(comm, B));
1491 PetscCall(MatSetSizes(*B, n, n, N, N));
1492 PetscCall(MatSetType(*B, MATLMVMDDFP));
1493 PetscCall(MatSetUp(*B));
1494 PetscFunctionReturn(PETSC_SUCCESS);
1495 }
1496
1497 /*@
1498 MatLMVMDenseSetType - Sets the memory storage type for dense `MATLMVM`
1499
1500 Input Parameters:
1501 + B - the `MATLMVM` matrix
1502 - type - scale type, see `MatLMVMDenseSetType`
1503
1504 Options Database Keys:
1505 + -mat_lqn_type <reorder,inplace> - set the strategy
1506 . -mat_lbfgs_type <reorder,inplace> - set the strategy
1507 - -mat_ldfp_type <reorder,inplace> - set the strategy
1508
1509 Level: intermediate
1510
1511 MatLMVMDenseTypes\:
1512 + `MAT_LMVM_DENSE_REORDER` - reorders memory to minimize kernel launch
1513 - `MAT_LMVM_DENSE_INPLACE` - launches kernel inplace to minimize memory movement
1514
1515 .seealso: [](ch_ksp), `MATLMVMDQN`, `MATLMVMDBFGS`, `MATLMVMDDFP`, `MatLMVMDenseType`
1516 @*/
MatLMVMDenseSetType(Mat B,MatLMVMDenseType type)1517 PetscErrorCode MatLMVMDenseSetType(Mat B, MatLMVMDenseType type)
1518 {
1519 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
1520 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
1521
1522 PetscFunctionBegin;
1523 PetscValidHeaderSpecific(B, MAT_CLASSID, 1);
1524 lqn->strategy = type;
1525 PetscFunctionReturn(PETSC_SUCCESS);
1526 }
1527