xref: /petsc/src/ksp/ksp/utils/lmvm/dense/denseqn.c (revision 834855d6effb0d027771461c8e947ee1ce5a1e17)
1 #include <../src/ksp/ksp/utils/lmvm/dense/denseqn.h> /*I "petscksp.h" I*/
2 #include <../src/ksp/ksp/utils/lmvm/blas_cyclic/blas_cyclic.h>
3 #include <petscblaslapack.h>
4 #include <petscmat.h>
5 #include <petscsys.h>
6 #include <petscsystypes.h>
7 #include <petscis.h>
8 #include <petscoptions.h>
9 #include <petscdevice.h>
10 #include <petsc/private/deviceimpl.h>
11 
12 static PetscErrorCode MatMult_LMVMDQN(Mat, Vec, Vec);
13 static PetscErrorCode MatMult_LMVMDBFGS(Mat, Vec, Vec);
14 static PetscErrorCode MatMult_LMVMDDFP(Mat, Vec, Vec);
15 static PetscErrorCode MatSolve_LMVMDQN(Mat, Vec, Vec);
16 static PetscErrorCode MatSolve_LMVMDBFGS(Mat, Vec, Vec);
17 static PetscErrorCode MatSolve_LMVMDDFP(Mat, Vec, Vec);
18 
recycle_index(PetscInt m,PetscInt idx)19 static inline PetscInt recycle_index(PetscInt m, PetscInt idx)
20 {
21   return idx % m;
22 }
23 
history_index(PetscInt m,PetscInt num_updates,PetscInt idx)24 static inline PetscInt history_index(PetscInt m, PetscInt num_updates, PetscInt idx)
25 {
26   return (idx - num_updates) + PetscMin(m, num_updates);
27 }
28 
oldest_update(PetscInt m,PetscInt idx)29 static inline PetscInt oldest_update(PetscInt m, PetscInt idx)
30 {
31   return PetscMax(0, idx - m);
32 }
33 
MatView_LMVMDQN(Mat B,PetscViewer pv)34 static PetscErrorCode MatView_LMVMDQN(Mat B, PetscViewer pv)
35 {
36   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
37   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
38 
39   PetscBool isascii;
40 
41   PetscFunctionBegin;
42   PetscCall(PetscObjectTypeCompare((PetscObject)pv, PETSCVIEWERASCII, &isascii));
43   PetscCall(MatView_LMVM(B, pv));
44   PetscCall(SymBroydenRescaleView(lqn->rescale, pv));
45   if (isascii) PetscCall(PetscViewerASCIIPrintf(pv, "Counts: S x : %" PetscInt_FMT ", S^T x : %" PetscInt_FMT ", Y x : %" PetscInt_FMT ",  Y^T x: %" PetscInt_FMT "\n", lqn->S_count, lqn->St_count, lqn->Y_count, lqn->Yt_count));
46   PetscFunctionReturn(PETSC_SUCCESS);
47 }
48 
MatLMVMDQNResetDestructive(Mat B)49 static PetscErrorCode MatLMVMDQNResetDestructive(Mat B)
50 {
51   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
52   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
53 
54   PetscFunctionBegin;
55   PetscCall(MatDestroy(&lqn->HY));
56   PetscCall(MatDestroy(&lqn->BS));
57   PetscCall(MatDestroy(&lqn->StY_triu));
58   PetscCall(MatDestroy(&lqn->YtS_triu));
59   PetscCall(VecDestroy(&lqn->StFprev));
60   PetscCall(VecDestroy(&lqn->Fprev_ref));
61   lqn->Fprev_state = 0;
62   PetscCall(MatDestroy(&lqn->YtS_triu_strict));
63   PetscCall(MatDestroy(&lqn->StY_triu_strict));
64   PetscCall(MatDestroy(&lqn->StBS));
65   PetscCall(MatDestroy(&lqn->YtHY));
66   PetscCall(MatDestroy(&lqn->J));
67   PetscCall(MatDestroy(&lqn->temp_mat));
68   PetscCall(VecDestroy(&lqn->diag_vec));
69   PetscCall(VecDestroy(&lqn->diag_vec_recycle_order));
70   PetscCall(VecDestroy(&lqn->inv_diag_vec));
71   PetscCall(VecDestroy(&lqn->column_work));
72   PetscCall(VecDestroy(&lqn->column_work2));
73   PetscCall(VecDestroy(&lqn->rwork1));
74   PetscCall(VecDestroy(&lqn->rwork2));
75   PetscCall(VecDestroy(&lqn->rwork3));
76   PetscCall(VecDestroy(&lqn->rwork2_local));
77   PetscCall(VecDestroy(&lqn->rwork3_local));
78   PetscCall(VecDestroy(&lqn->cyclic_work_vec));
79   PetscCall(VecDestroyVecs(lmvm->m, &lqn->PQ));
80   PetscCall(PetscFree(lqn->stp));
81   PetscCall(PetscFree(lqn->yts));
82   PetscCall(PetscFree(lqn->ytq));
83   lqn->allocated = PETSC_FALSE;
84   PetscFunctionReturn(PETSC_SUCCESS);
85 }
86 
MatReset_LMVMDQN_Internal(Mat B,MatLMVMResetMode mode)87 static PetscErrorCode MatReset_LMVMDQN_Internal(Mat B, MatLMVMResetMode mode)
88 {
89   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
90   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
91 
92   PetscFunctionBegin;
93   lqn->watchdog         = 0;
94   lqn->needPQ           = PETSC_TRUE;
95   lqn->num_updates      = 0;
96   lqn->num_mult_updates = 0;
97   if (MatLMVMResetClearsBases(mode)) PetscCall(MatLMVMDQNResetDestructive(B));
98   else {
99     if (lqn->BS) PetscCall(MatZeroEntries(lqn->BS));
100     if (lqn->HY) PetscCall(MatZeroEntries(lqn->HY));
101     if (lqn->StY_triu) { /* Set to identity by default so it is invertible */
102       PetscCall(MatZeroEntries(lqn->StY_triu));
103       PetscCall(MatShift(lqn->StY_triu, 1.0));
104     }
105     if (lqn->YtS_triu) {
106       PetscCall(MatZeroEntries(lqn->YtS_triu));
107       PetscCall(MatShift(lqn->YtS_triu, 1.0));
108     }
109     if (lqn->YtS_triu_strict) PetscCall(MatZeroEntries(lqn->YtS_triu_strict));
110     if (lqn->StY_triu_strict) PetscCall(MatZeroEntries(lqn->StY_triu_strict));
111     if (lqn->StBS) {
112       PetscCall(MatZeroEntries(lqn->StBS));
113       PetscCall(MatShift(lqn->StBS, 1.0));
114     }
115     if (lqn->YtHY) {
116       PetscCall(MatZeroEntries(lqn->YtHY));
117       PetscCall(MatShift(lqn->YtHY, 1.0));
118     }
119     if (lqn->Fprev_ref) PetscCall(VecDestroy(&lqn->Fprev_ref));
120     lqn->Fprev_state = 0;
121     if (lqn->StFprev) PetscCall(VecZeroEntries(lqn->StFprev));
122   }
123   PetscFunctionReturn(PETSC_SUCCESS);
124 }
125 
MatReset_LMVMDQN(Mat B,MatLMVMResetMode mode)126 static PetscErrorCode MatReset_LMVMDQN(Mat B, MatLMVMResetMode mode)
127 {
128   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
129   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
130 
131   PetscFunctionBegin;
132   PetscCall(SymBroydenRescaleReset(B, lqn->rescale, mode));
133   PetscCall(MatReset_LMVMDQN_Internal(B, mode));
134   PetscFunctionReturn(PETSC_SUCCESS);
135 }
136 
MatAllocate_LMVMDQN_Internal(Mat B)137 static PetscErrorCode MatAllocate_LMVMDQN_Internal(Mat B)
138 {
139   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
140   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
141 
142   PetscFunctionBegin;
143   if (!lqn->allocated) {
144     if (lmvm->m > 0) {
145       PetscMPIInt rank;
146       PetscInt    n, N, m, M;
147       PetscBool   is_dbfgs, is_ddfp, is_dqn;
148       VecType     vec_type;
149       MPI_Comm    comm  = PetscObjectComm((PetscObject)B);
150       Mat         Sfull = lmvm->basis[LMBASIS_S]->vecs;
151 
152       PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDBFGS, &is_dbfgs));
153       PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDDFP, &is_ddfp));
154       PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDQN, &is_dqn));
155 
156       PetscCallMPI(MPI_Comm_rank(comm, &rank));
157       PetscCall(MatGetSize(B, &N, NULL));
158       PetscCall(MatGetLocalSize(B, &n, NULL));
159       M = lmvm->m;
160       m = (rank == 0) ? M : 0;
161 
162       /* For DBFGS: Create data needed for MatSolve() eagerly; data needed for MatMult() will be created on demand
163        * For DDFP : Create data needed for MatMult() eagerly; data needed for MatSolve() will be created on demand
164        * For DQN  : Create all data eagerly */
165       PetscCall(VecGetType(lmvm->Xprev, &vec_type));
166       if (is_dqn) {
167         PetscCall(MatCreateDenseFromVecType(comm, vec_type, m, m, M, M, -1, NULL, &lqn->StY_triu));
168         PetscCall(MatCreateDenseFromVecType(comm, vec_type, m, m, M, M, -1, NULL, &lqn->YtS_triu));
169         PetscCall(MatCreateVecs(lqn->StY_triu, &lqn->diag_vec, &lqn->rwork1));
170         PetscCall(MatCreateVecs(lqn->StY_triu, &lqn->rwork2, &lqn->rwork3));
171       } else if (is_ddfp) {
172         PetscCall(MatCreateDenseFromVecType(comm, vec_type, m, m, M, M, -1, NULL, &lqn->YtS_triu));
173         PetscCall(MatDuplicate(Sfull, MAT_SHARE_NONZERO_PATTERN, &lqn->HY));
174         PetscCall(MatCreateVecs(lqn->YtS_triu, &lqn->diag_vec, &lqn->rwork1));
175         PetscCall(MatCreateVecs(lqn->YtS_triu, &lqn->rwork2, &lqn->rwork3));
176       } else if (is_dbfgs) {
177         PetscCall(MatCreateDenseFromVecType(comm, vec_type, m, m, M, M, -1, NULL, &lqn->StY_triu));
178         PetscCall(MatDuplicate(Sfull, MAT_SHARE_NONZERO_PATTERN, &lqn->BS));
179         PetscCall(MatCreateVecs(lqn->StY_triu, &lqn->diag_vec, &lqn->rwork1));
180         PetscCall(MatCreateVecs(lqn->StY_triu, &lqn->rwork2, &lqn->rwork3));
181       } else {
182         SETERRQ(PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_INCOMP, "MatAllocate_LMVMDQN is only available for dense derived types. (DBFGS, DDFP, DQN");
183       }
184       /* initialize StY_triu and YtS_triu to identity, if they exist, so it is invertible */
185       if (lqn->StY_triu) {
186         PetscCall(MatZeroEntries(lqn->StY_triu));
187         PetscCall(MatShift(lqn->StY_triu, 1.0));
188       }
189       if (lqn->YtS_triu) {
190         PetscCall(MatZeroEntries(lqn->YtS_triu));
191         PetscCall(MatShift(lqn->YtS_triu, 1.0));
192       }
193       if (lqn->use_recursive && (is_dbfgs || is_ddfp)) {
194         PetscCall(VecDuplicateVecs(lmvm->Xprev, lmvm->m, &lqn->PQ));
195         PetscCall(VecDuplicate(lmvm->Xprev, &lqn->column_work2));
196         PetscCall(PetscMalloc1(lmvm->m, &lqn->yts));
197         if (is_dbfgs) {
198           PetscCall(PetscMalloc1(lmvm->m, &lqn->stp));
199         } else if (is_ddfp) {
200           PetscCall(PetscMalloc1(lmvm->m, &lqn->ytq));
201         }
202       }
203       PetscCall(VecDuplicate(lqn->rwork2, &lqn->cyclic_work_vec));
204       PetscCall(VecZeroEntries(lqn->rwork1));
205       PetscCall(VecZeroEntries(lqn->rwork2));
206       PetscCall(VecZeroEntries(lqn->rwork3));
207       PetscCall(VecZeroEntries(lqn->diag_vec));
208     }
209     PetscCall(VecDuplicate(lmvm->Xprev, &lqn->column_work));
210     lqn->allocated = PETSC_TRUE;
211   }
212   PetscFunctionReturn(PETSC_SUCCESS);
213 }
214 
MatSetUp_LMVMDQN(Mat B)215 static PetscErrorCode MatSetUp_LMVMDQN(Mat B)
216 {
217   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
218   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
219 
220   PetscFunctionBegin;
221   PetscCall(MatSetUp_LMVM(B));
222   PetscCall(SymBroydenRescaleInitializeJ0(B, lqn->rescale));
223   PetscCall(MatAllocate_LMVMDQN_Internal(B));
224   PetscFunctionReturn(PETSC_SUCCESS);
225 }
226 
MatSetFromOptions_LMVMDQN(Mat B,PetscOptionItems PetscOptionsObject)227 static PetscErrorCode MatSetFromOptions_LMVMDQN(Mat B, PetscOptionItems PetscOptionsObject)
228 {
229   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
230   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
231   PetscBool is_dbfgs, is_ddfp, is_dqn;
232 
233   PetscFunctionBegin;
234   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDBFGS, &is_dbfgs));
235   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDDFP, &is_ddfp));
236   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDQN, &is_dqn));
237   PetscCall(MatSetFromOptions_LMVM(B, PetscOptionsObject));
238   PetscOptionsHeadBegin(PetscOptionsObject, "Dense symmetric Broyden method for approximating SPD Jacobian actions");
239   if (is_dqn) {
240     PetscCall(PetscOptionsEnum("-mat_lqn_type", "Implementation options for L-QN", "MatLMVMDenseType", MatLMVMDenseTypes, (PetscEnum)lqn->strategy, (PetscEnum *)&lqn->strategy, NULL));
241   } else if (is_dbfgs) {
242     PetscCall(PetscOptionsBool("-mat_lbfgs_recursive", "Use recursive formulation for MatMult_LMVMDBFGS, instead of Cholesky", "", lqn->use_recursive, &lqn->use_recursive, NULL));
243     PetscCall(PetscOptionsEnum("-mat_lbfgs_type", "Implementation options for L-BFGS", "MatLMVMDenseType", MatLMVMDenseTypes, (PetscEnum)lqn->strategy, (PetscEnum *)&lqn->strategy, NULL));
244   } else if (is_ddfp) {
245     PetscCall(PetscOptionsBool("-mat_ldfp_recursive", "Use recursive formulation for MatSolve_LMVMDDFP, instead of Cholesky", "", lqn->use_recursive, &lqn->use_recursive, NULL));
246     PetscCall(PetscOptionsEnum("-mat_ldfp_type", "Implementation options for L-DFP", "MatLMVMDenseType", MatLMVMDenseTypes, (PetscEnum)lqn->strategy, (PetscEnum *)&lqn->strategy, NULL));
247   } else {
248     SETERRQ(PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_INCOMP, "MatSetFromOptions_LMVMDQN is only available for dense derived types. (DBFGS, DDFP, DQN");
249   }
250   PetscCall(SymBroydenRescaleSetFromOptions(B, lqn->rescale, PetscOptionsObject));
251   PetscOptionsHeadEnd();
252   PetscFunctionReturn(PETSC_SUCCESS);
253 }
254 
MatDestroy_LMVMDQN(Mat B)255 static PetscErrorCode MatDestroy_LMVMDQN(Mat B)
256 {
257   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
258   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
259 
260   PetscFunctionBegin;
261   PetscCall(SymBroydenRescaleDestroy(&lqn->rescale));
262   PetscCall(MatReset_LMVMDQN_Internal(B, MAT_LMVM_RESET_ALL));
263   PetscCall(PetscFree(lqn->workscalar));
264   PetscCall(PetscFree(lmvm->ctx));
265   PetscCall(MatDestroy_LMVM(B));
266   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSymBroydenSetDelta_C", NULL));
267   PetscFunctionReturn(PETSC_SUCCESS);
268 }
269 
MatUpdate_LMVMDQN(Mat B,Vec X,Vec F)270 static PetscErrorCode MatUpdate_LMVMDQN(Mat B, Vec X, Vec F)
271 {
272   Mat_LMVM *lmvm  = (Mat_LMVM *)B->data;
273   Mat_DQN  *lqn   = (Mat_DQN *)lmvm->ctx;
274   Mat       Sfull = lmvm->basis[LMBASIS_S]->vecs;
275   Mat       Yfull = lmvm->basis[LMBASIS_Y]->vecs;
276 
277   PetscBool          is_ddfp, is_dbfgs, is_dqn;
278   PetscDeviceContext dctx;
279 
280   PetscFunctionBegin;
281   if (!lmvm->m) PetscFunctionReturn(PETSC_SUCCESS);
282   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDBFGS, &is_dbfgs));
283   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDDFP, &is_ddfp));
284   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDQN, &is_dqn));
285   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
286   if (lmvm->prev_set) {
287     Vec         FX[2];
288     PetscScalar dotFX[2];
289     PetscScalar stFprev;
290     PetscScalar curvature, yTy;
291     PetscReal   curvtol;
292 
293     /* Compute the new (S = X - Xprev) and (Y = F - Fprev) vectors */
294     PetscCall(VecAYPX(lmvm->Xprev, -1.0, X));
295     /* Test if the updates can be accepted */
296     FX[0] = lmvm->Fprev; /* dotFX[0] = s^T Fprev */
297     FX[1] = F;           /* dotFX[1] = s^T F     */
298     PetscCall(VecMDot(lmvm->Xprev, 2, FX, dotFX));
299     PetscCall(VecAYPX(lmvm->Fprev, -1.0, F));
300     PetscCall(VecDot(lmvm->Fprev, lmvm->Fprev, &yTy));
301     stFprev   = PetscConj(dotFX[0]);
302     curvature = PetscConj(dotFX[1] - dotFX[0]); /* s^T y */
303     if (PetscRealPart(yTy) < lmvm->eps) {
304       curvtol = 0.0;
305     } else {
306       curvtol = lmvm->eps * PetscRealPart(yTy);
307     }
308     if (PetscRealPart(curvature) > curvtol) {
309       PetscInt m     = lmvm->m;
310       PetscInt k     = lmvm->k;
311       PetscInt h_old = k - oldest_update(m, k);
312       PetscInt h_new = k + 1 - oldest_update(m, k + 1);
313       PetscInt idx   = recycle_index(m, k);
314 
315       /* Update is good, accept it */
316       PetscCall(MatUpdateKernel_LMVM(B, lmvm->Xprev, lmvm->Fprev));
317       lqn->num_updates++;
318       lqn->watchdog = 0;
319       lqn->needPQ   = PETSC_TRUE;
320 
321       if (h_old == m && lqn->strategy == MAT_LMVM_DENSE_REORDER) {
322         if (is_dqn) {
323           PetscCall(MatMove_LR3(B, lqn->StY_triu, m - 1));
324           PetscCall(MatMove_LR3(B, lqn->YtS_triu, m - 1));
325         } else if (is_dbfgs) {
326           PetscCall(MatMove_LR3(B, lqn->StY_triu, m - 1));
327         } else if (is_ddfp) {
328           PetscCall(MatMove_LR3(B, lqn->YtS_triu, m - 1));
329         } else {
330           SETERRQ(PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_INCOMP, "MatUpdate_LMVMDQN is only available for dense derived types. (DBFGS, DDFP, DQN");
331         }
332       }
333 
334       if (lqn->use_recursive && (is_dbfgs || is_ddfp)) lqn->yts[idx] = PetscRealPart(curvature);
335 
336       if (is_dqn || is_dbfgs) { /* implement the scheme of Byrd, Nocedal, and Schnabel to save a MatMultTranspose call in the common case the       *
337          * H_k is immediately applied to F after begin updated.   The S^T y computation can be split up as S^T (F - F_prev) */
338         PetscInt     local_n;
339         PetscScalar *StFprev;
340         PetscMemType memtype;
341         PetscInt     StYidx;
342 
343         StYidx = (lqn->strategy == MAT_LMVM_DENSE_REORDER) ? history_index(m, lqn->num_updates, k) : idx;
344         if (!lqn->StFprev) PetscCall(VecDuplicate(lqn->rwork1, &lqn->StFprev));
345         PetscCall(VecGetLocalSize(lqn->StFprev, &local_n));
346         PetscCall(VecGetArrayAndMemType(lqn->StFprev, &StFprev, &memtype));
347         if (local_n) {
348           if (PetscMemTypeHost(memtype)) {
349             StFprev[idx] = stFprev;
350           } else {
351             PetscCall(PetscDeviceRegisterMemory(&stFprev, PETSC_MEMTYPE_HOST, 1 * sizeof(stFprev)));
352             PetscCall(PetscDeviceRegisterMemory(StFprev, memtype, local_n * sizeof(*StFprev)));
353             PetscCall(PetscDeviceArrayCopy(dctx, &StFprev[idx], &stFprev, 1));
354           }
355         }
356         PetscCall(VecRestoreArrayAndMemType(lqn->StFprev, &StFprev));
357 
358         {
359           Vec this_sy_col;
360           /* Now StFprev is updated for the new S vector.  Write -StFprev into the appropriate row */
361           PetscCall(MatDenseGetColumnVecWrite(lqn->StY_triu, StYidx, &this_sy_col));
362           PetscCall(VecAXPBY(this_sy_col, -1.0, 0.0, lqn->StFprev));
363 
364           /* Now compute the new StFprev */
365           PetscCall(MatMultHermitianTransposeColumnRange(Sfull, F, lqn->StFprev, 0, h_new));
366           lqn->St_count++;
367 
368           /* Now add StFprev: this_sy_col == S^T (F - Fprev) == S^T y */
369           PetscCall(VecAXPY(this_sy_col, 1.0, lqn->StFprev));
370 
371           if (lqn->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, this_sy_col, lqn->num_updates, lqn->cyclic_work_vec));
372           PetscCall(MatDenseRestoreColumnVecWrite(lqn->StY_triu, StYidx, &this_sy_col));
373         }
374       }
375 
376       if (is_ddfp || is_dqn) {
377         PetscInt YtSidx;
378 
379         YtSidx = (lqn->strategy == MAT_LMVM_DENSE_REORDER) ? history_index(m, lqn->num_updates, k) : idx;
380 
381         {
382           Vec this_ys_col;
383 
384           PetscCall(MatDenseGetColumnVecWrite(lqn->YtS_triu, YtSidx, &this_ys_col));
385           PetscCall(MatMultHermitianTransposeColumnRange(Yfull, lmvm->Xprev, this_ys_col, 0, h_new));
386           lqn->Yt_count++;
387 
388           if (lqn->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, this_ys_col, lqn->num_updates, lqn->cyclic_work_vec));
389           PetscCall(MatDenseRestoreColumnVecWrite(lqn->YtS_triu, YtSidx, &this_ys_col));
390         }
391       }
392 
393       if (is_dbfgs || is_dqn) {
394         PetscCall(MatGetDiagonal(lqn->StY_triu, lqn->diag_vec));
395       } else if (is_ddfp) {
396         PetscCall(MatGetDiagonal(lqn->YtS_triu, lqn->diag_vec));
397       } else {
398         SETERRQ(PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_INCOMP, "MatUpdate_LMVMDQN is only available for dense derived types. (DBFGS, DDFP, DQN");
399       }
400 
401       if (lqn->strategy == MAT_LMVM_DENSE_REORDER) {
402         if (!lqn->diag_vec_recycle_order) PetscCall(VecDuplicate(lqn->diag_vec, &lqn->diag_vec_recycle_order));
403         PetscCall(VecCopy(lqn->diag_vec, lqn->diag_vec_recycle_order));
404         PetscCall(VecHistoryOrderToRecycleOrder(B, lqn->diag_vec_recycle_order, lqn->num_updates, lqn->cyclic_work_vec));
405       } else {
406         if (!lqn->diag_vec_recycle_order) {
407           PetscCall(PetscObjectReference((PetscObject)lqn->diag_vec));
408           lqn->diag_vec_recycle_order = lqn->diag_vec;
409         }
410       }
411 
412       PetscCall(SymBroydenRescaleUpdate(B, lqn->rescale));
413     } else {
414       /* Update is bad, skip it */
415       ++lmvm->nrejects;
416       ++lqn->watchdog;
417       PetscInt m = lmvm->m;
418       PetscInt k = lmvm->k;
419       PetscInt h = k - oldest_update(m, k);
420 
421       /* we still have to maintain StFprev */
422       if (!lqn->StFprev) PetscCall(VecDuplicate(lqn->rwork1, &lqn->StFprev));
423       PetscCall(MatMultHermitianTransposeColumnRange(Sfull, F, lqn->StFprev, 0, h));
424       lqn->St_count++;
425     }
426   }
427 
428   if (lqn->watchdog > lqn->max_seq_rejects) PetscCall(MatLMVMReset(B, PETSC_FALSE));
429 
430   /* Save the solution and function to be used in the next update */
431   PetscCall(VecCopy(X, lmvm->Xprev));
432   PetscCall(VecCopy(F, lmvm->Fprev));
433   PetscCall(PetscObjectReference((PetscObject)F));
434   PetscCall(VecDestroy(&lqn->Fprev_ref));
435   lqn->Fprev_ref = F;
436   PetscCall(PetscObjectStateGet((PetscObject)F, &lqn->Fprev_state));
437   lmvm->prev_set = PETSC_TRUE;
438   PetscFunctionReturn(PETSC_SUCCESS);
439 }
440 
MatDestroyThenCopy(Mat src,Mat * dst)441 static PetscErrorCode MatDestroyThenCopy(Mat src, Mat *dst)
442 {
443   PetscFunctionBegin;
444   PetscCall(MatDestroy(dst));
445   if (src) PetscCall(MatDuplicate(src, MAT_COPY_VALUES, dst));
446   PetscFunctionReturn(PETSC_SUCCESS);
447 }
448 
VecDestroyThenCopy(Vec src,Vec * dst)449 static PetscErrorCode VecDestroyThenCopy(Vec src, Vec *dst)
450 {
451   PetscFunctionBegin;
452   PetscCall(VecDestroy(dst));
453   if (src) {
454     PetscCall(VecDuplicate(src, dst));
455     PetscCall(VecCopy(src, *dst));
456   }
457   PetscFunctionReturn(PETSC_SUCCESS);
458 }
459 
MatCopy_LMVMDQN(Mat B,Mat M,MatStructure str)460 static PetscErrorCode MatCopy_LMVMDQN(Mat B, Mat M, MatStructure str)
461 {
462   Mat_LMVM *bdata = (Mat_LMVM *)B->data;
463   Mat_DQN  *blqn  = (Mat_DQN *)bdata->ctx;
464   Mat_LMVM *mdata = (Mat_LMVM *)M->data;
465   Mat_DQN  *mlqn  = (Mat_DQN *)mdata->ctx;
466   PetscInt  i;
467   PetscBool is_dbfgs, is_ddfp, is_dqn;
468 
469   PetscFunctionBegin;
470   PetscCall(SymBroydenRescaleCopy(blqn->rescale, mlqn->rescale));
471   mlqn->num_updates      = blqn->num_updates;
472   mlqn->num_mult_updates = blqn->num_mult_updates;
473   mlqn->dense_type       = blqn->dense_type;
474   mlqn->strategy         = blqn->strategy;
475   mlqn->S_count          = 0;
476   mlqn->St_count         = 0;
477   mlqn->Y_count          = 0;
478   mlqn->Yt_count         = 0;
479   mlqn->watchdog         = blqn->watchdog;
480   mlqn->max_seq_rejects  = blqn->max_seq_rejects;
481   mlqn->use_recursive    = blqn->use_recursive;
482   mlqn->needPQ           = blqn->needPQ;
483   if (blqn->allocated) {
484     PetscCall(MatAllocate_LMVMDQN_Internal(M));
485     PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDBFGS, &is_dbfgs));
486     PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDDFP, &is_ddfp));
487     PetscCall(PetscObjectTypeCompare((PetscObject)B, MATLMVMDQN, &is_dqn));
488     PetscCall(MatDestroyThenCopy(blqn->HY, &mlqn->BS));
489     PetscCall(VecDestroyThenCopy(blqn->StFprev, &mlqn->StFprev));
490     PetscCall(MatDestroyThenCopy(blqn->StY_triu, &mlqn->StY_triu));
491     PetscCall(MatDestroyThenCopy(blqn->StY_triu_strict, &mlqn->StY_triu_strict));
492     PetscCall(MatDestroyThenCopy(blqn->YtS_triu, &mlqn->YtS_triu));
493     PetscCall(MatDestroyThenCopy(blqn->YtS_triu_strict, &mlqn->YtS_triu_strict));
494     PetscCall(MatDestroyThenCopy(blqn->YtHY, &mlqn->YtHY));
495     PetscCall(MatDestroyThenCopy(blqn->StBS, &mlqn->StBS));
496     PetscCall(MatDestroyThenCopy(blqn->J, &mlqn->J));
497     PetscCall(VecDestroyThenCopy(blqn->diag_vec, &mlqn->diag_vec));
498     PetscCall(VecDestroyThenCopy(blqn->diag_vec_recycle_order, &mlqn->diag_vec_recycle_order));
499     PetscCall(VecDestroyThenCopy(blqn->inv_diag_vec, &mlqn->inv_diag_vec));
500     if (blqn->use_recursive && (is_dbfgs || is_ddfp)) {
501       for (i = 0; i < bdata->m; i++) {
502         PetscCall(VecDestroyThenCopy(blqn->PQ[i], &mlqn->PQ[i]));
503         mlqn->yts[i] = blqn->yts[i];
504         if (is_dbfgs) {
505           mlqn->stp[i] = blqn->stp[i];
506         } else if (is_ddfp) {
507           mlqn->ytq[i] = blqn->ytq[i];
508         }
509       }
510     }
511   }
512   PetscCall(PetscObjectReference((PetscObject)blqn->Fprev_ref));
513   PetscCall(VecDestroy(&mlqn->Fprev_ref));
514   mlqn->Fprev_ref   = blqn->Fprev_ref;
515   mlqn->Fprev_state = blqn->Fprev_state;
516   PetscFunctionReturn(PETSC_SUCCESS);
517 }
518 
MatMult_LMVMDQN(Mat B,Vec X,Vec Z)519 static PetscErrorCode MatMult_LMVMDQN(Mat B, Vec X, Vec Z)
520 {
521   PetscFunctionBegin;
522   PetscCall(MatMult_LMVMDDFP(B, X, Z));
523   PetscFunctionReturn(PETSC_SUCCESS);
524 }
525 
MatSolve_LMVMDQN(Mat H,Vec F,Vec dX)526 static PetscErrorCode MatSolve_LMVMDQN(Mat H, Vec F, Vec dX)
527 {
528   PetscFunctionBegin;
529   PetscCall(MatSolve_LMVMDBFGS(H, F, dX));
530   PetscFunctionReturn(PETSC_SUCCESS);
531 }
532 
MatLMVMSymBroydenSetDelta_LMVMDQN(Mat B,PetscScalar delta)533 static PetscErrorCode MatLMVMSymBroydenSetDelta_LMVMDQN(Mat B, PetscScalar delta)
534 {
535   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
536   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
537 
538   PetscFunctionBegin;
539   PetscCall(SymBroydenRescaleSetDelta(B, lqn->rescale, PetscAbsReal(PetscRealPart(delta))));
540   PetscFunctionReturn(PETSC_SUCCESS);
541 }
542 
543 /*
544   This dense representation uses Davidon-Fletcher-Powell (DFP) for MatMult,
545   and Broyden-Fletcher-Goldfarb-Shanno (BFGS) for MatSolve. This implementation
546   results in avoiding costly Cholesky factorization, at the cost of duality cap.
547   Please refer to MatLMVMDDFP and MatLMVMDBFGS for more information.
548 */
MatCreate_LMVMDQN(Mat B)549 PetscErrorCode MatCreate_LMVMDQN(Mat B)
550 {
551   Mat_LMVM *lmvm;
552   Mat_DQN  *lqn;
553 
554   PetscFunctionBegin;
555   PetscCall(MatCreate_LMVM(B));
556   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMDQN));
557   PetscCall(MatSetOption(B, MAT_HERMITIAN, PETSC_TRUE));
558   PetscCall(MatSetOption(B, MAT_SPD, PETSC_TRUE));
559   PetscCall(MatSetOption(B, MAT_SPD_ETERNAL, PETSC_TRUE));
560   B->ops->view           = MatView_LMVMDQN;
561   B->ops->setup          = MatSetUp_LMVMDQN;
562   B->ops->setfromoptions = MatSetFromOptions_LMVMDQN;
563   B->ops->destroy        = MatDestroy_LMVMDQN;
564 
565   lmvm              = (Mat_LMVM *)B->data;
566   lmvm->ops->reset  = MatReset_LMVMDQN;
567   lmvm->ops->update = MatUpdate_LMVMDQN;
568   lmvm->ops->mult   = MatMult_LMVMDQN;
569   lmvm->ops->solve  = MatSolve_LMVMDQN;
570   lmvm->ops->copy   = MatCopy_LMVMDQN;
571 
572   lmvm->ops->multht  = lmvm->ops->mult;
573   lmvm->ops->solveht = lmvm->ops->solve;
574 
575   PetscCall(PetscNew(&lqn));
576   lmvm->ctx            = (void *)lqn;
577   lqn->allocated       = PETSC_FALSE;
578   lqn->use_recursive   = PETSC_FALSE;
579   lqn->needPQ          = PETSC_FALSE;
580   lqn->watchdog        = 0;
581   lqn->max_seq_rejects = lmvm->m / 2;
582   lqn->strategy        = MAT_LMVM_DENSE_INPLACE;
583 
584   PetscCall(SymBroydenRescaleCreate(&lqn->rescale));
585   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSymBroydenSetDelta_C", MatLMVMSymBroydenSetDelta_LMVMDQN));
586   PetscFunctionReturn(PETSC_SUCCESS);
587 }
588 
589 /*@
590   MatCreateLMVMDQN - Creates a dense representation of the limited-memory
591   Quasi-Newton approximation to a Hessian.
592 
593   Collective
594 
595   Input Parameters:
596 + comm - MPI communicator
597 . n    - number of local rows for storage vectors
598 - N    - global size of the storage vectors
599 
600   Output Parameter:
601 . B - the matrix
602 
603   Level: advanced
604 
605   Note:
606   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`
607   paradigm instead of this routine directly.
608 
609 .seealso: `MatCreate()`, `MATLMVM`, `MATLMVMDBFGS`, `MATLMVMDDFP`, `MatCreateLMVMDDFP()`, `MatCreateLMVMDBFGS()`
610 @*/
MatCreateLMVMDQN(MPI_Comm comm,PetscInt n,PetscInt N,Mat * B)611 PetscErrorCode MatCreateLMVMDQN(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
612 {
613   PetscFunctionBegin;
614   PetscCall(KSPInitializePackage());
615   PetscCall(MatCreate(comm, B));
616   PetscCall(MatSetSizes(*B, n, n, N, N));
617   PetscCall(MatSetType(*B, MATLMVMDQN));
618   PetscCall(MatSetUp(*B));
619   PetscFunctionReturn(PETSC_SUCCESS);
620 }
621 
MatDQNApplyJ0Fwd(Mat B,Vec X,Vec Z)622 static PetscErrorCode MatDQNApplyJ0Fwd(Mat B, Vec X, Vec Z)
623 {
624   PetscFunctionBegin;
625   PetscCall(MatLMVMApplyJ0Fwd(B, X, Z));
626   PetscFunctionReturn(PETSC_SUCCESS);
627 }
628 
MatDQNApplyJ0Inv(Mat B,Vec F,Vec dX)629 static PetscErrorCode MatDQNApplyJ0Inv(Mat B, Vec F, Vec dX)
630 {
631   PetscFunctionBegin;
632   PetscCall(MatLMVMApplyJ0Inv(B, F, dX));
633   PetscFunctionReturn(PETSC_SUCCESS);
634 }
635 
636 /* This is not Bunch-Kaufman LDLT: here L is strictly lower triangular part of STY */
MatGetLDLT(Mat B,Mat result)637 static PetscErrorCode MatGetLDLT(Mat B, Mat result)
638 {
639   Mat_LMVM *lmvm  = (Mat_LMVM *)B->data;
640   Mat_DQN  *lbfgs = (Mat_DQN *)lmvm->ctx;
641   PetscInt  m_local;
642 
643   PetscFunctionBegin;
644   if (!lbfgs->temp_mat) PetscCall(MatDuplicate(lbfgs->YtS_triu_strict, MAT_SHARE_NONZERO_PATTERN, &lbfgs->temp_mat));
645   PetscCall(MatCopy(lbfgs->YtS_triu_strict, lbfgs->temp_mat, SAME_NONZERO_PATTERN));
646   PetscCall(MatDiagonalScale(lbfgs->temp_mat, lbfgs->inv_diag_vec, NULL));
647   PetscCall(MatGetLocalSize(result, &m_local, NULL));
648   // need to conjugate and conjugate again because we have MatTransposeMatMult but not MatHermitianTransposeMatMult()
649   PetscCall(MatConjugate(lbfgs->temp_mat));
650   if (m_local) {
651     Mat temp_local, YtS_local, result_local;
652     PetscCall(MatDenseGetLocalMatrix(lbfgs->YtS_triu_strict, &YtS_local));
653     PetscCall(MatDenseGetLocalMatrix(lbfgs->temp_mat, &temp_local));
654     PetscCall(MatDenseGetLocalMatrix(result, &result_local));
655     PetscCall(MatTransposeMatMult(YtS_local, temp_local, MAT_REUSE_MATRIX, PETSC_DETERMINE, &result_local));
656   }
657   PetscCall(MatConjugate(result));
658   PetscFunctionReturn(PETSC_SUCCESS);
659 }
660 
MatLMVMDBFGSUpdateMultData(Mat B)661 static PetscErrorCode MatLMVMDBFGSUpdateMultData(Mat B)
662 {
663   Mat_LMVM *lmvm  = (Mat_LMVM *)B->data;
664   Mat_DQN  *lbfgs = (Mat_DQN *)lmvm->ctx;
665   PetscInt  m     = lmvm->m, m_local;
666   PetscInt  k     = lmvm->k;
667   PetscInt  h     = k - oldest_update(m, k);
668   PetscInt  j_0;
669   PetscInt  prev_oldest;
670   Mat       J_local;
671   Mat       Sfull = lmvm->basis[LMBASIS_S]->vecs;
672   Mat       Yfull = lmvm->basis[LMBASIS_Y]->vecs;
673 
674   PetscFunctionBegin;
675   if (!lbfgs->YtS_triu_strict) {
676     PetscCall(MatDuplicate(lbfgs->StY_triu, MAT_SHARE_NONZERO_PATTERN, &lbfgs->YtS_triu_strict));
677     PetscCall(MatDestroy(&lbfgs->StBS));
678     PetscCall(MatDuplicate(lbfgs->StY_triu, MAT_SHARE_NONZERO_PATTERN, &lbfgs->StBS));
679     PetscCall(MatDestroy(&lbfgs->J));
680     PetscCall(MatDuplicate(lbfgs->StY_triu, MAT_SHARE_NONZERO_PATTERN, &lbfgs->J));
681     PetscCall(MatDestroy(&lbfgs->BS));
682     PetscCall(MatDuplicate(Yfull, MAT_SHARE_NONZERO_PATTERN, &lbfgs->BS));
683     PetscCall(MatShift(lbfgs->StBS, 1.0));
684     lbfgs->num_mult_updates = oldest_update(m, k);
685   }
686   if (lbfgs->num_mult_updates == k) PetscFunctionReturn(PETSC_SUCCESS);
687 
688   /* B_0 may have been updated, we must recompute B_0 S and S^T B_0 S */
689   for (PetscInt j = oldest_update(m, k); j < k; j++) {
690     Vec      s_j;
691     Vec      Bs_j;
692     Vec      StBs_j;
693     PetscInt S_idx    = recycle_index(m, j);
694     PetscInt StBS_idx = lbfgs->strategy == MAT_LMVM_DENSE_INPLACE ? S_idx : history_index(m, k, j);
695 
696     PetscCall(MatDenseGetColumnVecWrite(lbfgs->BS, S_idx, &Bs_j));
697     PetscCall(MatDenseGetColumnVecRead(Sfull, S_idx, &s_j));
698     PetscCall(MatDQNApplyJ0Fwd(B, s_j, Bs_j));
699     PetscCall(MatDenseRestoreColumnVecRead(Sfull, S_idx, &s_j));
700     PetscCall(MatDenseGetColumnVecWrite(lbfgs->StBS, StBS_idx, &StBs_j));
701     PetscCall(MatMultHermitianTransposeColumnRange(Sfull, Bs_j, StBs_j, 0, h));
702     lbfgs->St_count++;
703     if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, StBs_j, lbfgs->num_updates, lbfgs->cyclic_work_vec));
704     PetscCall(MatDenseRestoreColumnVecWrite(lbfgs->StBS, StBS_idx, &StBs_j));
705     PetscCall(MatDenseRestoreColumnVecWrite(lbfgs->BS, S_idx, &Bs_j));
706   }
707   prev_oldest = oldest_update(m, lbfgs->num_mult_updates);
708   if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER && prev_oldest < oldest_update(m, k)) {
709     /* move the YtS entries that have been computed and need to be kept back up */
710     PetscInt m_keep = m - (oldest_update(m, k) - prev_oldest);
711 
712     PetscCall(MatMove_LR3(B, lbfgs->YtS_triu_strict, m_keep));
713   }
714   PetscCall(MatGetLocalSize(lbfgs->YtS_triu_strict, &m_local, NULL));
715   j_0 = PetscMax(lbfgs->num_mult_updates, oldest_update(m, k));
716   for (PetscInt j = j_0; j < k; j++) {
717     PetscInt S_idx   = recycle_index(m, j);
718     PetscInt YtS_idx = lbfgs->strategy == MAT_LMVM_DENSE_INPLACE ? S_idx : history_index(m, k, j);
719     Vec      s_j, Yts_j;
720 
721     PetscCall(MatDenseGetColumnVecRead(Sfull, S_idx, &s_j));
722     PetscCall(MatDenseGetColumnVecWrite(lbfgs->YtS_triu_strict, YtS_idx, &Yts_j));
723     PetscCall(MatMultHermitianTransposeColumnRange(Yfull, s_j, Yts_j, 0, h));
724     lbfgs->Yt_count++;
725     if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, Yts_j, lbfgs->num_updates, lbfgs->cyclic_work_vec));
726     PetscCall(MatDenseRestoreColumnVecWrite(lbfgs->YtS_triu_strict, YtS_idx, &Yts_j));
727     PetscCall(MatDenseRestoreColumnVecRead(Sfull, S_idx, &s_j));
728     /* zero the corresponding row */
729     if (m_local > 0) {
730       Mat YtS_local, YtS_row;
731 
732       PetscCall(MatDenseGetLocalMatrix(lbfgs->YtS_triu_strict, &YtS_local));
733       PetscCall(MatDenseGetSubMatrix(YtS_local, YtS_idx, YtS_idx + 1, PETSC_DECIDE, PETSC_DECIDE, &YtS_row));
734       PetscCall(MatZeroEntries(YtS_row));
735       PetscCall(MatDenseRestoreSubMatrix(YtS_local, &YtS_row));
736     }
737   }
738   if (!lbfgs->inv_diag_vec) PetscCall(VecDuplicate(lbfgs->diag_vec, &lbfgs->inv_diag_vec));
739   PetscCall(VecCopy(lbfgs->diag_vec, lbfgs->inv_diag_vec));
740   PetscCall(VecReciprocal(lbfgs->inv_diag_vec));
741   PetscCall(MatDenseGetLocalMatrix(lbfgs->J, &J_local));
742   PetscCall(MatSetFactorType(J_local, MAT_FACTOR_NONE));
743   PetscCall(MatGetLDLT(B, lbfgs->J));
744   PetscCall(MatAXPY(lbfgs->J, 1.0, lbfgs->StBS, SAME_NONZERO_PATTERN));
745   if (m_local) {
746     PetscCall(MatSetOption(J_local, MAT_SPD, PETSC_TRUE));
747     PetscCall(MatCholeskyFactor(J_local, NULL, NULL));
748   }
749   lbfgs->num_mult_updates = lbfgs->num_updates;
750   PetscFunctionReturn(PETSC_SUCCESS);
751 }
752 
753 /* Solves for
754  * [ I | -S R^{-T} ] [  I  | 0 ] [ H_0 | 0 ] [ I | Y ] [      I      ]
755  *                   [-----+---] [-----+---] [---+---] [-------------]
756  *                   [ Y^T | I ] [  0  | D ] [ 0 | I ] [ -R^{-1} S^T ]  */
757 
MatSolve_LMVMDBFGS(Mat H,Vec F,Vec dX)758 static PetscErrorCode MatSolve_LMVMDBFGS(Mat H, Vec F, Vec dX)
759 {
760   Mat_LMVM        *lmvm   = (Mat_LMVM *)H->data;
761   Mat_DQN         *lbfgs  = (Mat_DQN *)lmvm->ctx;
762   Vec              rwork1 = lbfgs->rwork1;
763   PetscInt         m      = lmvm->m;
764   PetscInt         k      = lmvm->k;
765   PetscInt         h      = k - oldest_update(m, k);
766   Mat              Sfull  = lmvm->basis[LMBASIS_S]->vecs;
767   Mat              Yfull  = lmvm->basis[LMBASIS_Y]->vecs;
768   PetscObjectState Fstate;
769 
770   PetscFunctionBegin;
771   VecCheckSameSize(F, 2, dX, 3);
772   VecCheckMatCompatible(H, dX, 3, F, 2);
773 
774   /* Block Version */
775   if (!lbfgs->num_updates) {
776     PetscCall(MatDQNApplyJ0Inv(H, F, dX));
777     PetscFunctionReturn(PETSC_SUCCESS); /* No updates stored yet */
778   }
779 
780   PetscCall(PetscObjectStateGet((PetscObject)F, &Fstate));
781   if (F == lbfgs->Fprev_ref && Fstate == lbfgs->Fprev_state) {
782     PetscCall(VecCopy(lbfgs->StFprev, rwork1));
783   } else {
784     PetscCall(MatMultHermitianTransposeColumnRange(Sfull, F, rwork1, 0, h));
785     lbfgs->St_count++;
786   }
787 
788   /* Reordering rwork1, as STY is in history order, while S is in recycled order */
789   if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(H, rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
790   PetscCall(MatUpperTriangularSolveInPlace(H, lbfgs->StY_triu, rwork1, PETSC_FALSE, lbfgs->num_updates, lbfgs->strategy));
791   PetscCall(VecScale(rwork1, -1.0));
792   if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecHistoryOrderToRecycleOrder(H, rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
793 
794   PetscCall(VecCopy(F, lbfgs->column_work));
795   PetscCall(MatMultAddColumnRange(Yfull, rwork1, lbfgs->column_work, lbfgs->column_work, 0, h));
796   lbfgs->Y_count++;
797 
798   PetscCall(VecPointwiseMult(rwork1, lbfgs->diag_vec_recycle_order, rwork1));
799   PetscCall(MatDQNApplyJ0Inv(H, lbfgs->column_work, dX));
800 
801   PetscCall(MatMultHermitianTransposeAddColumnRange(Yfull, dX, rwork1, rwork1, 0, h));
802   lbfgs->Yt_count++;
803 
804   if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(H, rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
805   PetscCall(MatUpperTriangularSolveInPlace(H, lbfgs->StY_triu, rwork1, PETSC_TRUE, lbfgs->num_updates, lbfgs->strategy));
806   PetscCall(VecScale(rwork1, -1.0));
807   if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecHistoryOrderToRecycleOrder(H, rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
808 
809   PetscCall(MatMultAddColumnRange(Sfull, rwork1, dX, dX, 0, h));
810   lbfgs->S_count++;
811   PetscFunctionReturn(PETSC_SUCCESS);
812 }
813 
814 /* Solves for
815    B_0 - [ Y | B_0 S] [ -D  |    L^T    ]^-1 [   Y^T   ]
816                       [-----+-----------]    [---------]
817                       [  L  | S^T B_0 S ]    [ S^T B_0 ]
818 
819    Above is equivalent to
820 
821    B_0 - [ Y | B_0 S] [[     I     | 0 ][ -D  | 0 ][ I | -D^{-1} L^T ]]^-1 [   Y^T   ]
822                       [[-----------+---][-----+---][---+-------------]]    [---------]
823                       [[ -L D^{-1} | I ][  0  | J ][ 0 |       I     ]]    [ S^T B_0 ]
824 
825    where J = S^T B_0 S + L D^{-1} L^T
826 
827    becomes
828 
829    B_0 - [ Y | B_0 S] [ I | D^{-1} L^T ][ -D^{-1}  |   0    ][    I     | 0 ] [   Y^T   ]
830                       [---+------------][----------+--------][----------+---] [---------]
831                       [ 0 |     I      ][     0    | J^{-1} ][ L D^{-1} | I ] [ S^T B_0 ]
832 
833                       =
834 
835    B_0 + [ Y | B_0 S] [ D^{-1} | 0 ][ I | L^T ][ I |    0    ][     I    | 0 ] [   Y^T   ]
836                       [--------+---][---+-----][---+---------][----------+---] [---------]
837                       [ 0      | I ][ 0 |  I  ][ 0 | -J^{-1} ][ L D^{-1} | I ] [ S^T B_0 ]
838 
839                       (Note that YtS_triu_strict is L^T)
840    Byrd, Nocedal, Schnabel 1994
841 
842    Alternative approach: considering the fact that DFP is dual to BFGS, use MatMult of DPF:
843    (See ddfp.c's MatMult_LMVMDDFP)
844 
845 */
MatMult_LMVMDBFGS(Mat B,Vec X,Vec Z)846 static PetscErrorCode MatMult_LMVMDBFGS(Mat B, Vec X, Vec Z)
847 {
848   Mat_LMVM *lmvm  = (Mat_LMVM *)B->data;
849   Mat_DQN  *lbfgs = (Mat_DQN *)lmvm->ctx;
850   Mat       J_local;
851   PetscInt  idx, i, j, m_local, local_n;
852   PetscInt  m     = lmvm->m;
853   PetscInt  k     = lmvm->k;
854   PetscInt  h     = k - oldest_update(m, k);
855   Mat       Sfull = lmvm->basis[LMBASIS_S]->vecs;
856   Mat       Yfull = lmvm->basis[LMBASIS_Y]->vecs;
857 
858   PetscFunctionBegin;
859   VecCheckSameSize(X, 2, Z, 3);
860   VecCheckMatCompatible(B, X, 2, Z, 3);
861 
862   /* Cholesky Version */
863   /* Start with the B0 term */
864   PetscCall(MatDQNApplyJ0Fwd(B, X, Z));
865   if (!lbfgs->num_updates) PetscFunctionReturn(PETSC_SUCCESS); /* No updates stored yet */
866 
867   if (lbfgs->use_recursive) {
868     PetscDeviceContext dctx;
869     PetscMemType       memtype;
870     PetscScalar        stz, ytx, stp, sjtpi, yjtsi, *workscalar;
871     PetscInt           oldest = oldest_update(m, k);
872 
873     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
874     /* Recursive formulation to avoid Cholesky. Not a dense formulation */
875     PetscCall(MatMultHermitianTransposeColumnRange(Yfull, X, lbfgs->rwork1, 0, h));
876     lbfgs->Yt_count++;
877 
878     PetscCall(VecGetLocalSize(lbfgs->rwork1, &local_n));
879 
880     if (lbfgs->needPQ) {
881       PetscInt oldest = oldest_update(m, k);
882       for (i = oldest; i < k; ++i) {
883         idx = recycle_index(m, i);
884         /* column_work = S[idx] */
885         PetscCall(MatGetColumnVector(Sfull, lbfgs->column_work, idx));
886         PetscCall(MatDQNApplyJ0Fwd(B, lbfgs->column_work, lbfgs->PQ[idx]));
887         PetscCall(MatMultHermitianTransposeColumnRange(Yfull, lbfgs->column_work, lbfgs->rwork3, 0, h));
888         PetscCall(VecGetArrayAndMemType(lbfgs->rwork3, &workscalar, &memtype));
889         for (j = oldest; j < i; ++j) {
890           PetscInt idx_j = recycle_index(m, j);
891           /* Copy yjtsi in device-aware manner */
892           if (local_n) {
893             if (PetscMemTypeHost(memtype)) {
894               yjtsi = workscalar[idx_j];
895             } else {
896               PetscCall(PetscDeviceRegisterMemory(&yjtsi, PETSC_MEMTYPE_HOST, sizeof(yjtsi)));
897               PetscCall(PetscDeviceRegisterMemory(workscalar, memtype, local_n * sizeof(*workscalar)));
898               PetscCall(PetscDeviceArrayCopy(dctx, &yjtsi, &workscalar[idx_j], 1));
899             }
900           }
901           PetscCallMPI(MPI_Bcast(&yjtsi, 1, MPIU_SCALAR, 0, PetscObjectComm((PetscObject)B)));
902           /* column_work2 = S[j] */
903           PetscCall(MatGetColumnVector(Sfull, lbfgs->column_work2, idx_j));
904           PetscCall(VecDot(lbfgs->PQ[idx], lbfgs->column_work2, &sjtpi));
905           /* column_work2 = Y[j] */
906           PetscCall(MatGetColumnVector(Yfull, lbfgs->column_work2, idx_j));
907           /* Compute the pure BFGS component of the forward product */
908           PetscCall(VecAXPBYPCZ(lbfgs->PQ[idx], -sjtpi / lbfgs->stp[idx_j], yjtsi / lbfgs->yts[idx_j], 1.0, lbfgs->PQ[idx_j], lbfgs->column_work2));
909         }
910         PetscCall(VecDot(lbfgs->PQ[idx], lbfgs->column_work, &stp));
911         lbfgs->stp[idx] = PetscRealPart(stp);
912       }
913       lbfgs->needPQ = PETSC_FALSE;
914     }
915 
916     PetscCall(VecGetArrayAndMemType(lbfgs->rwork1, &workscalar, &memtype));
917     for (i = oldest; i < k; ++i) {
918       idx = recycle_index(m, i);
919       /* Copy stz[i], ytx[i] in device-aware manner */
920       if (local_n) {
921         if (PetscMemTypeHost(memtype)) {
922           ytx = workscalar[idx];
923         } else {
924           PetscCall(PetscDeviceRegisterMemory(&ytx, PETSC_MEMTYPE_HOST, 1 * sizeof(ytx)));
925           PetscCall(PetscDeviceRegisterMemory(workscalar, memtype, local_n * sizeof(*workscalar)));
926           PetscCall(PetscDeviceArrayCopy(dctx, &ytx, &workscalar[idx], 1));
927         }
928       }
929       PetscCallMPI(MPI_Bcast(&ytx, 1, MPIU_SCALAR, 0, PetscObjectComm((PetscObject)B)));
930       /* column_work : S[i], column_work2 : Y[i] */
931       PetscCall(MatGetColumnVector(Sfull, lbfgs->column_work, idx));
932       PetscCall(MatGetColumnVector(Yfull, lbfgs->column_work2, idx));
933       PetscCall(VecDot(Z, lbfgs->column_work, &stz));
934       PetscCall(VecAXPBYPCZ(Z, -stz / lbfgs->stp[idx], ytx / lbfgs->yts[idx], 1.0, lbfgs->PQ[idx], lbfgs->column_work2));
935     }
936     PetscCall(VecRestoreArrayAndMemType(lbfgs->rwork1, &workscalar));
937   } else {
938     PetscCall(MatLMVMDBFGSUpdateMultData(B));
939     PetscCall(MatMultHermitianTransposeColumnRange(Yfull, X, lbfgs->rwork1, 0, h));
940     lbfgs->Yt_count++;
941     PetscCall(MatMultHermitianTransposeColumnRange(Sfull, Z, lbfgs->rwork2, 0, h));
942     lbfgs->St_count++;
943     if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) {
944       PetscCall(VecRecycleOrderToHistoryOrder(B, lbfgs->rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
945       PetscCall(VecRecycleOrderToHistoryOrder(B, lbfgs->rwork2, lbfgs->num_updates, lbfgs->cyclic_work_vec));
946     }
947 
948     PetscCall(VecPointwiseMult(lbfgs->rwork1, lbfgs->rwork1, lbfgs->inv_diag_vec));
949     if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(lbfgs->YtS_triu_strict));
950     PetscCall(MatMultTransposeAdd(lbfgs->YtS_triu_strict, lbfgs->rwork1, lbfgs->rwork2, lbfgs->rwork2));
951     if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(lbfgs->YtS_triu_strict));
952 
953     if (!lbfgs->rwork2_local) PetscCall(VecCreateLocalVector(lbfgs->rwork2, &lbfgs->rwork2_local));
954     if (!lbfgs->rwork3_local) PetscCall(VecCreateLocalVector(lbfgs->rwork3, &lbfgs->rwork3_local));
955     PetscCall(VecGetLocalVectorRead(lbfgs->rwork2, lbfgs->rwork2_local));
956     PetscCall(VecGetLocalVector(lbfgs->rwork3, lbfgs->rwork3_local));
957     PetscCall(MatDenseGetLocalMatrix(lbfgs->J, &J_local));
958     PetscCall(VecGetSize(lbfgs->rwork2_local, &m_local));
959     if (m_local) {
960       PetscCall(MatDenseGetLocalMatrix(lbfgs->J, &J_local));
961       PetscCall(MatSolve(J_local, lbfgs->rwork2_local, lbfgs->rwork3_local));
962     }
963     PetscCall(VecRestoreLocalVector(lbfgs->rwork3, lbfgs->rwork3_local));
964     PetscCall(VecRestoreLocalVectorRead(lbfgs->rwork2, lbfgs->rwork2_local));
965     PetscCall(VecScale(lbfgs->rwork3, -1.0));
966 
967     PetscCall(MatMult(lbfgs->YtS_triu_strict, lbfgs->rwork3, lbfgs->rwork2));
968     PetscCall(VecPointwiseMult(lbfgs->rwork2, lbfgs->rwork2, lbfgs->inv_diag_vec));
969     PetscCall(VecAXPY(lbfgs->rwork1, 1.0, lbfgs->rwork2));
970 
971     if (lbfgs->strategy == MAT_LMVM_DENSE_REORDER) {
972       PetscCall(VecHistoryOrderToRecycleOrder(B, lbfgs->rwork1, lbfgs->num_updates, lbfgs->cyclic_work_vec));
973       PetscCall(VecHistoryOrderToRecycleOrder(B, lbfgs->rwork3, lbfgs->num_updates, lbfgs->cyclic_work_vec));
974     }
975 
976     PetscCall(MatMultAddColumnRange(Yfull, lbfgs->rwork1, Z, Z, 0, h));
977     lbfgs->Y_count++;
978     PetscCall(MatMultAddColumnRange(lbfgs->BS, lbfgs->rwork3, Z, Z, 0, h));
979     lbfgs->S_count++;
980   }
981   PetscFunctionReturn(PETSC_SUCCESS);
982 }
983 
984 /*
985   This dense representation reduces the L-BFGS update to a series of
986   matrix-vector products with dense matrices in lieu of the conventional matrix-free
987   two-loop algorithm.
988 */
MatCreate_LMVMDBFGS(Mat B)989 PetscErrorCode MatCreate_LMVMDBFGS(Mat B)
990 {
991   Mat_LMVM *lmvm;
992   Mat_DQN  *lbfgs;
993 
994   PetscFunctionBegin;
995   PetscCall(MatCreate_LMVM(B));
996   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMDBFGS));
997   PetscCall(MatSetOption(B, MAT_HERMITIAN, PETSC_TRUE));
998   PetscCall(MatSetOption(B, MAT_SPD, PETSC_TRUE));
999   PetscCall(MatSetOption(B, MAT_SPD_ETERNAL, PETSC_TRUE));
1000   B->ops->view           = MatView_LMVMDQN;
1001   B->ops->setup          = MatSetUp_LMVMDQN;
1002   B->ops->setfromoptions = MatSetFromOptions_LMVMDQN;
1003   B->ops->destroy        = MatDestroy_LMVMDQN;
1004 
1005   lmvm              = (Mat_LMVM *)B->data;
1006   lmvm->ops->reset  = MatReset_LMVMDQN;
1007   lmvm->ops->update = MatUpdate_LMVMDQN;
1008   lmvm->ops->mult   = MatMult_LMVMDBFGS;
1009   lmvm->ops->solve  = MatSolve_LMVMDBFGS;
1010   lmvm->ops->copy   = MatCopy_LMVMDQN;
1011 
1012   lmvm->ops->multht  = lmvm->ops->mult;
1013   lmvm->ops->solveht = lmvm->ops->solve;
1014 
1015   PetscCall(PetscNew(&lbfgs));
1016   lmvm->ctx              = (void *)lbfgs;
1017   lbfgs->allocated       = PETSC_FALSE;
1018   lbfgs->use_recursive   = PETSC_TRUE;
1019   lbfgs->needPQ          = PETSC_TRUE;
1020   lbfgs->watchdog        = 0;
1021   lbfgs->max_seq_rejects = lmvm->m / 2;
1022   lbfgs->strategy        = MAT_LMVM_DENSE_INPLACE;
1023 
1024   PetscCall(SymBroydenRescaleCreate(&lbfgs->rescale));
1025   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSymBroydenSetDelta_C", MatLMVMSymBroydenSetDelta_LMVMDQN));
1026   PetscFunctionReturn(PETSC_SUCCESS);
1027 }
1028 
1029 /*@
1030   MatCreateLMVMDBFGS - Creates a dense representation of the limited-memory
1031   Broyden-Fletcher-Goldfarb-Shanno (BFGS) approximation to a Hessian.
1032 
1033   Collective
1034 
1035   Input Parameters:
1036 + comm - MPI communicator
1037 . n    - number of local rows for storage vectors
1038 - N    - global size of the storage vectors
1039 
1040   Output Parameter:
1041 . B - the matrix
1042 
1043   Level: advanced
1044 
1045   Note:
1046   It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions()
1047   paradigm instead of this routine directly.
1048 
1049 .seealso: `MatCreate()`, `MATLMVM`, `MATLMVMDBFGS`, `MatCreateLMVMBFGS()`
1050 @*/
MatCreateLMVMDBFGS(MPI_Comm comm,PetscInt n,PetscInt N,Mat * B)1051 PetscErrorCode MatCreateLMVMDBFGS(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
1052 {
1053   PetscFunctionBegin;
1054   PetscCall(KSPInitializePackage());
1055   PetscCall(MatCreate(comm, B));
1056   PetscCall(MatSetSizes(*B, n, n, N, N));
1057   PetscCall(MatSetType(*B, MATLMVMDBFGS));
1058   PetscCall(MatSetUp(*B));
1059   PetscFunctionReturn(PETSC_SUCCESS);
1060 }
1061 
1062 /* here R is strictly upper triangular part of STY */
MatGetRTDR(Mat B,Mat result)1063 static PetscErrorCode MatGetRTDR(Mat B, Mat result)
1064 {
1065   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
1066   Mat_DQN  *ldfp = (Mat_DQN *)lmvm->ctx;
1067   PetscInt  m_local;
1068 
1069   PetscFunctionBegin;
1070   if (!ldfp->temp_mat) PetscCall(MatDuplicate(ldfp->StY_triu_strict, MAT_SHARE_NONZERO_PATTERN, &ldfp->temp_mat));
1071   PetscCall(MatCopy(ldfp->StY_triu_strict, ldfp->temp_mat, SAME_NONZERO_PATTERN));
1072   PetscCall(MatDiagonalScale(ldfp->temp_mat, ldfp->inv_diag_vec, NULL));
1073   PetscCall(MatGetLocalSize(result, &m_local, NULL));
1074   // need to conjugate and conjugate again because we have MatTransposeMatMult but not MatHermitianTransposeMatMult()
1075   PetscCall(MatConjugate(ldfp->temp_mat));
1076   if (m_local) {
1077     Mat temp_local, StY_local, result_local;
1078     PetscCall(MatDenseGetLocalMatrix(ldfp->StY_triu_strict, &StY_local));
1079     PetscCall(MatDenseGetLocalMatrix(ldfp->temp_mat, &temp_local));
1080     PetscCall(MatDenseGetLocalMatrix(result, &result_local));
1081     PetscCall(MatTransposeMatMult(StY_local, temp_local, MAT_REUSE_MATRIX, PETSC_DETERMINE, &result_local));
1082   }
1083   PetscCall(MatConjugate(result));
1084   PetscFunctionReturn(PETSC_SUCCESS);
1085 }
1086 
MatLMVMDDFPUpdateSolveData(Mat B)1087 static PetscErrorCode MatLMVMDDFPUpdateSolveData(Mat B)
1088 {
1089   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
1090   Mat_DQN  *ldfp = (Mat_DQN *)lmvm->ctx;
1091   PetscInt  m    = lmvm->m, m_local;
1092   PetscInt  k    = lmvm->k;
1093   PetscInt  h    = k - oldest_update(m, k);
1094   PetscInt  j_0;
1095   PetscInt  prev_oldest;
1096   Mat       Sfull = lmvm->basis[LMBASIS_S]->vecs;
1097   Mat       Yfull = lmvm->basis[LMBASIS_Y]->vecs;
1098   Mat       J_local;
1099 
1100   PetscFunctionBegin;
1101   if (!ldfp->StY_triu_strict) {
1102     PetscCall(MatDuplicate(ldfp->YtS_triu, MAT_SHARE_NONZERO_PATTERN, &ldfp->StY_triu_strict));
1103     PetscCall(MatDestroy(&ldfp->YtHY));
1104     PetscCall(MatDuplicate(ldfp->YtS_triu, MAT_SHARE_NONZERO_PATTERN, &ldfp->YtHY));
1105     PetscCall(MatDestroy(&ldfp->J));
1106     PetscCall(MatDuplicate(ldfp->YtS_triu, MAT_SHARE_NONZERO_PATTERN, &ldfp->J));
1107     PetscCall(MatDestroy(&ldfp->HY));
1108     PetscCall(MatDuplicate(Yfull, MAT_SHARE_NONZERO_PATTERN, &ldfp->HY));
1109     PetscCall(MatShift(ldfp->YtHY, 1.0));
1110     ldfp->num_mult_updates = oldest_update(m, k);
1111   }
1112   if (ldfp->num_mult_updates == k) PetscFunctionReturn(PETSC_SUCCESS);
1113 
1114   /* H_0 may have been updated, we must recompute H_0 Y and Y^T H_0 Y */
1115   for (PetscInt j = oldest_update(m, k); j < k; j++) {
1116     Vec      y_j;
1117     Vec      Hy_j;
1118     Vec      YtHy_j;
1119     PetscInt Y_idx    = recycle_index(m, j);
1120     PetscInt YtHY_idx = ldfp->strategy == MAT_LMVM_DENSE_INPLACE ? Y_idx : history_index(m, k, j);
1121 
1122     PetscCall(MatDenseGetColumnVecWrite(ldfp->HY, Y_idx, &Hy_j));
1123     PetscCall(MatDenseGetColumnVecRead(Yfull, Y_idx, &y_j));
1124     PetscCall(MatDQNApplyJ0Inv(B, y_j, Hy_j));
1125     PetscCall(MatDenseRestoreColumnVecRead(Yfull, Y_idx, &y_j));
1126     PetscCall(MatDenseGetColumnVecWrite(ldfp->YtHY, YtHY_idx, &YtHy_j));
1127     PetscCall(MatMultHermitianTransposeColumnRange(Yfull, Hy_j, YtHy_j, 0, h));
1128     ldfp->Yt_count++;
1129     if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, YtHy_j, ldfp->num_updates, ldfp->cyclic_work_vec));
1130     PetscCall(MatDenseRestoreColumnVecWrite(ldfp->YtHY, YtHY_idx, &YtHy_j));
1131     PetscCall(MatDenseRestoreColumnVecWrite(ldfp->HY, Y_idx, &Hy_j));
1132   }
1133   prev_oldest = oldest_update(m, ldfp->num_mult_updates);
1134   if (ldfp->strategy == MAT_LMVM_DENSE_REORDER && prev_oldest < oldest_update(m, k)) {
1135     /* move the YtS entries that have been computed and need to be kept back up */
1136     PetscInt m_keep = m - (oldest_update(m, k) - prev_oldest);
1137 
1138     PetscCall(MatMove_LR3(B, ldfp->StY_triu_strict, m_keep));
1139   }
1140   PetscCall(MatGetLocalSize(ldfp->StY_triu_strict, &m_local, NULL));
1141   j_0 = PetscMax(ldfp->num_mult_updates, oldest_update(m, k));
1142   for (PetscInt j = j_0; j < k; j++) {
1143     PetscInt Y_idx   = recycle_index(m, j);
1144     PetscInt StY_idx = ldfp->strategy == MAT_LMVM_DENSE_INPLACE ? Y_idx : history_index(m, k, j);
1145     Vec      y_j, Sty_j;
1146 
1147     PetscCall(MatDenseGetColumnVecRead(Yfull, Y_idx, &y_j));
1148     PetscCall(MatDenseGetColumnVecWrite(ldfp->StY_triu_strict, StY_idx, &Sty_j));
1149     PetscCall(MatMultHermitianTransposeColumnRange(Sfull, y_j, Sty_j, 0, h));
1150     ldfp->St_count++;
1151     if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, Sty_j, ldfp->num_updates, ldfp->cyclic_work_vec));
1152     PetscCall(MatDenseRestoreColumnVecWrite(ldfp->StY_triu_strict, StY_idx, &Sty_j));
1153     PetscCall(MatDenseRestoreColumnVecRead(Yfull, Y_idx, &y_j));
1154     /* zero the corresponding row */
1155     if (m_local > 0) {
1156       Mat StY_local, StY_row;
1157 
1158       PetscCall(MatDenseGetLocalMatrix(ldfp->StY_triu_strict, &StY_local));
1159       PetscCall(MatDenseGetSubMatrix(StY_local, StY_idx, StY_idx + 1, PETSC_DECIDE, PETSC_DECIDE, &StY_row));
1160       PetscCall(MatZeroEntries(StY_row));
1161       PetscCall(MatDenseRestoreSubMatrix(StY_local, &StY_row));
1162     }
1163   }
1164   if (!ldfp->inv_diag_vec) PetscCall(VecDuplicate(ldfp->diag_vec, &ldfp->inv_diag_vec));
1165   PetscCall(VecCopy(ldfp->diag_vec, ldfp->inv_diag_vec));
1166   PetscCall(VecReciprocal(ldfp->inv_diag_vec));
1167   PetscCall(MatDenseGetLocalMatrix(ldfp->J, &J_local));
1168   PetscCall(MatSetFactorType(J_local, MAT_FACTOR_NONE));
1169   PetscCall(MatGetRTDR(B, ldfp->J));
1170   PetscCall(MatAXPY(ldfp->J, 1.0, ldfp->YtHY, SAME_NONZERO_PATTERN));
1171   if (m_local) {
1172     PetscCall(MatSetOption(J_local, MAT_SPD, PETSC_TRUE));
1173     PetscCall(MatCholeskyFactor(J_local, NULL, NULL));
1174   }
1175   ldfp->num_mult_updates = ldfp->num_updates;
1176   PetscFunctionReturn(PETSC_SUCCESS);
1177 }
1178 
1179 /* Solves for
1180 
1181    H_0 - [ S | H_0 Y] [ -D  |    R.T    ]^-1 [   S^T   ]
1182                       [-----+-----------]    [---------]
1183                       [  R  | Y^T H_0 Y ]    [ Y^T H_0 ]
1184 
1185    Above is equivalent to
1186 
1187    H_0 - [ S | H_0 Y] [[     I     | 0 ][ -D | 0 ][ I | -D^{-1} R^T ]]^-1 [   S^T   ]
1188                       [[-----------+---][----+---][---+-------------]]    [---------]
1189                       [[ -R D^{-1} | I ][  0 | J ][ 0 |      I      ]]    [ Y^T H_0 ]
1190 
1191    where J = Y^T H_0 Y + R D^{-1} R.T
1192 
1193    becomes
1194 
1195    H_0 - [ S | H_0 Y] [ I | D^{-1} R^T ][ -D^{-1}  |   0    ][     I    | 0 ] [   S^T   ]
1196                       [---+------------][----------+--------][----------+---] [---------]
1197                       [ 0 |      I     ][     0    | J^{-1} ][ R D^{-1} | I ] [ Y^T H_0 ]
1198 
1199                       =
1200 
1201    H_0 + [ S | H_0 Y] [ D^{-1} | 0 ][ I | R^T ][ I |    0    ][     I    | 0 ] [   S^T   ]
1202                       [--------+---][---+-----][---+---------][----------+---] [---------]
1203                       [ 0      | I ][ 0 |  I  ][ 0 | -J^{-1} ][ R D^{-1} | I ] [ Y^T H_0 ]
1204 
1205                       (Note that StY_triu_strict is R)
1206    Byrd, Nocedal, Schnabel 1994
1207 
1208 */
MatSolve_LMVMDDFP(Mat H,Vec F,Vec dX)1209 static PetscErrorCode MatSolve_LMVMDDFP(Mat H, Vec F, Vec dX)
1210 {
1211   Mat_LMVM *lmvm = (Mat_LMVM *)H->data;
1212   Mat_DQN  *ldfp = (Mat_DQN *)lmvm->ctx;
1213   PetscInt  m    = lmvm->m;
1214   PetscInt  k    = lmvm->k;
1215   PetscInt  h    = k - oldest_update(m, k);
1216   PetscInt  idx, i, j, local_n;
1217   PetscInt  m_local;
1218   Mat       J_local;
1219   Mat       Sfull = lmvm->basis[LMBASIS_S]->vecs;
1220   Mat       Yfull = lmvm->basis[LMBASIS_Y]->vecs;
1221 
1222   PetscFunctionBegin;
1223   VecCheckSameSize(F, 2, dX, 3);
1224   VecCheckMatCompatible(H, dX, 3, F, 2);
1225 
1226   /* Cholesky Version */
1227   /* Start with the B0 term */
1228   PetscCall(MatDQNApplyJ0Inv(H, F, dX));
1229   if (!ldfp->num_updates) PetscFunctionReturn(PETSC_SUCCESS); /* No updates stored yet */
1230 
1231   if (ldfp->use_recursive) {
1232     PetscDeviceContext dctx;
1233     PetscMemType       memtype;
1234     PetscScalar        stf, ytx, ytq, yjtqi, sjtyi, *workscalar;
1235 
1236     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
1237     /* Recursive formulation to avoid Cholesky. Not a dense formulation */
1238     PetscCall(MatMultHermitianTransposeColumnRange(Sfull, F, ldfp->rwork1, 0, h));
1239     ldfp->Yt_count++;
1240 
1241     PetscCall(VecGetLocalSize(ldfp->rwork1, &local_n));
1242 
1243     PetscInt oldest = oldest_update(m, k);
1244 
1245     if (ldfp->needPQ) {
1246       PetscInt oldest = oldest_update(m, k);
1247       for (i = oldest; i < k; ++i) {
1248         idx = recycle_index(m, i);
1249         /* column_work = S[idx] */
1250         PetscCall(MatGetColumnVector(Yfull, ldfp->column_work, idx));
1251         PetscCall(MatDQNApplyJ0Inv(H, ldfp->column_work, ldfp->PQ[idx]));
1252         PetscCall(MatMultHermitianTransposeColumnRange(Sfull, ldfp->column_work, ldfp->rwork3, 0, h));
1253         PetscCall(VecGetArrayAndMemType(ldfp->rwork3, &workscalar, &memtype));
1254         for (j = oldest; j < i; ++j) {
1255           PetscInt idx_j = recycle_index(m, j);
1256           /* Copy sjtyi in device-aware manner */
1257           if (local_n) {
1258             if (PetscMemTypeHost(memtype)) {
1259               sjtyi = workscalar[idx_j];
1260             } else {
1261               PetscCall(PetscDeviceRegisterMemory(&sjtyi, PETSC_MEMTYPE_HOST, 1 * sizeof(sjtyi)));
1262               PetscCall(PetscDeviceRegisterMemory(workscalar, memtype, local_n * sizeof(*workscalar)));
1263               PetscCall(PetscDeviceArrayCopy(dctx, &sjtyi, &workscalar[idx_j], 1));
1264             }
1265           }
1266           PetscCallMPI(MPI_Bcast(&sjtyi, 1, MPIU_SCALAR, 0, PetscObjectComm((PetscObject)H)));
1267           /* column_work2 = Y[j] */
1268           PetscCall(MatGetColumnVector(Yfull, ldfp->column_work2, idx_j));
1269           PetscCall(VecDot(ldfp->PQ[idx], ldfp->column_work2, &yjtqi));
1270           /* column_work2 = Y[j] */
1271           PetscCall(MatGetColumnVector(Sfull, ldfp->column_work2, idx_j));
1272           /* Compute the pure BFGS component of the forward product */
1273           PetscCall(VecAXPBYPCZ(ldfp->PQ[idx], -yjtqi / ldfp->ytq[idx_j], sjtyi / ldfp->yts[idx_j], 1.0, ldfp->PQ[idx_j], ldfp->column_work2));
1274         }
1275         PetscCall(VecDot(ldfp->PQ[idx], ldfp->column_work, &ytq));
1276         ldfp->ytq[idx] = PetscRealPart(ytq);
1277       }
1278       ldfp->needPQ = PETSC_FALSE;
1279     }
1280 
1281     PetscCall(VecGetArrayAndMemType(ldfp->rwork1, &workscalar, &memtype));
1282     for (i = oldest; i < k; ++i) {
1283       idx = recycle_index(m, i);
1284       /* Copy stz[i], ytx[i] in device-aware manner */
1285       if (local_n) {
1286         if (PetscMemTypeHost(memtype)) {
1287           stf = workscalar[idx];
1288         } else {
1289           PetscCall(PetscDeviceRegisterMemory(&stf, PETSC_MEMTYPE_HOST, sizeof(stf)));
1290           PetscCall(PetscDeviceRegisterMemory(workscalar, memtype, local_n * sizeof(*workscalar)));
1291           PetscCall(PetscDeviceArrayCopy(dctx, &stf, &workscalar[idx], 1));
1292         }
1293       }
1294       PetscCallMPI(MPI_Bcast(&stf, 1, MPIU_SCALAR, 0, PetscObjectComm((PetscObject)H)));
1295       /* column_work : S[i], column_work2 : Y[i] */
1296       PetscCall(MatGetColumnVector(Sfull, ldfp->column_work, idx));
1297       PetscCall(MatGetColumnVector(Yfull, ldfp->column_work2, idx));
1298       PetscCall(VecDot(dX, ldfp->column_work2, &ytx));
1299       PetscCall(VecAXPBYPCZ(dX, -ytx / ldfp->ytq[idx], stf / ldfp->yts[idx], 1.0, ldfp->PQ[idx], ldfp->column_work));
1300     }
1301     PetscCall(VecRestoreArrayAndMemType(ldfp->rwork1, &workscalar));
1302   } else {
1303     PetscCall(MatLMVMDDFPUpdateSolveData(H));
1304     PetscCall(MatMultHermitianTransposeColumnRange(Sfull, F, ldfp->rwork1, 0, h));
1305     ldfp->St_count++;
1306     PetscCall(MatMultHermitianTransposeColumnRange(Yfull, dX, ldfp->rwork2, 0, h));
1307     ldfp->Yt_count++;
1308     if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) {
1309       PetscCall(VecRecycleOrderToHistoryOrder(H, ldfp->rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1310       PetscCall(VecRecycleOrderToHistoryOrder(H, ldfp->rwork2, ldfp->num_updates, ldfp->cyclic_work_vec));
1311     }
1312 
1313     PetscCall(VecPointwiseMult(ldfp->rwork3, ldfp->rwork1, ldfp->inv_diag_vec));
1314     if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(ldfp->StY_triu_strict));
1315     PetscCall(MatMultTransposeAdd(ldfp->StY_triu_strict, ldfp->rwork3, ldfp->rwork2, ldfp->rwork2));
1316     if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(ldfp->StY_triu_strict));
1317 
1318     if (!ldfp->rwork2_local) PetscCall(VecCreateLocalVector(ldfp->rwork2, &ldfp->rwork2_local));
1319     if (!ldfp->rwork3_local) PetscCall(VecCreateLocalVector(ldfp->rwork3, &ldfp->rwork3_local));
1320     PetscCall(VecGetLocalVectorRead(ldfp->rwork2, ldfp->rwork2_local));
1321     PetscCall(VecGetLocalVector(ldfp->rwork3, ldfp->rwork3_local));
1322     PetscCall(MatDenseGetLocalMatrix(ldfp->J, &J_local));
1323     PetscCall(VecGetSize(ldfp->rwork2_local, &m_local));
1324     if (m_local) {
1325       Mat J_local;
1326 
1327       PetscCall(MatDenseGetLocalMatrix(ldfp->J, &J_local));
1328       PetscCall(MatSolve(J_local, ldfp->rwork2_local, ldfp->rwork3_local));
1329     }
1330     PetscCall(VecRestoreLocalVector(ldfp->rwork3, ldfp->rwork3_local));
1331     PetscCall(VecRestoreLocalVectorRead(ldfp->rwork2, ldfp->rwork2_local));
1332     PetscCall(VecScale(ldfp->rwork3, -1.0));
1333 
1334     PetscCall(MatMultAdd(ldfp->StY_triu_strict, ldfp->rwork3, ldfp->rwork1, ldfp->rwork1));
1335     PetscCall(VecPointwiseMult(ldfp->rwork1, ldfp->rwork1, ldfp->inv_diag_vec));
1336 
1337     if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) {
1338       PetscCall(VecHistoryOrderToRecycleOrder(H, ldfp->rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1339       PetscCall(VecHistoryOrderToRecycleOrder(H, ldfp->rwork3, ldfp->num_updates, ldfp->cyclic_work_vec));
1340     }
1341 
1342     PetscCall(MatMultAddColumnRange(Sfull, ldfp->rwork1, dX, dX, 0, h));
1343     ldfp->S_count++;
1344     PetscCall(MatMultAddColumnRange(ldfp->HY, ldfp->rwork3, dX, dX, 0, h));
1345     ldfp->Y_count++;
1346   }
1347   PetscFunctionReturn(PETSC_SUCCESS);
1348 }
1349 
1350 /* Solves for
1351    (Theorem 1, Erway, Jain, and Marcia, 2013)
1352 
1353    B_0 - [ Y | B_0 S] [ -R^{-T} (D + S^T B_0 S) R^{-1} | R^{-T} ] [   Y^T   ]
1354                       ---------------------------------+--------] [---------]
1355                       [             R^{-1}             |   0    ] [ S^T B_0 ]
1356 
1357    (Note: R above is right triangular part of YTS)
1358    which becomes,
1359 
1360    [ I | -Y L^{-T} ] [  I  | 0 ] [ B_0 | 0 ] [ I | S ] [      I      ]
1361                      [-----+---] [-----+---] [---+---] [-------------]
1362                      [ S^T | I ] [  0  | D ] [ 0 | I ] [ -L^{-1} Y^T ]
1363 
1364    (Note: L above is right triangular part of STY)
1365 
1366 */
MatMult_LMVMDDFP(Mat B,Vec X,Vec Z)1367 static PetscErrorCode MatMult_LMVMDDFP(Mat B, Vec X, Vec Z)
1368 {
1369   Mat_LMVM        *lmvm   = (Mat_LMVM *)B->data;
1370   Mat_DQN         *ldfp   = (Mat_DQN *)lmvm->ctx;
1371   Vec              rwork1 = ldfp->rwork1;
1372   PetscInt         m      = lmvm->m;
1373   PetscInt         k      = lmvm->k;
1374   PetscInt         h      = k - oldest_update(m, k);
1375   Mat              Sfull  = lmvm->basis[LMBASIS_S]->vecs;
1376   Mat              Yfull  = lmvm->basis[LMBASIS_Y]->vecs;
1377   PetscObjectState Xstate;
1378 
1379   PetscFunctionBegin;
1380   VecCheckSameSize(X, 2, Z, 3);
1381   VecCheckMatCompatible(B, X, 2, Z, 3);
1382 
1383   /* DFP Version. Erway, Jain, Marcia, 2013, Theorem 1 */
1384   /* Block Version */
1385   if (!ldfp->num_updates) {
1386     PetscCall(MatDQNApplyJ0Fwd(B, X, Z));
1387     PetscFunctionReturn(PETSC_SUCCESS); /* No updates stored yet */
1388   }
1389 
1390   PetscCall(PetscObjectStateGet((PetscObject)X, &Xstate));
1391   PetscCall(MatMultHermitianTransposeColumnRange(Yfull, X, rwork1, 0, h));
1392 
1393   /* Reordering rwork1, as STY is in history order, while Y is in recycled order */
1394   if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1395   PetscCall(MatUpperTriangularSolveInPlace(B, ldfp->YtS_triu, rwork1, PETSC_FALSE, ldfp->num_updates, ldfp->strategy));
1396   PetscCall(VecScale(rwork1, -1.0));
1397   if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecHistoryOrderToRecycleOrder(B, rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1398 
1399   PetscCall(VecCopy(X, ldfp->column_work));
1400   PetscCall(MatMultAddColumnRange(Sfull, rwork1, ldfp->column_work, ldfp->column_work, 0, h));
1401   ldfp->S_count++;
1402 
1403   PetscCall(VecPointwiseMult(rwork1, ldfp->diag_vec_recycle_order, rwork1));
1404   PetscCall(MatDQNApplyJ0Fwd(B, ldfp->column_work, Z));
1405 
1406   PetscCall(MatMultHermitianTransposeAddColumnRange(Sfull, Z, rwork1, rwork1, 0, h));
1407   ldfp->St_count++;
1408 
1409   if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecRecycleOrderToHistoryOrder(B, rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1410   PetscCall(MatUpperTriangularSolveInPlace(B, ldfp->YtS_triu, rwork1, PETSC_TRUE, ldfp->num_updates, ldfp->strategy));
1411   PetscCall(VecScale(rwork1, -1.0));
1412   if (ldfp->strategy == MAT_LMVM_DENSE_REORDER) PetscCall(VecHistoryOrderToRecycleOrder(B, rwork1, ldfp->num_updates, ldfp->cyclic_work_vec));
1413 
1414   PetscCall(MatMultAddColumnRange(Yfull, rwork1, Z, Z, 0, h));
1415   ldfp->Y_count++;
1416   PetscFunctionReturn(PETSC_SUCCESS);
1417 }
1418 
1419 /*
1420    This dense representation reduces the L-DFP update to a series of
1421    matrix-vector products with dense matrices in lieu of the conventional
1422    matrix-free two-loop algorithm.
1423 */
MatCreate_LMVMDDFP(Mat B)1424 PetscErrorCode MatCreate_LMVMDDFP(Mat B)
1425 {
1426   Mat_LMVM *lmvm;
1427   Mat_DQN  *ldfp;
1428 
1429   PetscFunctionBegin;
1430   PetscCall(MatCreate_LMVM(B));
1431   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMDDFP));
1432   PetscCall(MatSetOption(B, MAT_HERMITIAN, PETSC_TRUE));
1433   PetscCall(MatSetOption(B, MAT_SPD, PETSC_TRUE));
1434   PetscCall(MatSetOption(B, MAT_SPD_ETERNAL, PETSC_TRUE));
1435   B->ops->view           = MatView_LMVMDQN;
1436   B->ops->setup          = MatSetUp_LMVMDQN;
1437   B->ops->setfromoptions = MatSetFromOptions_LMVMDQN;
1438   B->ops->destroy        = MatDestroy_LMVMDQN;
1439 
1440   lmvm              = (Mat_LMVM *)B->data;
1441   lmvm->ops->reset  = MatReset_LMVMDQN;
1442   lmvm->ops->update = MatUpdate_LMVMDQN;
1443   lmvm->ops->mult   = MatMult_LMVMDDFP;
1444   lmvm->ops->solve  = MatSolve_LMVMDDFP;
1445   lmvm->ops->copy   = MatCopy_LMVMDQN;
1446 
1447   lmvm->ops->multht  = lmvm->ops->mult;
1448   lmvm->ops->solveht = lmvm->ops->solve;
1449 
1450   PetscCall(PetscNew(&ldfp));
1451   lmvm->ctx             = (void *)ldfp;
1452   ldfp->allocated       = PETSC_FALSE;
1453   ldfp->watchdog        = 0;
1454   ldfp->max_seq_rejects = lmvm->m / 2;
1455   ldfp->strategy        = MAT_LMVM_DENSE_INPLACE;
1456   ldfp->use_recursive   = PETSC_TRUE;
1457   ldfp->needPQ          = PETSC_TRUE;
1458 
1459   PetscCall(SymBroydenRescaleCreate(&ldfp->rescale));
1460   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSymBroydenSetDelta_C", MatLMVMSymBroydenSetDelta_LMVMDQN));
1461   PetscFunctionReturn(PETSC_SUCCESS);
1462 }
1463 
1464 /*@
1465   MatCreateLMVMDDFP - Creates a dense representation of the limited-memory
1466   Davidon-Fletcher-Powell (DFP) approximation to a Hessian.
1467 
1468   Collective
1469 
1470   Input Parameters:
1471 + comm - MPI communicator
1472 . n    - number of local rows for storage vectors
1473 - N    - global size of the storage vectors
1474 
1475   Output Parameter:
1476 . B - the matrix
1477 
1478   Level: advanced
1479 
1480   Note:
1481   It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions()
1482   paradigm instead of this routine directly.
1483 
1484 .seealso: `MatCreate()`, `MATLMVM`, `MATLMVMDDFP`, `MatCreateLMVMDFP()`
1485 @*/
MatCreateLMVMDDFP(MPI_Comm comm,PetscInt n,PetscInt N,Mat * B)1486 PetscErrorCode MatCreateLMVMDDFP(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
1487 {
1488   PetscFunctionBegin;
1489   PetscCall(KSPInitializePackage());
1490   PetscCall(MatCreate(comm, B));
1491   PetscCall(MatSetSizes(*B, n, n, N, N));
1492   PetscCall(MatSetType(*B, MATLMVMDDFP));
1493   PetscCall(MatSetUp(*B));
1494   PetscFunctionReturn(PETSC_SUCCESS);
1495 }
1496 
1497 /*@
1498   MatLMVMDenseSetType - Sets the memory storage type for dense `MATLMVM`
1499 
1500   Input Parameters:
1501 + B    - the `MATLMVM` matrix
1502 - type - scale type, see `MatLMVMDenseSetType`
1503 
1504   Options Database Keys:
1505 + -mat_lqn_type   <reorder,inplace> - set the strategy
1506 . -mat_lbfgs_type <reorder,inplace> - set the strategy
1507 - -mat_ldfp_type  <reorder,inplace> - set the strategy
1508 
1509   Level: intermediate
1510 
1511   MatLMVMDenseTypes\:
1512 +   `MAT_LMVM_DENSE_REORDER` - reorders memory to minimize kernel launch
1513 -   `MAT_LMVM_DENSE_INPLACE` - launches kernel inplace to minimize memory movement
1514 
1515 .seealso: [](ch_ksp), `MATLMVMDQN`, `MATLMVMDBFGS`, `MATLMVMDDFP`, `MatLMVMDenseType`
1516 @*/
MatLMVMDenseSetType(Mat B,MatLMVMDenseType type)1517 PetscErrorCode MatLMVMDenseSetType(Mat B, MatLMVMDenseType type)
1518 {
1519   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
1520   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
1521 
1522   PetscFunctionBegin;
1523   PetscValidHeaderSpecific(B, MAT_CLASSID, 1);
1524   lqn->strategy = type;
1525   PetscFunctionReturn(PETSC_SUCCESS);
1526 }
1527