xref: /petsc/src/ksp/ksp/utils/lmvm/lmvmimpl.c (revision 834855d6effb0d027771461c8e947ee1ce5a1e17)
1 #include <petscdevice.h>
2 #include <../src/ksp/ksp/utils/lmvm/lmvm.h> /*I "petscksp.h" I*/
3 #include <petsc/private/deviceimpl.h>
4 #include "blas_cyclic/blas_cyclic.h"
5 #include "rescale/symbrdnrescale.h"
6 
7 PetscLogEvent MATLMVM_Update;
8 
9 static PetscBool MatLMVMPackageInitialized = PETSC_FALSE;
10 
MatLMVMPackageInitialize(void)11 static PetscErrorCode MatLMVMPackageInitialize(void)
12 {
13   PetscFunctionBegin;
14   if (MatLMVMPackageInitialized) PetscFunctionReturn(PETSC_SUCCESS);
15   MatLMVMPackageInitialized = PETSC_TRUE;
16   PetscCall(PetscLogEventRegister("AXPBYCyclic", MAT_CLASSID, &AXPBY_Cyc));
17   PetscCall(PetscLogEventRegister("DMVCyclic", MAT_CLASSID, &DMV_Cyc));
18   PetscCall(PetscLogEventRegister("DSVCyclic", MAT_CLASSID, &DSV_Cyc));
19   PetscCall(PetscLogEventRegister("TRSVCyclic", MAT_CLASSID, &TRSV_Cyc));
20   PetscCall(PetscLogEventRegister("GEMVCyclic", MAT_CLASSID, &GEMV_Cyc));
21   PetscCall(PetscLogEventRegister("HEMVCyclic", MAT_CLASSID, &HEMV_Cyc));
22   PetscCall(PetscLogEventRegister("LMBasisGEMM", MAT_CLASSID, &LMBASIS_GEMM));
23   PetscCall(PetscLogEventRegister("LMBasisGEMV", MAT_CLASSID, &LMBASIS_GEMV));
24   PetscCall(PetscLogEventRegister("LMBasisGEMVH", MAT_CLASSID, &LMBASIS_GEMVH));
25   PetscCall(PetscLogEventRegister("LMProdsMult", MAT_CLASSID, &LMPROD_Mult));
26   PetscCall(PetscLogEventRegister("LMProdsSolve", MAT_CLASSID, &LMPROD_Solve));
27   PetscCall(PetscLogEventRegister("LMProdsUpdate", MAT_CLASSID, &LMPROD_Update));
28   PetscCall(PetscLogEventRegister("MatLMVMUpdate", MAT_CLASSID, &MATLMVM_Update));
29   PetscCall(PetscLogEventRegister("SymBrdnRescale", MAT_CLASSID, &SBRDN_Rescale));
30   PetscFunctionReturn(PETSC_SUCCESS);
31 }
32 
33 const char *const MatLMVMMultAlgorithms[] = {
34   "recursive", "dense", "compact_dense", "MatLMVMMatvecTypes", "MATLMVM_MATVEC_", NULL,
35 };
36 
37 PetscBool  ByrdNocedalSchnabelCite       = PETSC_FALSE;
38 const char ByrdNocedalSchnabelCitation[] = "@article{Byrd1994,"
39                                            "  title = {Representations of quasi-Newton matrices and their use in limited memory methods},"
40                                            "  volume = {63},"
41                                            "  ISSN = {1436-4646},"
42                                            "  url = {http://dx.doi.org/10.1007/BF01582063},"
43                                            "  DOI = {10.1007/bf01582063},"
44                                            "  number = {1-3},"
45                                            "  journal = {Mathematical Programming},"
46                                            "  publisher = {Springer Science and Business Media LLC},"
47                                            "  author = {Byrd,  Richard H. and Nocedal,  Jorge and Schnabel,  Robert B.},"
48                                            "  year = {1994},"
49                                            "  month = jan,"
50                                            "  pages = {129-156}"
51                                            "}\n";
52 
MatReset_LMVM(Mat B,MatLMVMResetMode mode)53 PETSC_INTERN PetscErrorCode MatReset_LMVM(Mat B, MatLMVMResetMode mode)
54 {
55   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
56 
57   PetscFunctionBegin;
58   lmvm->k        = 0;
59   lmvm->prev_set = PETSC_FALSE;
60   lmvm->shift    = 0.0;
61   if (MatLMVMResetClearsBases(mode)) {
62     for (PetscInt i = 0; i < LMBASIS_END; i++) PetscCall(LMBasisDestroy(&lmvm->basis[i]));
63     for (PetscInt k = 0; k < LMBLOCK_END; k++) {
64       for (PetscInt i = 0; i < LMBASIS_END; i++) {
65         for (PetscInt j = 0; j < LMBASIS_END; j++) PetscCall(LMProductsDestroy(&lmvm->products[k][i][j]));
66       }
67     }
68     B->preallocated = PETSC_FALSE; // MatSetUp() needs to be run to create at least the S and Y bases
69   } else {
70     for (PetscInt i = 0; i < LMBASIS_END; i++) PetscCall(LMBasisReset(lmvm->basis[i]));
71     for (PetscInt k = 0; k < LMBLOCK_END; k++) {
72       for (PetscInt i = 0; i < LMBASIS_END; i++) {
73         for (PetscInt j = 0; j < LMBASIS_END; j++) PetscCall(LMProductsReset(lmvm->products[k][i][j]));
74       }
75     }
76   }
77   if (MatLMVMResetClearsJ0(mode)) PetscCall(MatLMVMClearJ0(B));
78   if (MatLMVMResetClearsVecs(mode)) {
79     PetscCall(VecDestroy(&lmvm->Xprev));
80     PetscCall(VecDestroy(&lmvm->Fprev));
81     B->preallocated = PETSC_FALSE; // MatSetUp() needs to be run to create these vecs
82   }
83   if (MatLMVMResetClearsAll(mode)) {
84     lmvm->nupdates = 0;
85     lmvm->nrejects = 0;
86   }
87   PetscFunctionReturn(PETSC_SUCCESS);
88 }
89 
MatLMVMAllocateBases(Mat B)90 PETSC_INTERN PetscErrorCode MatLMVMAllocateBases(Mat B)
91 {
92   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
93 
94   PetscFunctionBegin;
95   PetscCheck(lmvm->Xprev != NULL && lmvm->Fprev != NULL, PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_WRONGSTATE, "Must allocate Xprev and Fprev before allocating bases");
96   if (!lmvm->basis[LMBASIS_S]) PetscCall(LMBasisCreate(lmvm->Xprev, lmvm->m, &lmvm->basis[LMBASIS_S]));
97   if (!lmvm->basis[LMBASIS_Y]) PetscCall(LMBasisCreate(lmvm->Fprev, lmvm->m, &lmvm->basis[LMBASIS_Y]));
98   PetscFunctionReturn(PETSC_SUCCESS);
99 }
100 
MatLMVMAllocateVecs(Mat B)101 PETSC_INTERN PetscErrorCode MatLMVMAllocateVecs(Mat B)
102 {
103   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
104 
105   PetscFunctionBegin;
106   if (!lmvm->Xprev) PetscCall(MatCreateVecs(B, &lmvm->Xprev, NULL));
107   if (!lmvm->Fprev) PetscCall(MatCreateVecs(B, NULL, &lmvm->Fprev));
108   PetscFunctionReturn(PETSC_SUCCESS);
109 }
110 
MatAllocate_LMVM(Mat B,Vec X,Vec F)111 PETSC_INTERN PetscErrorCode MatAllocate_LMVM(Mat B, Vec X, Vec F)
112 {
113   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
114   PetscBool same;
115   VecType   vtype, Bvtype;
116 
117   PetscFunctionBegin;
118   PetscCall(MatLMVMUseVecLayoutsIfCompatible(B, X, F));
119   PetscCall(VecGetType(X, &vtype));
120   PetscCall(MatGetVecType(B, &Bvtype));
121   PetscCall(PetscStrcmp(vtype, Bvtype, &same));
122   if (!same) {
123     /* Given X vector has a different type than allocated X-type data structures.
124        We need to destroy all of this and duplicate again out of the given vector. */
125     PetscCall(MatLMVMReset_Internal(B, MAT_LMVM_RESET_BASES | MAT_LMVM_RESET_VECS));
126     PetscCall(MatSetVecType(B, vtype));
127     if (lmvm->created_J0) PetscCall(MatSetVecType(lmvm->J0, vtype));
128   }
129   PetscCall(MatLMVMAllocateVecs(B));
130   PetscFunctionReturn(PETSC_SUCCESS);
131 }
132 
MatUpdateKernel_LMVM(Mat B,Vec S,Vec Y)133 PetscErrorCode MatUpdateKernel_LMVM(Mat B, Vec S, Vec Y)
134 {
135   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
136   Vec       s_k, y_k;
137 
138   PetscFunctionBegin;
139   PetscCall(LMBasisGetNextVec(lmvm->basis[LMBASIS_S], &s_k));
140   PetscCall(VecCopy(S, s_k));
141   PetscCall(LMBasisRestoreNextVec(lmvm->basis[LMBASIS_S], &s_k));
142 
143   PetscCall(LMBasisGetNextVec(lmvm->basis[LMBASIS_Y], &y_k));
144   PetscCall(VecCopy(Y, y_k));
145   PetscCall(LMBasisRestoreNextVec(lmvm->basis[LMBASIS_Y], &y_k));
146   lmvm->nupdates++;
147   lmvm->k++;
148   PetscAssert(lmvm->k == lmvm->basis[LMBASIS_S]->k, PetscObjectComm((PetscObject)B), PETSC_ERR_PLIB, "Basis S and Mat B out of sync");
149   PetscAssert(lmvm->k == lmvm->basis[LMBASIS_Y]->k, PetscObjectComm((PetscObject)B), PETSC_ERR_PLIB, "Basis Y and Mat B out of sync");
150   PetscFunctionReturn(PETSC_SUCCESS);
151 }
152 
MatUpdate_LMVM(Mat B,Vec X,Vec F)153 PetscErrorCode MatUpdate_LMVM(Mat B, Vec X, Vec F)
154 {
155   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
156 
157   PetscFunctionBegin;
158   if (!lmvm->m) PetscFunctionReturn(PETSC_SUCCESS);
159   if (lmvm->prev_set) {
160     /* Compute the new (S = X - Xprev) and (Y = F - Fprev) vectors */
161     PetscCall(VecAXPBY(lmvm->Xprev, 1.0, -1.0, X));
162     PetscCall(VecAXPBY(lmvm->Fprev, 1.0, -1.0, F));
163     /* Update S and Y */
164     PetscCall(MatUpdateKernel_LMVM(B, lmvm->Xprev, lmvm->Fprev));
165   }
166 
167   /* Save the solution and function to be used in the next update */
168   PetscCall(VecCopy(X, lmvm->Xprev));
169   PetscCall(VecCopy(F, lmvm->Fprev));
170   lmvm->prev_set = PETSC_TRUE;
171   PetscFunctionReturn(PETSC_SUCCESS);
172 }
173 
MatMultAdd_LMVM(Mat B,Vec X,Vec Y,Vec Z)174 static PetscErrorCode MatMultAdd_LMVM(Mat B, Vec X, Vec Y, Vec Z)
175 {
176   PetscFunctionBegin;
177   PetscCall(MatMult(B, X, Z));
178   PetscCall(VecAXPY(Z, 1.0, Y));
179   PetscFunctionReturn(PETSC_SUCCESS);
180 }
181 
MatMult_LMVM(Mat B,Vec X,Vec Y)182 static PetscErrorCode MatMult_LMVM(Mat B, Vec X, Vec Y)
183 {
184   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
185 
186   PetscFunctionBegin;
187   PetscCall((*lmvm->ops->mult)(B, X, Y));
188   if (lmvm->shift != 0.0) PetscCall(VecAXPY(Y, lmvm->shift, X));
189   PetscFunctionReturn(PETSC_SUCCESS);
190 }
191 
MatMultHermitianTranspose_LMVM(Mat B,Vec X,Vec Y)192 static PetscErrorCode MatMultHermitianTranspose_LMVM(Mat B, Vec X, Vec Y)
193 {
194   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
195 
196   PetscFunctionBegin;
197   PetscCall((*lmvm->ops->multht)(B, X, Y));
198   if (lmvm->shift != 0.0) PetscCall(VecAXPY(Y, PetscConj(lmvm->shift), X));
199   PetscFunctionReturn(PETSC_SUCCESS);
200 }
201 
MatSolve_LMVM(Mat B,Vec x,Vec y)202 static PetscErrorCode MatSolve_LMVM(Mat B, Vec x, Vec y)
203 {
204   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
205 
206   PetscFunctionBegin;
207   PetscCheck(lmvm->shift == 0.0, PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_WRONGSTATE, "Cannot solve a MatLMVM when it has a nonzero shift");
208   PetscCall((*lmvm->ops->solve)(B, x, y));
209   PetscFunctionReturn(PETSC_SUCCESS);
210 }
211 
MatSolveHermitianTranspose_LMVM(Mat B,Vec x,Vec y)212 static PetscErrorCode MatSolveHermitianTranspose_LMVM(Mat B, Vec x, Vec y)
213 {
214   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
215 
216   PetscFunctionBegin;
217   PetscCheck(lmvm->shift == 0.0, PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_WRONGSTATE, "Cannot solve a MatLMVM when it has a nonzero shift");
218   PetscCall((*lmvm->ops->solveht)(B, x, y));
219   PetscFunctionReturn(PETSC_SUCCESS);
220 }
221 
MatSolveTranspose_LMVM(Mat B,Vec x,Vec y)222 static PetscErrorCode MatSolveTranspose_LMVM(Mat B, Vec x, Vec y)
223 {
224   PetscFunctionBegin;
225   if (!PetscDefined(USE_COMPLEX)) {
226     PetscCall(MatSolveHermitianTranspose_LMVM(B, x, y));
227   } else {
228     Vec x_conj;
229     PetscCall(VecDuplicate(x, &x_conj));
230     PetscCall(VecCopy(x, x_conj));
231     PetscCall(VecConjugate(x_conj));
232     PetscCall(MatSolveHermitianTranspose_LMVM(B, x_conj, y));
233     PetscCall(VecDestroy(&x_conj));
234     PetscCall(VecConjugate(y));
235   }
236   PetscFunctionReturn(PETSC_SUCCESS);
237 }
238 
239 // MatCopy() calls MatCheckPreallocated(), so B will have Xprev, Fprev, LMBASIS_S, and LMBASIS_Y
MatCopy_LMVM(Mat B,Mat M,MatStructure str)240 static PetscErrorCode MatCopy_LMVM(Mat B, Mat M, MatStructure str)
241 {
242   Mat_LMVM *bctx = (Mat_LMVM *)B->data;
243   Mat_LMVM *mctx;
244   Mat       J0_copy;
245 
246   PetscFunctionBegin;
247   if (str == DIFFERENT_NONZERO_PATTERN) {
248     PetscCall(MatLMVMReset(M, PETSC_TRUE));
249     PetscCall(MatLMVMAllocate(M, bctx->Xprev, bctx->Fprev));
250   } else MatCheckSameSize(B, 1, M, 2);
251 
252   mctx = (Mat_LMVM *)M->data;
253   PetscCall(MatDuplicate(bctx->J0, MAT_COPY_VALUES, &J0_copy));
254   PetscCall(MatLMVMSetJ0(M, J0_copy));
255   PetscCall(MatDestroy(&J0_copy));
256   mctx->nupdates = bctx->nupdates;
257   mctx->nrejects = bctx->nrejects;
258   mctx->k        = bctx->k;
259   PetscCall(MatLMVMAllocateVecs(M));
260   PetscCall(VecCopy(bctx->Xprev, mctx->Xprev));
261   PetscCall(VecCopy(bctx->Fprev, mctx->Fprev));
262   PetscCall(MatLMVMAllocateBases(M));
263   PetscCall(LMBasisCopy(bctx->basis[LMBASIS_S], mctx->basis[LMBASIS_S]));
264   PetscCall(LMBasisCopy(bctx->basis[LMBASIS_Y], mctx->basis[LMBASIS_Y]));
265   mctx->do_not_cache_J0_products = bctx->do_not_cache_J0_products;
266   mctx->cache_gradient_products  = bctx->cache_gradient_products;
267   mctx->mult_alg                 = bctx->mult_alg;
268   if (mctx->ops->setmultalgorithm) PetscCall((*mctx->ops->setmultalgorithm)(M));
269   if (bctx->ops->copy) PetscCall((*bctx->ops->copy)(B, M, str));
270   PetscFunctionReturn(PETSC_SUCCESS);
271 }
272 
MatDuplicate_LMVM(Mat B,MatDuplicateOption op,Mat * mat)273 static PetscErrorCode MatDuplicate_LMVM(Mat B, MatDuplicateOption op, Mat *mat)
274 {
275   Mat_LMVM *bctx = (Mat_LMVM *)B->data;
276   Mat_LMVM *mctx;
277   MatType   lmvmType;
278   Mat       A;
279 
280   PetscFunctionBegin;
281   PetscCall(MatGetType(B, &lmvmType));
282   PetscCall(MatCreate(PetscObjectComm((PetscObject)B), mat));
283   PetscCall(MatSetType(*mat, lmvmType));
284 
285   A       = *mat;
286   mctx    = (Mat_LMVM *)A->data;
287   mctx->m = bctx->m;
288   if (bctx->J0ksp) {
289     PetscReal rtol, atol, dtol;
290     PetscInt  max_it;
291 
292     PetscCall(KSPGetTolerances(bctx->J0ksp, &rtol, &atol, &dtol, &max_it));
293     PetscCall(KSPSetTolerances(mctx->J0ksp, rtol, atol, dtol, max_it));
294   }
295   mctx->shift = bctx->shift;
296 
297   PetscCall(MatLMVMAllocate(*mat, bctx->Xprev, bctx->Fprev));
298   if (op == MAT_COPY_VALUES) PetscCall(MatCopy(B, *mat, SAME_NONZERO_PATTERN));
299   PetscFunctionReturn(PETSC_SUCCESS);
300 }
301 
MatShift_LMVM(Mat B,PetscScalar a)302 static PetscErrorCode MatShift_LMVM(Mat B, PetscScalar a)
303 {
304   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
305 
306   PetscFunctionBegin;
307   lmvm->shift += PetscRealPart(a);
308   PetscFunctionReturn(PETSC_SUCCESS);
309 }
310 
MatView_LMVM(Mat B,PetscViewer pv)311 PetscErrorCode MatView_LMVM(Mat B, PetscViewer pv)
312 {
313   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
314   PetscBool isascii;
315   MatType   type;
316 
317   PetscFunctionBegin;
318   PetscCall(PetscObjectTypeCompare((PetscObject)pv, PETSCVIEWERASCII, &isascii));
319   if (isascii) {
320     PetscBool         is_exact;
321     PetscViewerFormat format;
322 
323     PetscCall(MatGetType(B, &type));
324     PetscCall(PetscViewerASCIIPrintf(pv, "Max. storage: %" PetscInt_FMT "\n", lmvm->m));
325     PetscCall(PetscViewerASCIIPrintf(pv, "Used storage: %" PetscInt_FMT "\n", PetscMin(lmvm->k, lmvm->m)));
326     PetscCall(PetscViewerASCIIPrintf(pv, "Number of updates: %" PetscInt_FMT "\n", lmvm->nupdates));
327     PetscCall(PetscViewerASCIIPrintf(pv, "Number of rejects: %" PetscInt_FMT "\n", lmvm->nrejects));
328     PetscCall(PetscViewerASCIIPrintf(pv, "Number of resets: %" PetscInt_FMT "\n", lmvm->nresets));
329     PetscCall(PetscViewerGetFormat(pv, &format));
330     if (format == PETSC_VIEWER_ASCII_INFO_DETAIL) {
331       PetscCall(PetscViewerASCIIPrintf(pv, "Mult algorithm: %s\n", MatLMVMMultAlgorithms[lmvm->mult_alg]));
332       PetscCall(PetscViewerASCIIPrintf(pv, "Cache J0 products: %s\n", lmvm->do_not_cache_J0_products ? "false" : "true"));
333       PetscCall(PetscViewerASCIIPrintf(pv, "Cache gradient products: %s\n", lmvm->cache_gradient_products ? "true" : "false"));
334     }
335     PetscCall(MatLMVMJ0KSPIsExact(B, &is_exact));
336     if (is_exact) {
337       PetscBool is_scalar;
338 
339       PetscCall(PetscObjectTypeCompare((PetscObject)lmvm->J0, MATCONSTANTDIAGONAL, &is_scalar));
340       PetscCall(PetscViewerASCIIPrintf(pv, "J0:\n"));
341       PetscCall(PetscViewerASCIIPushTab(pv));
342       PetscCall(PetscViewerPushFormat(pv, is_scalar ? PETSC_VIEWER_DEFAULT : PETSC_VIEWER_ASCII_INFO));
343       PetscCall(MatView(lmvm->J0, pv));
344       PetscCall(PetscViewerPopFormat(pv));
345       PetscCall(PetscViewerASCIIPopTab(pv));
346     } else {
347       PetscCall(PetscViewerASCIIPrintf(pv, "J0 KSP:\n"));
348       PetscCall(PetscViewerASCIIPushTab(pv));
349       PetscCall(PetscViewerPushFormat(pv, PETSC_VIEWER_ASCII_INFO));
350       PetscCall(KSPView(lmvm->J0ksp, pv));
351       PetscCall(PetscViewerPopFormat(pv));
352       PetscCall(PetscViewerASCIIPopTab(pv));
353     }
354   }
355   PetscFunctionReturn(PETSC_SUCCESS);
356 }
357 
MatSetFromOptions_LMVM(Mat B,PetscOptionItems PetscOptionsObject)358 PetscErrorCode MatSetFromOptions_LMVM(Mat B, PetscOptionItems PetscOptionsObject)
359 {
360   Mat_LMVM            *lmvm     = (Mat_LMVM *)B->data;
361   PetscBool            cache_J0 = lmvm->do_not_cache_J0_products ? PETSC_FALSE : PETSC_TRUE; // Default is false, but flipping double negative so that the command line option make sense
362   PetscBool            set;
363   PetscInt             hist_size = lmvm->m;
364   MatLMVMMultAlgorithm mult_alg;
365 
366   PetscFunctionBegin;
367   PetscCall(MatLMVMGetMultAlgorithm(B, &mult_alg));
368   PetscOptionsHeadBegin(PetscOptionsObject, "Limited-memory Variable Metric matrix for approximating Jacobians");
369   PetscCall(PetscOptionsInt("-mat_lmvm_hist_size", "number of past updates kept in memory for the approximation", "", hist_size, &hist_size, NULL));
370   PetscCall(PetscOptionsEnum("-mat_lmvm_mult_algorithm", "Algorithm used to matrix-vector products", "", MatLMVMMultAlgorithms, (PetscEnum)mult_alg, (PetscEnum *)&mult_alg, &set));
371   PetscCall(PetscOptionsReal("-mat_lmvm_eps", "(developer) machine zero definition", "", lmvm->eps, &lmvm->eps, NULL));
372   PetscCall(PetscOptionsBool("-mat_lmvm_cache_J0_products", "Cache applications of the kernel J0 or its inverse", "", cache_J0, &cache_J0, NULL));
373   PetscCall(PetscOptionsBool("-mat_lmvm_cache_gradient_products", "Cache data used to apply the inverse Hessian to a gradient vector to accelerate the quasi-Newton update", "", lmvm->cache_gradient_products, &lmvm->cache_gradient_products, NULL));
374   PetscCall(PetscOptionsBool("-mat_lmvm_debug", "(developer) Perform internal debugging checks", "", lmvm->debug, &lmvm->debug, NULL));
375   PetscOptionsHeadEnd();
376   lmvm->do_not_cache_J0_products = cache_J0 ? PETSC_FALSE : PETSC_TRUE;
377   if (hist_size != lmvm->m) PetscCall(MatLMVMSetHistorySize(B, hist_size));
378   if (set) PetscCall(MatLMVMSetMultAlgorithm(B, mult_alg));
379   if (lmvm->created_J0) PetscCall(MatSetFromOptions(lmvm->J0));
380   if (lmvm->created_J0ksp) PetscCall(KSPSetFromOptions(lmvm->J0ksp));
381   PetscFunctionReturn(PETSC_SUCCESS);
382 }
383 
MatSetUp_LMVM(Mat B)384 PetscErrorCode MatSetUp_LMVM(Mat B)
385 {
386   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
387 
388   PetscFunctionBegin;
389   PetscCall(PetscLayoutSetUp(B->rmap));
390   PetscCall(PetscLayoutSetUp(B->cmap));
391   if (lmvm->created_J0) {
392     PetscCall(PetscLayoutReference(B->rmap, &lmvm->J0->rmap));
393     PetscCall(PetscLayoutReference(B->cmap, &lmvm->J0->cmap));
394     PetscCall(MatSetUp(lmvm->J0));
395   }
396   PetscCall(MatLMVMAllocateVecs(B));
397   PetscCall(MatLMVMAllocateBases(B));
398   PetscFunctionReturn(PETSC_SUCCESS);
399 }
400 
401 /*@
402   MatLMVMSetMultAlgorithm - Set the algorithm used by a `MatLMVM` for products
403 
404   Logically collective
405 
406   Input Parameters:
407 + B   - a `MatLMVM` matrix
408 - alg - one of the algorithm classes (`MAT_LMVM_MULT_RECURSIVE`, `MAT_LMVM_MULT_DENSE`, `MAT_LMVM_MULT_COMPACT_DENSE`)
409 
410   Level: advanced
411 
412 .seealso: [](ch_matrices), `MatLMVM`, `MatLMVMMultAlgorithm`, `MatLMVMGetMultAlgorithm()`
413 @*/
MatLMVMSetMultAlgorithm(Mat B,MatLMVMMultAlgorithm alg)414 PetscErrorCode MatLMVMSetMultAlgorithm(Mat B, MatLMVMMultAlgorithm alg)
415 {
416   PetscFunctionBegin;
417   PetscValidHeaderSpecific(B, MAT_CLASSID, 1);
418   PetscTryMethod(B, "MatLMVMSetMultAlgorithm_C", (Mat, MatLMVMMultAlgorithm), (B, alg));
419   PetscFunctionReturn(PETSC_SUCCESS);
420 }
421 
MatLMVMSetMultAlgorithm_LMVM(Mat B,MatLMVMMultAlgorithm alg)422 static PetscErrorCode MatLMVMSetMultAlgorithm_LMVM(Mat B, MatLMVMMultAlgorithm alg)
423 {
424   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
425 
426   PetscFunctionBegin;
427   lmvm->mult_alg = alg;
428   if (lmvm->ops->setmultalgorithm) PetscCall((*lmvm->ops->setmultalgorithm)(B));
429   PetscFunctionReturn(PETSC_SUCCESS);
430 }
431 
432 /*@
433   MatLMVMGetMultAlgorithm - Get the algorithm used by a `MatLMVM` for products
434 
435   Not collective
436 
437   Input Parameter:
438 . B - a `MatLMVM` matrix
439 
440   Output Parameter:
441 . alg - one of the algorithm classes (`MAT_LMVM_MULT_RECURSIVE`, `MAT_LMVM_MULT_DENSE`, `MAT_LMVM_MULT_COMPACT_DENSE`)
442 
443   Level: advanced
444 
445 .seealso: [](ch_matrices), `MatLMVM`, `MatLMVMMultAlgorithm`, `MatLMVMSetMultAlgorithm()`
446 @*/
MatLMVMGetMultAlgorithm(Mat B,MatLMVMMultAlgorithm * alg)447 PetscErrorCode MatLMVMGetMultAlgorithm(Mat B, MatLMVMMultAlgorithm *alg)
448 {
449   PetscFunctionBegin;
450   PetscValidHeaderSpecific(B, MAT_CLASSID, 1);
451   PetscAssertPointer(alg, 2);
452   PetscUseMethod(B, "MatLMVMGetMultAlgorithm_C", (Mat, MatLMVMMultAlgorithm *), (B, alg));
453   PetscFunctionReturn(PETSC_SUCCESS);
454 }
455 
MatLMVMGetMultAlgorithm_LMVM(Mat B,MatLMVMMultAlgorithm * alg)456 static PetscErrorCode MatLMVMGetMultAlgorithm_LMVM(Mat B, MatLMVMMultAlgorithm *alg)
457 {
458   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
459 
460   PetscFunctionBegin;
461   *alg = lmvm->mult_alg;
462   PetscFunctionReturn(PETSC_SUCCESS);
463 }
464 
MatDestroy_LMVM(Mat B)465 PetscErrorCode MatDestroy_LMVM(Mat B)
466 {
467   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
468 
469   PetscFunctionBegin;
470   PetscCall(MatReset_LMVM(B, MAT_LMVM_RESET_ALL));
471   PetscCall(KSPDestroy(&lmvm->J0ksp));
472   PetscCall(MatDestroy(&lmvm->J0));
473   PetscCall(PetscFree(B->data));
474   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMGetLastUpdate_C", NULL));
475   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSetMultAlgorithm_C", NULL));
476   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMGetMultAlgorithm_C", NULL));
477   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetOptionsPrefix_C", NULL));
478   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatAppendOptionsPrefix_C", NULL));
479   PetscFunctionReturn(PETSC_SUCCESS);
480 }
481 
482 /*@
483   MatLMVMGetLastUpdate - Get the last vectors passed to `MatLMVMUpdate()`
484 
485   Not collective
486 
487   Input Parameter:
488 . B - a `MatLMVM` matrix
489 
490   Output Parameters:
491 + x_prev - the last solution vector
492 - f_prev - the last function vector
493 
494   Level: intermediate
495 
496 .seealso: [](ch_matrices), `MatLMVM`, `MatLMVMUpdate()`
497 @*/
MatLMVMGetLastUpdate(Mat B,Vec * x_prev,Vec * f_prev)498 PetscErrorCode MatLMVMGetLastUpdate(Mat B, Vec *x_prev, Vec *f_prev)
499 {
500   PetscFunctionBegin;
501   PetscValidHeaderSpecific(B, MAT_CLASSID, 1);
502   PetscTryMethod(B, "MatLMVMGetLastUpdate_C", (Mat, Vec *, Vec *), (B, x_prev, f_prev));
503   PetscFunctionReturn(PETSC_SUCCESS);
504 }
505 
MatLMVMGetLastUpdate_LMVM(Mat B,Vec * x_prev,Vec * f_prev)506 static PetscErrorCode MatLMVMGetLastUpdate_LMVM(Mat B, Vec *x_prev, Vec *f_prev)
507 {
508   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
509 
510   PetscFunctionBegin;
511   if (x_prev) *x_prev = (lmvm->prev_set) ? lmvm->Xprev : NULL;
512   if (f_prev) *f_prev = (lmvm->prev_set) ? lmvm->Fprev : NULL;
513   PetscFunctionReturn(PETSC_SUCCESS);
514 }
515 
516 /* in both MatSetOptionsPrefix() and MatAppendOptionsPrefix(), this is called after
517    the prefix of B has been changed, so we just query the prefix of B rather than
518    using the passed prefix */
MatSetOptionsPrefix_LMVM(Mat B,const char unused[])519 static PetscErrorCode MatSetOptionsPrefix_LMVM(Mat B, const char unused[])
520 {
521   Mat_LMVM   *lmvm = (Mat_LMVM *)B->data;
522   const char *prefix;
523 
524   PetscFunctionBegin;
525   PetscCall(MatGetOptionsPrefix(B, &prefix));
526   if (lmvm->created_J0) {
527     PetscCall(MatSetOptionsPrefix(lmvm->J0, prefix));
528     PetscCall(MatAppendOptionsPrefix(lmvm->J0, "mat_lmvm_J0_"));
529   }
530   if (lmvm->created_J0ksp) {
531     PetscCall(KSPSetOptionsPrefix(lmvm->J0ksp, prefix));
532     PetscCall(KSPAppendOptionsPrefix(lmvm->J0ksp, "mat_lmvm_J0_"));
533   }
534   PetscFunctionReturn(PETSC_SUCCESS);
535 }
536 
537 /*MC
538    MATLMVM - MATLMVM = "lmvm" - A matrix type used for Limited-Memory Variable Metric (LMVM) matrices.
539 
540    Level: intermediate
541 
542    Developer notes:
543    Improve this manual page as well as many others in the MATLMVM family.
544 
545 .seealso: [](sec_matlmvm), `Mat`
546 M*/
MatCreate_LMVM(Mat B)547 PetscErrorCode MatCreate_LMVM(Mat B)
548 {
549   Mat_LMVM *lmvm;
550 
551   PetscFunctionBegin;
552   PetscCall(MatLMVMPackageInitialize());
553   PetscCall(PetscNew(&lmvm));
554   B->data = (void *)lmvm;
555 
556   lmvm->m   = 5;
557   lmvm->eps = PetscPowReal(PETSC_MACHINE_EPSILON, 2.0 / 3.0);
558 
559   B->ops->destroy                = MatDestroy_LMVM;
560   B->ops->setfromoptions         = MatSetFromOptions_LMVM;
561   B->ops->view                   = MatView_LMVM;
562   B->ops->setup                  = MatSetUp_LMVM;
563   B->ops->shift                  = MatShift_LMVM;
564   B->ops->duplicate              = MatDuplicate_LMVM;
565   B->ops->mult                   = MatMult_LMVM;
566   B->ops->multhermitiantranspose = MatMultHermitianTranspose_LMVM;
567   B->ops->multadd                = MatMultAdd_LMVM;
568   B->ops->copy                   = MatCopy_LMVM;
569   B->ops->solve                  = MatSolve_LMVM;
570   B->ops->solvetranspose         = MatSolveTranspose_LMVM;
571   if (!PetscDefined(USE_COMPLEX)) B->ops->multtranspose = MatMultHermitianTranspose_LMVM;
572 
573   /*
574     There is no assembly phase, Mat_LMVM relies on B->preallocated to ensure that
575     necessary setup happens in MatSetUp(), which is called in MatCheckPreallocated()
576     in all major operations (MatLMVMUpdate(), MatMult(), MatSolve(), etc.)
577    */
578   B->assembled = PETSC_TRUE;
579 
580   lmvm->ops->update = MatUpdate_LMVM;
581 
582   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVM));
583   // J0 should be present at all times, calling ClearJ0() here initializes it to the identity
584   PetscCall(MatLMVMClearJ0(B));
585 
586   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMGetLastUpdate_C", MatLMVMGetLastUpdate_LMVM));
587   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSetMultAlgorithm_C", MatLMVMSetMultAlgorithm_LMVM));
588   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMGetMultAlgorithm_C", MatLMVMGetMultAlgorithm_LMVM));
589   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetOptionsPrefix_C", MatSetOptionsPrefix_LMVM));
590   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatAppendOptionsPrefix_C", MatSetOptionsPrefix_LMVM));
591   PetscFunctionReturn(PETSC_SUCCESS);
592 }
593