xref: /petsc/src/ksp/ksp/utils/lmvm/lmbasis.c (revision 8577b683712d1cca1e9b8fdaa9ae028364224dad)
1 #include <petsc/private/petscimpl.h>
2 #include "lmbasis.h"
3 #include "blas_cyclic/blas_cyclic.h"
4 
5 PetscLogEvent LMBASIS_GEMM, LMBASIS_GEMV, LMBASIS_GEMVH;
6 
LMBasisCreate(Vec v,PetscInt m,LMBasis * basis_p)7 PetscErrorCode LMBasisCreate(Vec v, PetscInt m, LMBasis *basis_p)
8 {
9   PetscInt    n, N;
10   PetscMPIInt rank;
11   Mat         backing;
12   VecType     type;
13   LMBasis     basis;
14 
15   PetscFunctionBegin;
16   PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
17   PetscValidLogicalCollectiveInt(v, m, 2);
18   PetscCheck(m >= 0, PetscObjectComm((PetscObject)v), PETSC_ERR_ARG_OUTOFRANGE, "Requested window size %" PetscInt_FMT " is not >= 0", m);
19   PetscCall(VecGetLocalSize(v, &n));
20   PetscCall(VecGetSize(v, &N));
21   PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)v), &rank));
22   PetscCall(VecGetType(v, &type));
23   PetscCall(MatCreateDenseFromVecType(PetscObjectComm((PetscObject)v), type, n, rank == 0 ? m : 0, N, m, n, NULL, &backing));
24   PetscCall(PetscNew(&basis));
25   *basis_p    = basis;
26   basis->m    = m;
27   basis->k    = 0;
28   basis->vecs = backing;
29   PetscFunctionReturn(PETSC_SUCCESS);
30 }
31 
LMBasisGetVec_Internal(LMBasis basis,PetscInt idx,PetscMemoryAccessMode mode,Vec * single,PetscBool check_idx)32 static PetscErrorCode LMBasisGetVec_Internal(LMBasis basis, PetscInt idx, PetscMemoryAccessMode mode, Vec *single, PetscBool check_idx)
33 {
34   PetscFunctionBegin;
35   PetscAssertPointer(basis, 1);
36   if (check_idx) {
37     PetscValidLogicalCollectiveInt(basis->vecs, idx, 2);
38     PetscCheck(idx < basis->k, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Asked for index %" PetscInt_FMT " >= number of inserted vecs %" PetscInt_FMT, idx, basis->k);
39     PetscInt earliest = PetscMax(0, basis->k - basis->m);
40     PetscCheck(idx >= earliest, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Asked for index %" PetscInt_FMT " < the earliest retained index % " PetscInt_FMT, idx, earliest);
41   }
42   PetscAssert(mode == PETSC_MEMORY_ACCESS_READ || mode == PETSC_MEMORY_ACCESS_WRITE, PETSC_COMM_SELF, PETSC_ERR_PLIB, "READ_WRITE access not implemented");
43   if (mode == PETSC_MEMORY_ACCESS_READ) {
44     PetscCall(MatDenseGetColumnVecRead(basis->vecs, idx % basis->m, single));
45   } else {
46     PetscCall(MatDenseGetColumnVecWrite(basis->vecs, idx % basis->m, single));
47   }
48   PetscFunctionReturn(PETSC_SUCCESS);
49 }
50 
LMBasisGetVec(LMBasis basis,PetscInt idx,PetscMemoryAccessMode mode,Vec * single)51 PETSC_INTERN PetscErrorCode LMBasisGetVec(LMBasis basis, PetscInt idx, PetscMemoryAccessMode mode, Vec *single)
52 {
53   PetscFunctionBegin;
54   PetscCall(LMBasisGetVec_Internal(basis, idx, mode, single, PETSC_TRUE));
55   PetscFunctionReturn(PETSC_SUCCESS);
56 }
57 
LMBasisRestoreVec(LMBasis basis,PetscInt idx,PetscMemoryAccessMode mode,Vec * single)58 PETSC_INTERN PetscErrorCode LMBasisRestoreVec(LMBasis basis, PetscInt idx, PetscMemoryAccessMode mode, Vec *single)
59 {
60   PetscFunctionBegin;
61   PetscAssertPointer(basis, 1);
62   PetscAssert(mode == PETSC_MEMORY_ACCESS_READ || mode == PETSC_MEMORY_ACCESS_WRITE, PETSC_COMM_SELF, PETSC_ERR_PLIB, "READ_WRITE access not implemented");
63   if (mode == PETSC_MEMORY_ACCESS_READ) {
64     PetscCall(MatDenseRestoreColumnVecRead(basis->vecs, idx % basis->m, single));
65   } else {
66     PetscCall(MatDenseRestoreColumnVecWrite(basis->vecs, idx % basis->m, single));
67   }
68   *single = NULL;
69   PetscFunctionReturn(PETSC_SUCCESS);
70 }
71 
LMBasisGetVecRead(LMBasis B,PetscInt i,Vec * b)72 PETSC_INTERN PetscErrorCode LMBasisGetVecRead(LMBasis B, PetscInt i, Vec *b)
73 {
74   return LMBasisGetVec(B, i, PETSC_MEMORY_ACCESS_READ, b);
75 }
LMBasisRestoreVecRead(LMBasis B,PetscInt i,Vec * b)76 PETSC_INTERN PetscErrorCode LMBasisRestoreVecRead(LMBasis B, PetscInt i, Vec *b)
77 {
78   return LMBasisRestoreVec(B, i, PETSC_MEMORY_ACCESS_READ, b);
79 }
80 
LMBasisGetNextVec(LMBasis basis,Vec * single)81 PETSC_INTERN PetscErrorCode LMBasisGetNextVec(LMBasis basis, Vec *single)
82 {
83   PetscFunctionBegin;
84   PetscCall(LMBasisGetVec_Internal(basis, basis->k, PETSC_MEMORY_ACCESS_WRITE, single, PETSC_FALSE));
85   PetscFunctionReturn(PETSC_SUCCESS);
86 }
87 
LMBasisRestoreNextVec(LMBasis basis,Vec * single)88 PETSC_INTERN PetscErrorCode LMBasisRestoreNextVec(LMBasis basis, Vec *single)
89 {
90   PetscFunctionBegin;
91   PetscAssertPointer(basis, 1);
92   PetscCall(LMBasisRestoreVec(basis, basis->k++, PETSC_MEMORY_ACCESS_WRITE, single));
93   // basis is updated, invalidate cached product
94   basis->cached_vec_id    = 0;
95   basis->cached_vec_state = 0;
96   PetscFunctionReturn(PETSC_SUCCESS);
97 }
98 
LMBasisSetNextVec(LMBasis basis,Vec single)99 PETSC_INTERN PetscErrorCode LMBasisSetNextVec(LMBasis basis, Vec single)
100 {
101   Vec next;
102 
103   PetscFunctionBegin;
104   PetscCall(LMBasisGetNextVec(basis, &next));
105   PetscCall(VecCopy(single, next));
106   PetscCall(LMBasisRestoreNextVec(basis, &next));
107   PetscFunctionReturn(PETSC_SUCCESS);
108 }
109 
LMBasisDestroy(LMBasis * basis_p)110 PETSC_INTERN PetscErrorCode LMBasisDestroy(LMBasis *basis_p)
111 {
112   LMBasis basis = *basis_p;
113 
114   PetscFunctionBegin;
115   *basis_p = NULL;
116   if (basis == NULL) PetscFunctionReturn(PETSC_SUCCESS);
117   PetscCall(LMBasisReset(basis));
118   PetscCheck(basis->work_vecs_in_use == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Work vecs are still checked out at destruction");
119   {
120     VecLink head = basis->work_vecs_available;
121 
122     while (head) {
123       VecLink next = head->next;
124 
125       PetscCall(VecDestroy(&head->vec));
126       PetscCall(PetscFree(head));
127       head = next;
128     }
129   }
130   PetscCheck(basis->work_rows_in_use == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Work rows are still checked out at destruction");
131   {
132     VecLink head = basis->work_rows_available;
133 
134     while (head) {
135       VecLink next = head->next;
136 
137       PetscCall(VecDestroy(&head->vec));
138       PetscCall(PetscFree(head));
139       head = next;
140     }
141   }
142   PetscCall(MatDestroy(&basis->vecs));
143   PetscCall(PetscFree(basis));
144   PetscFunctionReturn(PETSC_SUCCESS);
145 }
146 
LMBasisGetWorkVec(LMBasis basis,Vec * vec_p)147 PETSC_INTERN PetscErrorCode LMBasisGetWorkVec(LMBasis basis, Vec *vec_p)
148 {
149   VecLink link;
150 
151   PetscFunctionBegin;
152   if (!basis->work_vecs_available) {
153     PetscCall(PetscNew(&basis->work_vecs_available));
154     PetscCall(MatCreateVecs(basis->vecs, NULL, &basis->work_vecs_available->vec));
155   }
156   link                       = basis->work_vecs_available;
157   basis->work_vecs_available = link->next;
158   link->next                 = basis->work_vecs_in_use;
159   basis->work_vecs_in_use    = link;
160 
161   *vec_p    = link->vec;
162   link->vec = NULL;
163   PetscFunctionReturn(PETSC_SUCCESS);
164 }
165 
LMBasisRestoreWorkVec(LMBasis basis,Vec * vec_p)166 PETSC_INTERN PetscErrorCode LMBasisRestoreWorkVec(LMBasis basis, Vec *vec_p)
167 {
168   Vec     v    = *vec_p;
169   VecLink link = NULL;
170 
171   PetscFunctionBegin;
172   *vec_p = NULL;
173   PetscCheck(basis->work_vecs_in_use, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Trying to check in a vec that wasn't checked out");
174   link                       = basis->work_vecs_in_use;
175   basis->work_vecs_in_use    = link->next;
176   link->next                 = basis->work_vecs_available;
177   basis->work_vecs_available = link;
178 
179   PetscAssert(link->vec == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_PLIB, "Link not ready to return vector");
180   link->vec = v;
181   PetscFunctionReturn(PETSC_SUCCESS);
182 }
183 
LMBasisCreateRow(LMBasis basis,Vec * row_p)184 PETSC_INTERN PetscErrorCode LMBasisCreateRow(LMBasis basis, Vec *row_p)
185 {
186   PetscFunctionBegin;
187   PetscCall(MatCreateVecs(basis->vecs, row_p, NULL));
188   PetscFunctionReturn(PETSC_SUCCESS);
189 }
190 
LMBasisGetWorkRow(LMBasis basis,Vec * row_p)191 PETSC_INTERN PetscErrorCode LMBasisGetWorkRow(LMBasis basis, Vec *row_p)
192 {
193   VecLink link;
194 
195   PetscFunctionBegin;
196   if (!basis->work_rows_available) {
197     PetscCall(PetscNew(&basis->work_rows_available));
198     PetscCall(MatCreateVecs(basis->vecs, &basis->work_rows_available->vec, NULL));
199   }
200   link                       = basis->work_rows_available;
201   basis->work_rows_available = link->next;
202   link->next                 = basis->work_rows_in_use;
203   basis->work_rows_in_use    = link;
204 
205   PetscCall(VecZeroEntries(link->vec));
206   *row_p    = link->vec;
207   link->vec = NULL;
208   PetscFunctionReturn(PETSC_SUCCESS);
209 }
210 
LMBasisRestoreWorkRow(LMBasis basis,Vec * row_p)211 PETSC_INTERN PetscErrorCode LMBasisRestoreWorkRow(LMBasis basis, Vec *row_p)
212 {
213   Vec     v    = *row_p;
214   VecLink link = NULL;
215 
216   PetscFunctionBegin;
217   *row_p = NULL;
218   PetscCheck(basis->work_rows_in_use, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Trying to check in a row that wasn't checked out");
219   link                       = basis->work_rows_in_use;
220   basis->work_rows_in_use    = link->next;
221   link->next                 = basis->work_rows_available;
222   basis->work_rows_available = link;
223 
224   PetscAssert(link->vec == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_PLIB, "Link not ready to return vector");
225   link->vec = v;
226   PetscFunctionReturn(PETSC_SUCCESS);
227 }
228 
LMBasisCopy(LMBasis basis_a,LMBasis basis_b)229 PETSC_INTERN PetscErrorCode LMBasisCopy(LMBasis basis_a, LMBasis basis_b)
230 {
231   PetscFunctionBegin;
232   PetscCheck(basis_a->m == basis_b->m, PetscObjectComm((PetscObject)basis_a), PETSC_ERR_ARG_SIZ, "Copy target has different number of vecs, %" PetscInt_FMT " != %" PetscInt_FMT, basis_b->m, basis_a->m);
233   basis_b->k = basis_a->k;
234   PetscCall(MatCopy(basis_a->vecs, basis_b->vecs, SAME_NONZERO_PATTERN));
235   basis_b->cached_vec_id    = basis_a->cached_vec_id;
236   basis_b->cached_vec_state = basis_a->cached_vec_state;
237   if (basis_a->cached_product) {
238     if (!basis_b->cached_product) PetscCall(VecDuplicate(basis_a->cached_product, &basis_b->cached_product));
239     PetscCall(VecCopy(basis_a->cached_product, basis_b->cached_product));
240   }
241   PetscFunctionReturn(PETSC_SUCCESS);
242 }
243 
LMBasisGetRange(LMBasis basis,PetscInt * oldest,PetscInt * next)244 PETSC_INTERN PetscErrorCode LMBasisGetRange(LMBasis basis, PetscInt *oldest, PetscInt *next)
245 {
246   PetscFunctionBegin;
247   *next   = basis->k;
248   *oldest = PetscMax(0, basis->k - basis->m);
249   PetscFunctionReturn(PETSC_SUCCESS);
250 }
251 
LMBasisMultCheck(LMBasis A,PetscInt oldest,PetscInt next)252 static PetscErrorCode LMBasisMultCheck(LMBasis A, PetscInt oldest, PetscInt next)
253 {
254   PetscInt basis_oldest, basis_next;
255 
256   PetscFunctionBegin;
257   PetscCall(LMBasisGetRange(A, &basis_oldest, &basis_next));
258   PetscCheck(oldest >= basis_oldest && next <= basis_next, PetscObjectComm((PetscObject)A->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Asked for vec that hasn't been computed or is no longer stored");
259   PetscFunctionReturn(PETSC_SUCCESS);
260 }
261 
LMBasisGEMV(LMBasis A,PetscInt oldest,PetscInt next,PetscScalar alpha,Vec x,PetscScalar beta,Vec y)262 PETSC_INTERN PetscErrorCode LMBasisGEMV(LMBasis A, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y)
263 {
264   PetscInt lim        = next - oldest;
265   PetscInt next_idx   = ((next - 1) % A->m) + 1;
266   PetscInt oldest_idx = oldest % A->m;
267   Vec      x_work     = NULL;
268   Vec      x_         = x;
269 
270   PetscFunctionBegin;
271   if (lim <= 0) PetscFunctionReturn(PETSC_SUCCESS);
272   PetscCall(PetscLogEventBegin(LMBASIS_GEMV, NULL, NULL, NULL, NULL));
273   PetscCall(LMBasisMultCheck(A, oldest, next));
274   if (alpha != 1.0) {
275     PetscCall(LMBasisGetWorkRow(A, &x_work));
276     PetscCall(VecAXPBYCyclic(oldest, next, alpha, x, 0.0, x_work));
277     x_ = x_work;
278   }
279   if (beta != 1.0 && beta != 0.0) PetscCall(VecScale(y, beta));
280   if (lim == A->m) {
281     // all vectors are used
282     if (beta == 0.0) PetscCall(MatMult(A->vecs, x_, y));
283     else PetscCall(MatMultAdd(A->vecs, x_, y, y));
284   } else if (oldest_idx < next_idx) {
285     // contiguous vectors are used
286     if (beta == 0.0) PetscCall(MatMultColumnRange(A->vecs, x_, y, oldest_idx, next_idx));
287     else PetscCall(MatMultAddColumnRange(A->vecs, x_, y, y, oldest_idx, next_idx));
288   } else {
289     if (beta == 0.0) PetscCall(MatMultColumnRange(A->vecs, x_, y, 0, next_idx));
290     else PetscCall(MatMultAddColumnRange(A->vecs, x_, y, y, 0, next_idx));
291     PetscCall(MatMultAddColumnRange(A->vecs, x_, y, y, oldest_idx, A->m));
292   }
293   if (alpha != 1.0) PetscCall(LMBasisRestoreWorkRow(A, &x_work));
294   PetscCall(PetscLogEventEnd(LMBASIS_GEMV, NULL, NULL, NULL, NULL));
295   PetscFunctionReturn(PETSC_SUCCESS);
296 }
297 
LMBasisGEMVH(LMBasis A,PetscInt oldest,PetscInt next,PetscScalar alpha,Vec x,PetscScalar beta,Vec y)298 PETSC_INTERN PetscErrorCode LMBasisGEMVH(LMBasis A, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y)
299 {
300   PetscInt lim        = next - oldest;
301   PetscInt next_idx   = ((next - 1) % A->m) + 1;
302   PetscInt oldest_idx = oldest % A->m;
303   Vec      y_         = y;
304 
305   PetscFunctionBegin;
306   if (lim <= 0) PetscFunctionReturn(PETSC_SUCCESS);
307   PetscCall(LMBasisMultCheck(A, oldest, next));
308   if (A->cached_product && A->cached_vec_id != 0 && A->cached_vec_state != 0) {
309     // see if x is the cached input vector
310     PetscObjectId    x_id;
311     PetscObjectState x_state;
312 
313     PetscCall(PetscObjectGetId((PetscObject)x, &x_id));
314     PetscCall(PetscObjectStateGet((PetscObject)x, &x_state));
315     if (x_id == A->cached_vec_id && x_state == A->cached_vec_state) {
316       PetscCall(VecAXPBYCyclic(oldest, next, alpha, A->cached_product, beta, y));
317       PetscFunctionReturn(PETSC_SUCCESS);
318     }
319   }
320   PetscCall(PetscLogEventBegin(LMBASIS_GEMVH, NULL, NULL, NULL, NULL));
321   if (alpha != 1.0 || (beta != 1.0 && beta != 0.0)) PetscCall(LMBasisGetWorkRow(A, &y_));
322   if (lim == A->m) {
323     // all vectors are used
324     if (alpha == 1.0 && beta == 1.0) PetscCall(MatMultHermitianTransposeAdd(A->vecs, x, y_, y_));
325     else PetscCall(MatMultHermitianTranspose(A->vecs, x, y_));
326   } else if (oldest_idx < next_idx) {
327     // contiguous vectors are used
328     if (alpha == 1.0 && beta == 1.0) PetscCall(MatMultHermitianTransposeAddColumnRange(A->vecs, x, y_, y_, oldest_idx, next_idx));
329     else PetscCall(MatMultHermitianTransposeColumnRange(A->vecs, x, y_, oldest_idx, next_idx));
330   } else {
331     if (alpha == 1.0 && beta == 1.0) {
332       PetscCall(MatMultHermitianTransposeAddColumnRange(A->vecs, x, y_, y_, 0, next_idx));
333       PetscCall(MatMultHermitianTransposeAddColumnRange(A->vecs, x, y_, y_, oldest_idx, A->m));
334     } else {
335       PetscCall(MatMultHermitianTransposeColumnRange(A->vecs, x, y_, 0, next_idx));
336       PetscCall(MatMultHermitianTransposeColumnRange(A->vecs, x, y_, oldest_idx, A->m));
337     }
338   }
339   if (alpha != 1.0 || (beta != 1.0 && beta != 0.0)) {
340     PetscCall(VecAXPBYCyclic(oldest, next, alpha, y_, beta, y));
341     PetscCall(LMBasisRestoreWorkRow(A, &y_));
342   }
343   PetscCall(PetscLogEventEnd(LMBASIS_GEMVH, NULL, NULL, NULL, NULL));
344   PetscFunctionReturn(PETSC_SUCCESS);
345 }
346 
LMBasisGEMMH_Internal(Mat A,Mat B,PetscScalar alpha,PetscScalar beta,Mat G)347 static PetscErrorCode LMBasisGEMMH_Internal(Mat A, Mat B, PetscScalar alpha, PetscScalar beta, Mat G)
348 {
349   PetscFunctionBegin;
350   if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(A));
351   if (beta != 0.0) {
352     Mat G_alloc;
353 
354     if (beta != 1.0) PetscCall(MatScale(G, beta));
355     PetscCall(MatTransposeMatMult(A, B, MAT_INITIAL_MATRIX, PETSC_DECIDE, &G_alloc));
356     PetscCall(MatAXPY(G, alpha, G_alloc, DIFFERENT_NONZERO_PATTERN));
357     PetscCall(MatDestroy(&G_alloc));
358   } else {
359     PetscCall(MatProductClear(G));
360     PetscCall(MatProductCreateWithMat(A, B, NULL, G));
361     PetscCall(MatProductSetType(G, MATPRODUCT_AtB));
362     PetscCall(MatProductSetFromOptions(G));
363     PetscCall(MatProductSymbolic(G));
364     PetscCall(MatProductNumeric(G));
365     if (alpha != 1.0) PetscCall(MatScale(G, alpha));
366   }
367   if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(A));
368   PetscFunctionReturn(PETSC_SUCCESS);
369 }
370 
LMBasisGEMMH(LMBasis A,PetscInt a_oldest,PetscInt a_next,LMBasis B,PetscInt b_oldest,PetscInt b_next,PetscScalar alpha,PetscScalar beta,Mat G)371 PETSC_INTERN PetscErrorCode LMBasisGEMMH(LMBasis A, PetscInt a_oldest, PetscInt a_next, LMBasis B, PetscInt b_oldest, PetscInt b_next, PetscScalar alpha, PetscScalar beta, Mat G)
372 {
373   PetscInt a_lim = a_next - a_oldest;
374   PetscInt b_lim = b_next - b_oldest;
375 
376   PetscFunctionBegin;
377   if (a_lim <= 0 || b_lim <= 0) PetscFunctionReturn(PETSC_SUCCESS);
378   PetscCall(PetscLogEventBegin(LMBASIS_GEMM, NULL, NULL, NULL, NULL));
379   PetscCall(LMBasisMultCheck(A, a_oldest, a_next));
380   PetscCall(LMBasisMultCheck(B, b_oldest, b_next));
381   if (b_lim == 1) {
382     Vec b;
383     Vec g;
384 
385     PetscCall(LMBasisGetVecRead(B, b_oldest, &b));
386     PetscCall(MatDenseGetColumnVec(G, b_oldest % B->m, &g));
387     PetscCall(LMBasisGEMVH(A, a_oldest, a_next, alpha, b, beta, g));
388     PetscCall(MatDenseRestoreColumnVec(G, b_oldest % B->m, &g));
389     PetscCall(LMBasisRestoreVecRead(B, b_oldest, &b));
390   } else if (a_lim == 1) {
391     Vec a;
392     Vec g;
393 
394     PetscCall(LMBasisGetVecRead(A, a_oldest, &a));
395     PetscCall(LMBasisGetWorkRow(B, &g));
396     PetscCall(LMBasisGEMVH(B, b_oldest, b_next, 1.0, a, 0.0, g));
397     if (PetscDefined(USE_COMPLEX)) PetscCall(VecConjugate(g));
398     PetscCall(MatSeqDenseRowAXPBYCyclic(b_oldest, b_next, alpha, g, beta, G, a_oldest));
399     PetscCall(LMBasisRestoreWorkRow(B, &g));
400     PetscCall(LMBasisRestoreVecRead(A, a_oldest, &a));
401   } else {
402     PetscInt a_next_idx        = ((a_next - 1) % A->m) + 1;
403     PetscInt a_oldest_idx      = a_oldest % A->m;
404     PetscInt b_next_idx        = ((b_next - 1) % B->m) + 1;
405     PetscInt b_oldest_idx      = b_oldest % B->m;
406     PetscInt a_intervals[2][2] = {
407       {0,            a_next_idx},
408       {a_oldest_idx, A->m      }
409     };
410     PetscInt b_intervals[2][2] = {
411       {0,            b_next_idx},
412       {b_oldest_idx, B->m      }
413     };
414     PetscInt a_num_intervals = 2;
415     PetscInt b_num_intervals = 2;
416 
417     if (a_lim == A->m || a_oldest_idx < a_next_idx) {
418       a_num_intervals = 1;
419       if (a_lim == A->m) {
420         a_intervals[0][0] = 0;
421         a_intervals[0][1] = A->m;
422       } else {
423         a_intervals[0][0] = a_oldest_idx;
424         a_intervals[0][1] = a_next_idx;
425       }
426     }
427     if (b_lim == B->m || b_oldest_idx < b_next_idx) {
428       b_num_intervals = 1;
429       if (b_lim == B->m) {
430         b_intervals[0][0] = 0;
431         b_intervals[0][1] = B->m;
432       } else {
433         b_intervals[0][0] = b_oldest_idx;
434         b_intervals[0][1] = b_next_idx;
435       }
436     }
437     for (PetscInt i = 0; i < a_num_intervals; i++) {
438       Mat sub_A = A->vecs;
439       Mat sub_A_;
440 
441       if (a_intervals[i][0] != 0 || a_intervals[i][1] != A->m) PetscCall(MatDenseGetSubMatrix(A->vecs, PETSC_DECIDE, PETSC_DECIDE, a_intervals[i][0], a_intervals[i][1], &sub_A));
442       sub_A_ = sub_A;
443 
444       for (PetscInt j = 0; j < b_num_intervals; j++) {
445         Mat sub_B = B->vecs;
446         Mat sub_G = G;
447 
448         if (b_intervals[j][0] != 0 || b_intervals[j][1] != B->m) {
449           if (sub_A_ == sub_A && sub_A != A->vecs && B->vecs == A->vecs) {
450             /* We're hampered by the fact that you can only get one submatrix from a MatDense at a time.  This case
451              * should not happen often, copying here is acceptable */
452             PetscCall(MatDuplicate(sub_A, MAT_COPY_VALUES, &sub_A_));
453             PetscCall(MatDenseRestoreSubMatrix(A->vecs, &sub_A));
454             sub_A = A->vecs;
455           }
456           PetscCall(MatDenseGetSubMatrix(B->vecs, PETSC_DECIDE, PETSC_DECIDE, b_intervals[j][0], b_intervals[j][1], &sub_B));
457         }
458 
459         if (sub_A_ != A->vecs || sub_B != B->vecs) PetscCall(MatDenseGetSubMatrix(G, a_intervals[i][0], a_intervals[i][1], b_intervals[j][0], b_intervals[j][1], &sub_G));
460 
461         PetscCall(LMBasisGEMMH_Internal(sub_A_, sub_B, alpha, beta, sub_G));
462 
463         if (sub_G != G) PetscCall(MatDenseRestoreSubMatrix(G, &sub_G));
464         if (sub_B != B->vecs) PetscCall(MatDenseRestoreSubMatrix(B->vecs, &sub_B));
465       }
466 
467       if (sub_A_ != sub_A) PetscCall(MatDestroy(&sub_A_));
468       if (sub_A != A->vecs) PetscCall(MatDenseRestoreSubMatrix(A->vecs, &sub_A));
469     }
470   }
471   PetscCall(PetscLogEventEnd(LMBASIS_GEMM, NULL, NULL, NULL, NULL));
472   PetscFunctionReturn(PETSC_SUCCESS);
473 }
474 
LMBasisReset(LMBasis basis)475 PETSC_INTERN PetscErrorCode LMBasisReset(LMBasis basis)
476 {
477   PetscFunctionBegin;
478   if (basis) {
479     basis->k = 0;
480     PetscCall(VecDestroy(&basis->cached_product));
481     basis->cached_vec_id    = 0;
482     basis->cached_vec_state = 0;
483     basis->operator_id      = 0;
484     basis->operator_state   = 0;
485   }
486   PetscFunctionReturn(PETSC_SUCCESS);
487 }
488 
LMBasisSetCachedProduct(LMBasis A,Vec x,Vec Ax)489 PETSC_INTERN PetscErrorCode LMBasisSetCachedProduct(LMBasis A, Vec x, Vec Ax)
490 {
491   PetscFunctionBegin;
492   if (x == NULL) {
493     A->cached_vec_id    = 0;
494     A->cached_vec_state = 0;
495   } else {
496     PetscCall(PetscObjectGetId((PetscObject)x, &A->cached_vec_id));
497     PetscCall(PetscObjectStateGet((PetscObject)x, &A->cached_vec_state));
498   }
499   PetscCall(PetscObjectReference((PetscObject)Ax));
500   PetscCall(VecDestroy(&A->cached_product));
501   A->cached_product = Ax;
502   PetscFunctionReturn(PETSC_SUCCESS);
503 }
504