1 #include <petsc/private/petscimpl.h>
2 #include "lmbasis.h"
3 #include "blas_cyclic/blas_cyclic.h"
4
5 PetscLogEvent LMBASIS_GEMM, LMBASIS_GEMV, LMBASIS_GEMVH;
6
LMBasisCreate(Vec v,PetscInt m,LMBasis * basis_p)7 PetscErrorCode LMBasisCreate(Vec v, PetscInt m, LMBasis *basis_p)
8 {
9 PetscInt n, N;
10 PetscMPIInt rank;
11 Mat backing;
12 VecType type;
13 LMBasis basis;
14
15 PetscFunctionBegin;
16 PetscValidHeaderSpecific(v, VEC_CLASSID, 1);
17 PetscValidLogicalCollectiveInt(v, m, 2);
18 PetscCheck(m >= 0, PetscObjectComm((PetscObject)v), PETSC_ERR_ARG_OUTOFRANGE, "Requested window size %" PetscInt_FMT " is not >= 0", m);
19 PetscCall(VecGetLocalSize(v, &n));
20 PetscCall(VecGetSize(v, &N));
21 PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)v), &rank));
22 PetscCall(VecGetType(v, &type));
23 PetscCall(MatCreateDenseFromVecType(PetscObjectComm((PetscObject)v), type, n, rank == 0 ? m : 0, N, m, n, NULL, &backing));
24 PetscCall(PetscNew(&basis));
25 *basis_p = basis;
26 basis->m = m;
27 basis->k = 0;
28 basis->vecs = backing;
29 PetscFunctionReturn(PETSC_SUCCESS);
30 }
31
LMBasisGetVec_Internal(LMBasis basis,PetscInt idx,PetscMemoryAccessMode mode,Vec * single,PetscBool check_idx)32 static PetscErrorCode LMBasisGetVec_Internal(LMBasis basis, PetscInt idx, PetscMemoryAccessMode mode, Vec *single, PetscBool check_idx)
33 {
34 PetscFunctionBegin;
35 PetscAssertPointer(basis, 1);
36 if (check_idx) {
37 PetscValidLogicalCollectiveInt(basis->vecs, idx, 2);
38 PetscCheck(idx < basis->k, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Asked for index %" PetscInt_FMT " >= number of inserted vecs %" PetscInt_FMT, idx, basis->k);
39 PetscInt earliest = PetscMax(0, basis->k - basis->m);
40 PetscCheck(idx >= earliest, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Asked for index %" PetscInt_FMT " < the earliest retained index % " PetscInt_FMT, idx, earliest);
41 }
42 PetscAssert(mode == PETSC_MEMORY_ACCESS_READ || mode == PETSC_MEMORY_ACCESS_WRITE, PETSC_COMM_SELF, PETSC_ERR_PLIB, "READ_WRITE access not implemented");
43 if (mode == PETSC_MEMORY_ACCESS_READ) {
44 PetscCall(MatDenseGetColumnVecRead(basis->vecs, idx % basis->m, single));
45 } else {
46 PetscCall(MatDenseGetColumnVecWrite(basis->vecs, idx % basis->m, single));
47 }
48 PetscFunctionReturn(PETSC_SUCCESS);
49 }
50
LMBasisGetVec(LMBasis basis,PetscInt idx,PetscMemoryAccessMode mode,Vec * single)51 PETSC_INTERN PetscErrorCode LMBasisGetVec(LMBasis basis, PetscInt idx, PetscMemoryAccessMode mode, Vec *single)
52 {
53 PetscFunctionBegin;
54 PetscCall(LMBasisGetVec_Internal(basis, idx, mode, single, PETSC_TRUE));
55 PetscFunctionReturn(PETSC_SUCCESS);
56 }
57
LMBasisRestoreVec(LMBasis basis,PetscInt idx,PetscMemoryAccessMode mode,Vec * single)58 PETSC_INTERN PetscErrorCode LMBasisRestoreVec(LMBasis basis, PetscInt idx, PetscMemoryAccessMode mode, Vec *single)
59 {
60 PetscFunctionBegin;
61 PetscAssertPointer(basis, 1);
62 PetscAssert(mode == PETSC_MEMORY_ACCESS_READ || mode == PETSC_MEMORY_ACCESS_WRITE, PETSC_COMM_SELF, PETSC_ERR_PLIB, "READ_WRITE access not implemented");
63 if (mode == PETSC_MEMORY_ACCESS_READ) {
64 PetscCall(MatDenseRestoreColumnVecRead(basis->vecs, idx % basis->m, single));
65 } else {
66 PetscCall(MatDenseRestoreColumnVecWrite(basis->vecs, idx % basis->m, single));
67 }
68 *single = NULL;
69 PetscFunctionReturn(PETSC_SUCCESS);
70 }
71
LMBasisGetVecRead(LMBasis B,PetscInt i,Vec * b)72 PETSC_INTERN PetscErrorCode LMBasisGetVecRead(LMBasis B, PetscInt i, Vec *b)
73 {
74 return LMBasisGetVec(B, i, PETSC_MEMORY_ACCESS_READ, b);
75 }
LMBasisRestoreVecRead(LMBasis B,PetscInt i,Vec * b)76 PETSC_INTERN PetscErrorCode LMBasisRestoreVecRead(LMBasis B, PetscInt i, Vec *b)
77 {
78 return LMBasisRestoreVec(B, i, PETSC_MEMORY_ACCESS_READ, b);
79 }
80
LMBasisGetNextVec(LMBasis basis,Vec * single)81 PETSC_INTERN PetscErrorCode LMBasisGetNextVec(LMBasis basis, Vec *single)
82 {
83 PetscFunctionBegin;
84 PetscCall(LMBasisGetVec_Internal(basis, basis->k, PETSC_MEMORY_ACCESS_WRITE, single, PETSC_FALSE));
85 PetscFunctionReturn(PETSC_SUCCESS);
86 }
87
LMBasisRestoreNextVec(LMBasis basis,Vec * single)88 PETSC_INTERN PetscErrorCode LMBasisRestoreNextVec(LMBasis basis, Vec *single)
89 {
90 PetscFunctionBegin;
91 PetscAssertPointer(basis, 1);
92 PetscCall(LMBasisRestoreVec(basis, basis->k++, PETSC_MEMORY_ACCESS_WRITE, single));
93 // basis is updated, invalidate cached product
94 basis->cached_vec_id = 0;
95 basis->cached_vec_state = 0;
96 PetscFunctionReturn(PETSC_SUCCESS);
97 }
98
LMBasisSetNextVec(LMBasis basis,Vec single)99 PETSC_INTERN PetscErrorCode LMBasisSetNextVec(LMBasis basis, Vec single)
100 {
101 Vec next;
102
103 PetscFunctionBegin;
104 PetscCall(LMBasisGetNextVec(basis, &next));
105 PetscCall(VecCopy(single, next));
106 PetscCall(LMBasisRestoreNextVec(basis, &next));
107 PetscFunctionReturn(PETSC_SUCCESS);
108 }
109
LMBasisDestroy(LMBasis * basis_p)110 PETSC_INTERN PetscErrorCode LMBasisDestroy(LMBasis *basis_p)
111 {
112 LMBasis basis = *basis_p;
113
114 PetscFunctionBegin;
115 *basis_p = NULL;
116 if (basis == NULL) PetscFunctionReturn(PETSC_SUCCESS);
117 PetscCall(LMBasisReset(basis));
118 PetscCheck(basis->work_vecs_in_use == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Work vecs are still checked out at destruction");
119 {
120 VecLink head = basis->work_vecs_available;
121
122 while (head) {
123 VecLink next = head->next;
124
125 PetscCall(VecDestroy(&head->vec));
126 PetscCall(PetscFree(head));
127 head = next;
128 }
129 }
130 PetscCheck(basis->work_rows_in_use == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Work rows are still checked out at destruction");
131 {
132 VecLink head = basis->work_rows_available;
133
134 while (head) {
135 VecLink next = head->next;
136
137 PetscCall(VecDestroy(&head->vec));
138 PetscCall(PetscFree(head));
139 head = next;
140 }
141 }
142 PetscCall(MatDestroy(&basis->vecs));
143 PetscCall(PetscFree(basis));
144 PetscFunctionReturn(PETSC_SUCCESS);
145 }
146
LMBasisGetWorkVec(LMBasis basis,Vec * vec_p)147 PETSC_INTERN PetscErrorCode LMBasisGetWorkVec(LMBasis basis, Vec *vec_p)
148 {
149 VecLink link;
150
151 PetscFunctionBegin;
152 if (!basis->work_vecs_available) {
153 PetscCall(PetscNew(&basis->work_vecs_available));
154 PetscCall(MatCreateVecs(basis->vecs, NULL, &basis->work_vecs_available->vec));
155 }
156 link = basis->work_vecs_available;
157 basis->work_vecs_available = link->next;
158 link->next = basis->work_vecs_in_use;
159 basis->work_vecs_in_use = link;
160
161 *vec_p = link->vec;
162 link->vec = NULL;
163 PetscFunctionReturn(PETSC_SUCCESS);
164 }
165
LMBasisRestoreWorkVec(LMBasis basis,Vec * vec_p)166 PETSC_INTERN PetscErrorCode LMBasisRestoreWorkVec(LMBasis basis, Vec *vec_p)
167 {
168 Vec v = *vec_p;
169 VecLink link = NULL;
170
171 PetscFunctionBegin;
172 *vec_p = NULL;
173 PetscCheck(basis->work_vecs_in_use, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Trying to check in a vec that wasn't checked out");
174 link = basis->work_vecs_in_use;
175 basis->work_vecs_in_use = link->next;
176 link->next = basis->work_vecs_available;
177 basis->work_vecs_available = link;
178
179 PetscAssert(link->vec == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_PLIB, "Link not ready to return vector");
180 link->vec = v;
181 PetscFunctionReturn(PETSC_SUCCESS);
182 }
183
LMBasisCreateRow(LMBasis basis,Vec * row_p)184 PETSC_INTERN PetscErrorCode LMBasisCreateRow(LMBasis basis, Vec *row_p)
185 {
186 PetscFunctionBegin;
187 PetscCall(MatCreateVecs(basis->vecs, row_p, NULL));
188 PetscFunctionReturn(PETSC_SUCCESS);
189 }
190
LMBasisGetWorkRow(LMBasis basis,Vec * row_p)191 PETSC_INTERN PetscErrorCode LMBasisGetWorkRow(LMBasis basis, Vec *row_p)
192 {
193 VecLink link;
194
195 PetscFunctionBegin;
196 if (!basis->work_rows_available) {
197 PetscCall(PetscNew(&basis->work_rows_available));
198 PetscCall(MatCreateVecs(basis->vecs, &basis->work_rows_available->vec, NULL));
199 }
200 link = basis->work_rows_available;
201 basis->work_rows_available = link->next;
202 link->next = basis->work_rows_in_use;
203 basis->work_rows_in_use = link;
204
205 PetscCall(VecZeroEntries(link->vec));
206 *row_p = link->vec;
207 link->vec = NULL;
208 PetscFunctionReturn(PETSC_SUCCESS);
209 }
210
LMBasisRestoreWorkRow(LMBasis basis,Vec * row_p)211 PETSC_INTERN PetscErrorCode LMBasisRestoreWorkRow(LMBasis basis, Vec *row_p)
212 {
213 Vec v = *row_p;
214 VecLink link = NULL;
215
216 PetscFunctionBegin;
217 *row_p = NULL;
218 PetscCheck(basis->work_rows_in_use, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Trying to check in a row that wasn't checked out");
219 link = basis->work_rows_in_use;
220 basis->work_rows_in_use = link->next;
221 link->next = basis->work_rows_available;
222 basis->work_rows_available = link;
223
224 PetscAssert(link->vec == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_PLIB, "Link not ready to return vector");
225 link->vec = v;
226 PetscFunctionReturn(PETSC_SUCCESS);
227 }
228
LMBasisCopy(LMBasis basis_a,LMBasis basis_b)229 PETSC_INTERN PetscErrorCode LMBasisCopy(LMBasis basis_a, LMBasis basis_b)
230 {
231 PetscFunctionBegin;
232 PetscCheck(basis_a->m == basis_b->m, PetscObjectComm((PetscObject)basis_a), PETSC_ERR_ARG_SIZ, "Copy target has different number of vecs, %" PetscInt_FMT " != %" PetscInt_FMT, basis_b->m, basis_a->m);
233 basis_b->k = basis_a->k;
234 PetscCall(MatCopy(basis_a->vecs, basis_b->vecs, SAME_NONZERO_PATTERN));
235 basis_b->cached_vec_id = basis_a->cached_vec_id;
236 basis_b->cached_vec_state = basis_a->cached_vec_state;
237 if (basis_a->cached_product) {
238 if (!basis_b->cached_product) PetscCall(VecDuplicate(basis_a->cached_product, &basis_b->cached_product));
239 PetscCall(VecCopy(basis_a->cached_product, basis_b->cached_product));
240 }
241 PetscFunctionReturn(PETSC_SUCCESS);
242 }
243
LMBasisGetRange(LMBasis basis,PetscInt * oldest,PetscInt * next)244 PETSC_INTERN PetscErrorCode LMBasisGetRange(LMBasis basis, PetscInt *oldest, PetscInt *next)
245 {
246 PetscFunctionBegin;
247 *next = basis->k;
248 *oldest = PetscMax(0, basis->k - basis->m);
249 PetscFunctionReturn(PETSC_SUCCESS);
250 }
251
LMBasisMultCheck(LMBasis A,PetscInt oldest,PetscInt next)252 static PetscErrorCode LMBasisMultCheck(LMBasis A, PetscInt oldest, PetscInt next)
253 {
254 PetscInt basis_oldest, basis_next;
255
256 PetscFunctionBegin;
257 PetscCall(LMBasisGetRange(A, &basis_oldest, &basis_next));
258 PetscCheck(oldest >= basis_oldest && next <= basis_next, PetscObjectComm((PetscObject)A->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Asked for vec that hasn't been computed or is no longer stored");
259 PetscFunctionReturn(PETSC_SUCCESS);
260 }
261
LMBasisGEMV(LMBasis A,PetscInt oldest,PetscInt next,PetscScalar alpha,Vec x,PetscScalar beta,Vec y)262 PETSC_INTERN PetscErrorCode LMBasisGEMV(LMBasis A, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y)
263 {
264 PetscInt lim = next - oldest;
265 PetscInt next_idx = ((next - 1) % A->m) + 1;
266 PetscInt oldest_idx = oldest % A->m;
267 Vec x_work = NULL;
268 Vec x_ = x;
269
270 PetscFunctionBegin;
271 if (lim <= 0) PetscFunctionReturn(PETSC_SUCCESS);
272 PetscCall(PetscLogEventBegin(LMBASIS_GEMV, NULL, NULL, NULL, NULL));
273 PetscCall(LMBasisMultCheck(A, oldest, next));
274 if (alpha != 1.0) {
275 PetscCall(LMBasisGetWorkRow(A, &x_work));
276 PetscCall(VecAXPBYCyclic(oldest, next, alpha, x, 0.0, x_work));
277 x_ = x_work;
278 }
279 if (beta != 1.0 && beta != 0.0) PetscCall(VecScale(y, beta));
280 if (lim == A->m) {
281 // all vectors are used
282 if (beta == 0.0) PetscCall(MatMult(A->vecs, x_, y));
283 else PetscCall(MatMultAdd(A->vecs, x_, y, y));
284 } else if (oldest_idx < next_idx) {
285 // contiguous vectors are used
286 if (beta == 0.0) PetscCall(MatMultColumnRange(A->vecs, x_, y, oldest_idx, next_idx));
287 else PetscCall(MatMultAddColumnRange(A->vecs, x_, y, y, oldest_idx, next_idx));
288 } else {
289 if (beta == 0.0) PetscCall(MatMultColumnRange(A->vecs, x_, y, 0, next_idx));
290 else PetscCall(MatMultAddColumnRange(A->vecs, x_, y, y, 0, next_idx));
291 PetscCall(MatMultAddColumnRange(A->vecs, x_, y, y, oldest_idx, A->m));
292 }
293 if (alpha != 1.0) PetscCall(LMBasisRestoreWorkRow(A, &x_work));
294 PetscCall(PetscLogEventEnd(LMBASIS_GEMV, NULL, NULL, NULL, NULL));
295 PetscFunctionReturn(PETSC_SUCCESS);
296 }
297
LMBasisGEMVH(LMBasis A,PetscInt oldest,PetscInt next,PetscScalar alpha,Vec x,PetscScalar beta,Vec y)298 PETSC_INTERN PetscErrorCode LMBasisGEMVH(LMBasis A, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y)
299 {
300 PetscInt lim = next - oldest;
301 PetscInt next_idx = ((next - 1) % A->m) + 1;
302 PetscInt oldest_idx = oldest % A->m;
303 Vec y_ = y;
304
305 PetscFunctionBegin;
306 if (lim <= 0) PetscFunctionReturn(PETSC_SUCCESS);
307 PetscCall(LMBasisMultCheck(A, oldest, next));
308 if (A->cached_product && A->cached_vec_id != 0 && A->cached_vec_state != 0) {
309 // see if x is the cached input vector
310 PetscObjectId x_id;
311 PetscObjectState x_state;
312
313 PetscCall(PetscObjectGetId((PetscObject)x, &x_id));
314 PetscCall(PetscObjectStateGet((PetscObject)x, &x_state));
315 if (x_id == A->cached_vec_id && x_state == A->cached_vec_state) {
316 PetscCall(VecAXPBYCyclic(oldest, next, alpha, A->cached_product, beta, y));
317 PetscFunctionReturn(PETSC_SUCCESS);
318 }
319 }
320 PetscCall(PetscLogEventBegin(LMBASIS_GEMVH, NULL, NULL, NULL, NULL));
321 if (alpha != 1.0 || (beta != 1.0 && beta != 0.0)) PetscCall(LMBasisGetWorkRow(A, &y_));
322 if (lim == A->m) {
323 // all vectors are used
324 if (alpha == 1.0 && beta == 1.0) PetscCall(MatMultHermitianTransposeAdd(A->vecs, x, y_, y_));
325 else PetscCall(MatMultHermitianTranspose(A->vecs, x, y_));
326 } else if (oldest_idx < next_idx) {
327 // contiguous vectors are used
328 if (alpha == 1.0 && beta == 1.0) PetscCall(MatMultHermitianTransposeAddColumnRange(A->vecs, x, y_, y_, oldest_idx, next_idx));
329 else PetscCall(MatMultHermitianTransposeColumnRange(A->vecs, x, y_, oldest_idx, next_idx));
330 } else {
331 if (alpha == 1.0 && beta == 1.0) {
332 PetscCall(MatMultHermitianTransposeAddColumnRange(A->vecs, x, y_, y_, 0, next_idx));
333 PetscCall(MatMultHermitianTransposeAddColumnRange(A->vecs, x, y_, y_, oldest_idx, A->m));
334 } else {
335 PetscCall(MatMultHermitianTransposeColumnRange(A->vecs, x, y_, 0, next_idx));
336 PetscCall(MatMultHermitianTransposeColumnRange(A->vecs, x, y_, oldest_idx, A->m));
337 }
338 }
339 if (alpha != 1.0 || (beta != 1.0 && beta != 0.0)) {
340 PetscCall(VecAXPBYCyclic(oldest, next, alpha, y_, beta, y));
341 PetscCall(LMBasisRestoreWorkRow(A, &y_));
342 }
343 PetscCall(PetscLogEventEnd(LMBASIS_GEMVH, NULL, NULL, NULL, NULL));
344 PetscFunctionReturn(PETSC_SUCCESS);
345 }
346
LMBasisGEMMH_Internal(Mat A,Mat B,PetscScalar alpha,PetscScalar beta,Mat G)347 static PetscErrorCode LMBasisGEMMH_Internal(Mat A, Mat B, PetscScalar alpha, PetscScalar beta, Mat G)
348 {
349 PetscFunctionBegin;
350 if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(A));
351 if (beta != 0.0) {
352 Mat G_alloc;
353
354 if (beta != 1.0) PetscCall(MatScale(G, beta));
355 PetscCall(MatTransposeMatMult(A, B, MAT_INITIAL_MATRIX, PETSC_DECIDE, &G_alloc));
356 PetscCall(MatAXPY(G, alpha, G_alloc, DIFFERENT_NONZERO_PATTERN));
357 PetscCall(MatDestroy(&G_alloc));
358 } else {
359 PetscCall(MatProductClear(G));
360 PetscCall(MatProductCreateWithMat(A, B, NULL, G));
361 PetscCall(MatProductSetType(G, MATPRODUCT_AtB));
362 PetscCall(MatProductSetFromOptions(G));
363 PetscCall(MatProductSymbolic(G));
364 PetscCall(MatProductNumeric(G));
365 if (alpha != 1.0) PetscCall(MatScale(G, alpha));
366 }
367 if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(A));
368 PetscFunctionReturn(PETSC_SUCCESS);
369 }
370
LMBasisGEMMH(LMBasis A,PetscInt a_oldest,PetscInt a_next,LMBasis B,PetscInt b_oldest,PetscInt b_next,PetscScalar alpha,PetscScalar beta,Mat G)371 PETSC_INTERN PetscErrorCode LMBasisGEMMH(LMBasis A, PetscInt a_oldest, PetscInt a_next, LMBasis B, PetscInt b_oldest, PetscInt b_next, PetscScalar alpha, PetscScalar beta, Mat G)
372 {
373 PetscInt a_lim = a_next - a_oldest;
374 PetscInt b_lim = b_next - b_oldest;
375
376 PetscFunctionBegin;
377 if (a_lim <= 0 || b_lim <= 0) PetscFunctionReturn(PETSC_SUCCESS);
378 PetscCall(PetscLogEventBegin(LMBASIS_GEMM, NULL, NULL, NULL, NULL));
379 PetscCall(LMBasisMultCheck(A, a_oldest, a_next));
380 PetscCall(LMBasisMultCheck(B, b_oldest, b_next));
381 if (b_lim == 1) {
382 Vec b;
383 Vec g;
384
385 PetscCall(LMBasisGetVecRead(B, b_oldest, &b));
386 PetscCall(MatDenseGetColumnVec(G, b_oldest % B->m, &g));
387 PetscCall(LMBasisGEMVH(A, a_oldest, a_next, alpha, b, beta, g));
388 PetscCall(MatDenseRestoreColumnVec(G, b_oldest % B->m, &g));
389 PetscCall(LMBasisRestoreVecRead(B, b_oldest, &b));
390 } else if (a_lim == 1) {
391 Vec a;
392 Vec g;
393
394 PetscCall(LMBasisGetVecRead(A, a_oldest, &a));
395 PetscCall(LMBasisGetWorkRow(B, &g));
396 PetscCall(LMBasisGEMVH(B, b_oldest, b_next, 1.0, a, 0.0, g));
397 if (PetscDefined(USE_COMPLEX)) PetscCall(VecConjugate(g));
398 PetscCall(MatSeqDenseRowAXPBYCyclic(b_oldest, b_next, alpha, g, beta, G, a_oldest));
399 PetscCall(LMBasisRestoreWorkRow(B, &g));
400 PetscCall(LMBasisRestoreVecRead(A, a_oldest, &a));
401 } else {
402 PetscInt a_next_idx = ((a_next - 1) % A->m) + 1;
403 PetscInt a_oldest_idx = a_oldest % A->m;
404 PetscInt b_next_idx = ((b_next - 1) % B->m) + 1;
405 PetscInt b_oldest_idx = b_oldest % B->m;
406 PetscInt a_intervals[2][2] = {
407 {0, a_next_idx},
408 {a_oldest_idx, A->m }
409 };
410 PetscInt b_intervals[2][2] = {
411 {0, b_next_idx},
412 {b_oldest_idx, B->m }
413 };
414 PetscInt a_num_intervals = 2;
415 PetscInt b_num_intervals = 2;
416
417 if (a_lim == A->m || a_oldest_idx < a_next_idx) {
418 a_num_intervals = 1;
419 if (a_lim == A->m) {
420 a_intervals[0][0] = 0;
421 a_intervals[0][1] = A->m;
422 } else {
423 a_intervals[0][0] = a_oldest_idx;
424 a_intervals[0][1] = a_next_idx;
425 }
426 }
427 if (b_lim == B->m || b_oldest_idx < b_next_idx) {
428 b_num_intervals = 1;
429 if (b_lim == B->m) {
430 b_intervals[0][0] = 0;
431 b_intervals[0][1] = B->m;
432 } else {
433 b_intervals[0][0] = b_oldest_idx;
434 b_intervals[0][1] = b_next_idx;
435 }
436 }
437 for (PetscInt i = 0; i < a_num_intervals; i++) {
438 Mat sub_A = A->vecs;
439 Mat sub_A_;
440
441 if (a_intervals[i][0] != 0 || a_intervals[i][1] != A->m) PetscCall(MatDenseGetSubMatrix(A->vecs, PETSC_DECIDE, PETSC_DECIDE, a_intervals[i][0], a_intervals[i][1], &sub_A));
442 sub_A_ = sub_A;
443
444 for (PetscInt j = 0; j < b_num_intervals; j++) {
445 Mat sub_B = B->vecs;
446 Mat sub_G = G;
447
448 if (b_intervals[j][0] != 0 || b_intervals[j][1] != B->m) {
449 if (sub_A_ == sub_A && sub_A != A->vecs && B->vecs == A->vecs) {
450 /* We're hampered by the fact that you can only get one submatrix from a MatDense at a time. This case
451 * should not happen often, copying here is acceptable */
452 PetscCall(MatDuplicate(sub_A, MAT_COPY_VALUES, &sub_A_));
453 PetscCall(MatDenseRestoreSubMatrix(A->vecs, &sub_A));
454 sub_A = A->vecs;
455 }
456 PetscCall(MatDenseGetSubMatrix(B->vecs, PETSC_DECIDE, PETSC_DECIDE, b_intervals[j][0], b_intervals[j][1], &sub_B));
457 }
458
459 if (sub_A_ != A->vecs || sub_B != B->vecs) PetscCall(MatDenseGetSubMatrix(G, a_intervals[i][0], a_intervals[i][1], b_intervals[j][0], b_intervals[j][1], &sub_G));
460
461 PetscCall(LMBasisGEMMH_Internal(sub_A_, sub_B, alpha, beta, sub_G));
462
463 if (sub_G != G) PetscCall(MatDenseRestoreSubMatrix(G, &sub_G));
464 if (sub_B != B->vecs) PetscCall(MatDenseRestoreSubMatrix(B->vecs, &sub_B));
465 }
466
467 if (sub_A_ != sub_A) PetscCall(MatDestroy(&sub_A_));
468 if (sub_A != A->vecs) PetscCall(MatDenseRestoreSubMatrix(A->vecs, &sub_A));
469 }
470 }
471 PetscCall(PetscLogEventEnd(LMBASIS_GEMM, NULL, NULL, NULL, NULL));
472 PetscFunctionReturn(PETSC_SUCCESS);
473 }
474
LMBasisReset(LMBasis basis)475 PETSC_INTERN PetscErrorCode LMBasisReset(LMBasis basis)
476 {
477 PetscFunctionBegin;
478 if (basis) {
479 basis->k = 0;
480 PetscCall(VecDestroy(&basis->cached_product));
481 basis->cached_vec_id = 0;
482 basis->cached_vec_state = 0;
483 basis->operator_id = 0;
484 basis->operator_state = 0;
485 }
486 PetscFunctionReturn(PETSC_SUCCESS);
487 }
488
LMBasisSetCachedProduct(LMBasis A,Vec x,Vec Ax)489 PETSC_INTERN PetscErrorCode LMBasisSetCachedProduct(LMBasis A, Vec x, Vec Ax)
490 {
491 PetscFunctionBegin;
492 if (x == NULL) {
493 A->cached_vec_id = 0;
494 A->cached_vec_state = 0;
495 } else {
496 PetscCall(PetscObjectGetId((PetscObject)x, &A->cached_vec_id));
497 PetscCall(PetscObjectStateGet((PetscObject)x, &A->cached_vec_state));
498 }
499 PetscCall(PetscObjectReference((PetscObject)Ax));
500 PetscCall(VecDestroy(&A->cached_product));
501 A->cached_product = Ax;
502 PetscFunctionReturn(PETSC_SUCCESS);
503 }
504