xref: /petsc/src/ksp/ksp/utils/lmvm/lmproducts.c (revision 58bddbc0aeb8e2276be3739270a4176cb222ba3a)
1 #include <petsc/private/petscimpl.h>
2 #include <petscmat.h>
3 #include <petscblaslapack.h>
4 #include <petscdevice.h>
5 #include "lmproducts.h"
6 #include "blas_cyclic/blas_cyclic.h"
7 
8 PetscLogEvent LMPROD_Mult, LMPROD_Solve, LMPROD_Update;
9 
LMProductsCreate(LMBasis basis,LMBlockType block_type,LMProducts * dots)10 PETSC_INTERN PetscErrorCode LMProductsCreate(LMBasis basis, LMBlockType block_type, LMProducts *dots)
11 {
12   PetscInt m, m_local;
13 
14   PetscFunctionBegin;
15   PetscAssertPointer(basis, 1);
16   PetscValidHeaderSpecific(basis->vecs, MAT_CLASSID, 1);
17   PetscCheck(block_type >= 0 && block_type < LMBLOCK_END, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Invalid LMBlockType");
18   PetscCall(PetscNew(dots));
19   (*dots)->m = m      = basis->m;
20   (*dots)->block_type = block_type;
21   PetscCall(MatGetLocalSize(basis->vecs, NULL, &m_local));
22   (*dots)->m_local = m_local;
23   if (block_type == LMBLOCK_DIAGONAL) {
24     VecType vec_type;
25 
26     PetscCall(MatCreateVecs(basis->vecs, &(*dots)->diagonal_global, NULL));
27     PetscCall(VecCreateLocalVector((*dots)->diagonal_global, &(*dots)->diagonal_local));
28     PetscCall(VecGetType((*dots)->diagonal_local, &vec_type));
29     PetscCall(VecCreate(PETSC_COMM_SELF, &(*dots)->diagonal_dup));
30     PetscCall(VecSetSizes((*dots)->diagonal_dup, m, m));
31     PetscCall(VecSetType((*dots)->diagonal_dup, vec_type));
32     PetscCall(VecSetUp((*dots)->diagonal_dup));
33   } else {
34     VecType vec_type;
35 
36     PetscCall(MatGetVecType(basis->vecs, &vec_type));
37     PetscCall(MatCreateDenseFromVecType(PetscObjectComm((PetscObject)basis->vecs), vec_type, m_local, m_local, m, m, m_local, NULL, &(*dots)->full));
38   }
39   PetscFunctionReturn(PETSC_SUCCESS);
40 }
41 
LMProductsDestroy(LMProducts * dots_p)42 PETSC_INTERN PetscErrorCode LMProductsDestroy(LMProducts *dots_p)
43 {
44   PetscFunctionBegin;
45   LMProducts dots = *dots_p;
46   if (dots == NULL) PetscFunctionReturn(PETSC_SUCCESS);
47   PetscCall(MatDestroy(&dots->full));
48   PetscCall(VecDestroy(&dots->diagonal_dup));
49   PetscCall(VecDestroy(&dots->diagonal_local));
50   PetscCall(VecDestroy(&dots->diagonal_global));
51   PetscCall(VecDestroy(&dots->rhs_local));
52   PetscCall(VecDestroy(&dots->lhs_local));
53   PetscCall(PetscFree(dots));
54   PetscFunctionReturn(PETSC_SUCCESS);
55 }
56 
LMProductsPrepare_Internal(LMProducts dots,PetscObjectId operator_id,PetscObjectState operator_state,PetscInt oldest,PetscInt next)57 static PetscErrorCode LMProductsPrepare_Internal(LMProducts dots, PetscObjectId operator_id, PetscObjectState operator_state, PetscInt oldest, PetscInt next)
58 {
59   PetscFunctionBegin;
60   if (dots->operator_id != operator_id || dots->operator_state != operator_state) {
61     // invalidate the block
62     dots->operator_id    = operator_id;
63     dots->operator_state = operator_state;
64     dots->k              = oldest;
65   }
66   dots->k = PetscMax(oldest, dots->k);
67   PetscFunctionReturn(PETSC_SUCCESS);
68 }
69 
LMProductsPrepareFromBases(LMProducts dots,LMBasis X,LMBasis Y)70 static PetscErrorCode LMProductsPrepareFromBases(LMProducts dots, LMBasis X, LMBasis Y)
71 {
72   PetscInt      oldest, next;
73   PetscObjectId operator_id    = (X->operator_id == 0) ? Y->operator_id : X->operator_id;
74   PetscObjectId operator_state = (X->operator_id == 0) ? Y->operator_state : X->operator_state;
75 
76   PetscFunctionBegin;
77   PetscCall(LMBasisGetRange(X, &oldest, &next));
78   PetscCall(LMProductsPrepare_Internal(dots, operator_id, operator_state, oldest, next));
79   PetscFunctionReturn(PETSC_SUCCESS);
80 }
81 
LMProductsPrepare(LMProducts dots,Mat op,PetscInt oldest,PetscInt next)82 PETSC_INTERN PetscErrorCode LMProductsPrepare(LMProducts dots, Mat op, PetscInt oldest, PetscInt next)
83 {
84   PetscObjectId    operator_id;
85   PetscObjectState operator_state;
86 
87   PetscFunctionBegin;
88   PetscCall(PetscObjectGetId((PetscObject)op, &operator_id));
89   PetscCall(PetscObjectStateGet((PetscObject)op, &operator_state));
90   PetscCall(LMProductsPrepare_Internal(dots, operator_id, operator_state, oldest, next));
91   PetscFunctionReturn(PETSC_SUCCESS);
92 }
93 
LMProductsUpdate_Internal(LMProducts dots,LMBasis X,LMBasis Y,PetscInt oldest,PetscInt next)94 static PetscErrorCode LMProductsUpdate_Internal(LMProducts dots, LMBasis X, LMBasis Y, PetscInt oldest, PetscInt next)
95 {
96   MPI_Comm comm = PetscObjectComm((PetscObject)X->vecs);
97   PetscInt start;
98 
99   PetscFunctionBegin;
100   PetscAssert(X->m == Y->m && X->m == dots->m, comm, PETSC_ERR_ARG_INCOMP, "X vecs, Y vecs, and dots incompatible in size, (%d, %d, %d)", (int)X->m, (int)Y->m, (int)dots->m);
101   PetscAssert(X->k == Y->k, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are incompatible in state, (%d, %d)", (int)X->k, (int)Y->k);
102   PetscAssert(dots->k <= X->k, comm, PETSC_ERR_ARG_INCOMP, "Dot products are ahead of X and Y, (%d, %d)", (int)dots->k, (int)X->k);
103   PetscAssert(X->operator_id == 0 || Y->operator_id == 0 || X->operator_id == Y->operator_id, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are from different operators");
104   PetscAssert(X->operator_id != Y->operator_id || Y->operator_state == X->operator_state, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are from different operator states");
105 
106   PetscCall(LMProductsPrepareFromBases(dots, X, Y));
107 
108   start = dots->k;
109   if (start == next) PetscFunctionReturn(PETSC_SUCCESS);
110   PetscCall(PetscLogEventBegin(LMPROD_Update, NULL, NULL, NULL, NULL));
111   switch (dots->block_type) {
112   case LMBLOCK_DIAGONAL:
113     for (PetscInt i = start; i < next; i++) {
114       Vec         x, y;
115       PetscScalar xTy;
116 
117       PetscCall(LMBasisGetVecRead(X, i, &x));
118       y = x;
119       if (Y != X) PetscCall(LMBasisGetVecRead(Y, i, &y));
120       PetscCall(VecDot(y, x, &xTy));
121       if (Y != X) PetscCall(LMBasisRestoreVecRead(Y, i, &y));
122       PetscCall(LMBasisRestoreVecRead(X, i, &x));
123       PetscCall(LMProductsInsertNextDiagonalValue(dots, i, xTy));
124     }
125     break;
126   case LMBLOCK_STRICT_UPPER_TRIANGLE: {
127     Mat local;
128 
129     PetscCall(MatDenseGetLocalMatrix(dots->full, &local));
130     // we have to proceed index by index because we want to zero each row after we compute the corresponding column
131     for (PetscInt i = start; i < next; i++) {
132       Mat row;
133       Vec column, y;
134 
135       PetscCall(LMBasisGetVecRead(Y, i, &y));
136       PetscCall(MatDenseGetColumnVec(dots->full, i % dots->m, &column));
137       PetscCall(LMBasisGEMVH(X, oldest, next, 1.0, y, 0.0, column));
138       PetscCall(MatDenseRestoreColumnVec(dots->full, i % dots->m, &column));
139       PetscCall(LMBasisRestoreVecRead(Y, i, &y));
140 
141       // zero out the new row
142       if (dots->m_local) {
143         PetscCall(MatDenseGetSubMatrix(local, i % dots->m, (i % dots->m) + 1, PETSC_DECIDE, PETSC_DECIDE, &row));
144         PetscCall(MatZeroEntries(row));
145         PetscCall(MatDenseRestoreSubMatrix(local, &row));
146       }
147     }
148   } break;
149   case LMBLOCK_UPPER_TRIANGLE: {
150     PetscInt mid       = next - (next % dots->m);
151     PetscInt start_idx = start % dots->m;
152     PetscInt next_idx  = ((next - 1) % dots->m) + 1;
153 
154     if (next_idx > start_idx) {
155       PetscCall(LMBasisGEMMH(X, oldest, next, Y, start, next, 1.0, 0.0, dots->full));
156     } else {
157       PetscCall(LMBasisGEMMH(X, oldest, mid, Y, start, mid, 1.0, 0.0, dots->full));
158       PetscCall(LMBasisGEMMH(X, oldest, next, Y, mid, next, 1.0, 0.0, dots->full));
159     }
160   } break;
161   case LMBLOCK_FULL:
162     PetscCall(LMBasisGEMMH(X, oldest, next, Y, start, next, 1.0, 0.0, dots->full));
163     PetscCall(LMBasisGEMMH(X, start, next, Y, oldest, start, 1.0, 0.0, dots->full));
164     break;
165   default:
166     PetscUnreachable();
167   }
168   dots->k = next;
169   if (dots->debug) {
170     const PetscScalar *values = NULL;
171     PetscInt           lda;
172     PetscInt           N;
173 
174     PetscCall(MatGetSize(X->vecs, &N, NULL));
175     if (dots->block_type == LMBLOCK_DIAGONAL) {
176       lda = 0;
177       if (dots->update_diagonal_global) {
178         PetscCall(VecGetArrayRead(dots->diagonal_global, &values));
179       } else {
180         PetscCall(VecGetArrayRead(dots->diagonal_dup, &values));
181       }
182     } else {
183       PetscCall(MatDenseGetLDA(dots->full, &lda));
184       PetscCall(MatDenseGetArrayRead(dots->full, &values));
185     }
186     for (PetscInt i = oldest; i < next; i++) {
187       Vec       x_i_, x_i;
188       PetscReal x_norm;
189       PetscInt  j_start = oldest;
190       PetscInt  j_end   = next;
191 
192       PetscCall(LMBasisGetVecRead(X, i, &x_i_));
193       PetscCall(VecNorm(x_i_, NORM_1, &x_norm));
194       PetscCall(VecDuplicate(x_i_, &x_i));
195       PetscCall(VecCopy(x_i_, x_i));
196       PetscCall(LMBasisRestoreVecRead(X, i, &x_i_));
197 
198       switch (dots->block_type) {
199       case LMBLOCK_DIAGONAL:
200         j_start = i;
201         j_end   = i + 1;
202         break;
203       case LMBLOCK_UPPER_TRIANGLE:
204         j_start = i;
205         break;
206       case LMBLOCK_STRICT_UPPER_TRIANGLE:
207         j_start = i + 1;
208         break;
209       default:
210         break;
211       }
212       for (PetscInt j = j_start; j < j_end; j++) {
213         Vec         y_j;
214         PetscScalar dot_true, dot = 0.0, diff;
215         PetscReal   y_norm;
216 
217         PetscCall(LMBasisGetVecRead(Y, j, &y_j));
218         PetscCall(VecDot(y_j, x_i, &dot_true));
219         PetscCall(VecNorm(y_j, NORM_1, &y_norm));
220         if (dots->m_local) dot = values[(j % dots->m) * lda + (i % dots->m)];
221         PetscCallMPI(MPI_Bcast(&dot, 1, MPIU_SCALAR, 0, comm));
222         diff = dot_true - dot;
223         if (PetscDefined(USE_COMPLEX)) {
224           PetscCheck(PetscAbsScalar(diff) <= PETSC_SMALL * N * x_norm * y_norm, comm, PETSC_ERR_PLIB, "LMProducts debug: dots[%" PetscInt_FMT ", %" PetscInt_FMT "] = %g + i*%g != VecDot() = %g + i*%g", i, j, (double)PetscRealPart(dot), (double)PetscImaginaryPart(dot), (double)PetscRealPart(dot_true), (double)PetscImaginaryPart(dot_true));
225         } else {
226           PetscCheck(PetscAbsScalar(diff) <= PETSC_SMALL * N * x_norm * y_norm, comm, PETSC_ERR_PLIB, "LMProducts debug: dots[%" PetscInt_FMT ", %" PetscInt_FMT "] = %g != VecDot() = %g", i, j, (double)PetscRealPart(dot), (double)PetscRealPart(dot_true));
227         }
228         PetscCall(LMBasisRestoreVecRead(Y, j, &y_j));
229       }
230 
231       PetscCall(VecDestroy(&x_i));
232     }
233 
234     if (dots->block_type == LMBLOCK_DIAGONAL) {
235       if (dots->update_diagonal_global) {
236         PetscCall(VecRestoreArrayRead(dots->diagonal_global, &values));
237       } else {
238         PetscCall(VecRestoreArrayRead(dots->diagonal_dup, &values));
239       }
240     } else {
241       PetscCall(MatDenseRestoreArrayRead(dots->full, &values));
242     }
243   }
244   PetscCall(PetscLogEventEnd(LMPROD_Update, NULL, NULL, NULL, NULL));
245   PetscFunctionReturn(PETSC_SUCCESS);
246 }
247 
248 // dots = X^H Y
LMProductsUpdate(LMProducts dots,LMBasis X,LMBasis Y)249 PETSC_INTERN PetscErrorCode LMProductsUpdate(LMProducts dots, LMBasis X, LMBasis Y)
250 {
251   PetscInt oldest, next;
252 
253   PetscFunctionBegin;
254   PetscCall(LMBasisGetRange(X, &oldest, &next));
255   PetscCall(LMProductsUpdate_Internal(dots, X, Y, oldest, next));
256   PetscFunctionReturn(PETSC_SUCCESS);
257 }
258 
LMProductsCopy(LMProducts src,LMProducts dest)259 PETSC_INTERN PetscErrorCode LMProductsCopy(LMProducts src, LMProducts dest)
260 {
261   PetscFunctionBegin;
262   PetscCheck(dest->m == src->m, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different size");
263   PetscCheck(dest->m_local == src->m_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different size");
264   PetscCheck(dest->block_type == src->block_type, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different block type");
265   dest->k       = src->k;
266   dest->m_local = src->m_local;
267   if (src->full) PetscCall(MatCopy(src->full, dest->full, DIFFERENT_NONZERO_PATTERN));
268   if (src->diagonal_dup) PetscCall(VecCopy(src->diagonal_dup, dest->diagonal_dup));
269   if (src->diagonal_global) PetscCall(VecCopy(src->diagonal_global, dest->diagonal_global));
270   dest->update_diagonal_global = src->update_diagonal_global;
271   dest->operator_id            = src->operator_id;
272   dest->operator_state         = src->operator_state;
273   PetscFunctionReturn(PETSC_SUCCESS);
274 }
275 
LMProductsScale(LMProducts dots,PetscScalar scale)276 PETSC_INTERN PetscErrorCode LMProductsScale(LMProducts dots, PetscScalar scale)
277 {
278   PetscFunctionBegin;
279   if (dots->full) PetscCall(MatScale(dots->full, scale));
280   if (dots->diagonal_dup) PetscCall(VecScale(dots->diagonal_dup, scale));
281   if (dots->diagonal_global) PetscCall(VecScale(dots->diagonal_global, scale));
282   PetscFunctionReturn(PETSC_SUCCESS);
283 }
284 
LMProductsGetLocalMatrix(LMProducts dots,Mat * G_local,PetscInt * k,PetscBool * local_is_nonempty)285 PETSC_INTERN PetscErrorCode LMProductsGetLocalMatrix(LMProducts dots, Mat *G_local, PetscInt *k, PetscBool *local_is_nonempty)
286 {
287   PetscFunctionBegin;
288   PetscCheck(dots->block_type != LMBLOCK_DIAGONAL, PETSC_COMM_SELF, PETSC_ERR_SUP, "Asking for full matrix of diagonal products");
289   PetscCall(MatDenseGetLocalMatrix(dots->full, G_local));
290   if (k) *k = dots->k;
291   if (local_is_nonempty) *local_is_nonempty = (dots->m_local == dots->m) ? PETSC_TRUE : PETSC_FALSE;
292   PetscFunctionReturn(PETSC_SUCCESS);
293 }
294 
LMProductsRestoreLocalMatrix(LMProducts dots,Mat * G_local,PetscInt * k)295 PETSC_INTERN PetscErrorCode LMProductsRestoreLocalMatrix(LMProducts dots, Mat *G_local, PetscInt *k)
296 {
297   PetscFunctionBegin;
298   if (G_local) *G_local = NULL;
299   if (k) dots->k = *k;
300   PetscFunctionReturn(PETSC_SUCCESS);
301 }
302 
LMProductsGetUpdatedDiagonal(LMProducts dots,Vec * diagonal)303 static PetscErrorCode LMProductsGetUpdatedDiagonal(LMProducts dots, Vec *diagonal)
304 {
305   PetscFunctionBegin;
306   if (!dots->update_diagonal_global) {
307     PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
308     if (dots->m_local) PetscCall(VecCopy(dots->diagonal_dup, dots->diagonal_local));
309     PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
310     dots->update_diagonal_global = PETSC_TRUE;
311   }
312   if (diagonal) *diagonal = dots->diagonal_global;
313   PetscFunctionReturn(PETSC_SUCCESS);
314 }
315 
LMProductsGetLocalDiagonal(LMProducts dots,Vec * D_local)316 PETSC_INTERN PetscErrorCode LMProductsGetLocalDiagonal(LMProducts dots, Vec *D_local)
317 {
318   PetscFunctionBegin;
319   PetscCall(LMProductsGetUpdatedDiagonal(dots, NULL));
320   PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
321   *D_local = dots->diagonal_local;
322   PetscFunctionReturn(PETSC_SUCCESS);
323 }
324 
LMProductsRestoreLocalDiagonal(LMProducts dots,Vec * D_local)325 PETSC_INTERN PetscErrorCode LMProductsRestoreLocalDiagonal(LMProducts dots, Vec *D_local)
326 {
327   PetscFunctionBegin;
328   PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
329   *D_local = NULL;
330   PetscFunctionReturn(PETSC_SUCCESS);
331 }
332 
LMProductsGetNextColumn(LMProducts dots,Vec * col)333 PETSC_INTERN PetscErrorCode LMProductsGetNextColumn(LMProducts dots, Vec *col)
334 {
335   PetscFunctionBegin;
336   PetscCheck(dots->block_type != LMBLOCK_DIAGONAL, PETSC_COMM_SELF, PETSC_ERR_SUP, "Asking for column of diagonal products");
337   PetscCall(MatDenseGetColumnVecWrite(dots->full, dots->k % dots->m, col));
338   PetscFunctionReturn(PETSC_SUCCESS);
339 }
340 
LMProductsRestoreNextColumn(LMProducts dots,Vec * col)341 PETSC_INTERN PetscErrorCode LMProductsRestoreNextColumn(LMProducts dots, Vec *col)
342 {
343   PetscFunctionBegin;
344   PetscCall(MatDenseRestoreColumnVecWrite(dots->full, dots->k % dots->m, col));
345   dots->k++;
346   PetscFunctionReturn(PETSC_SUCCESS);
347 }
348 
349 // copy conj(triu(G)) into tril(G)
LMProductsMakeHermitian(Mat local,PetscInt oldest,PetscInt next)350 PETSC_INTERN PetscErrorCode LMProductsMakeHermitian(Mat local, PetscInt oldest, PetscInt next)
351 {
352   PetscInt m;
353 
354   PetscFunctionBegin;
355   PetscCall(MatGetLocalSize(local, &m, NULL));
356   if (m) {
357     // TODO: implement on device?
358     PetscScalar *a;
359     PetscInt     lda;
360 
361     PetscCall(MatDenseGetLDA(local, &lda));
362     PetscCall(MatDenseGetArray(local, &a));
363     for (PetscInt j_ = oldest; j_ < next; j_++) {
364       PetscInt j = j_ % m;
365 
366       a[j + j * lda] = PetscRealPart(a[j + j * lda]);
367       for (PetscInt i_ = j_ + 1; i_ < next; i_++) {
368         PetscInt i = i_ % m;
369 
370         a[i + j * lda] = PetscConj(a[j + i * lda]);
371       }
372     }
373   }
374   PetscFunctionReturn(PETSC_SUCCESS);
375 }
376 
LMProductsSolve(LMProducts dots,PetscInt oldest,PetscInt next,Vec b,Vec x,PetscBool hermitian_transpose)377 PETSC_INTERN PetscErrorCode LMProductsSolve(LMProducts dots, PetscInt oldest, PetscInt next, Vec b, Vec x, PetscBool hermitian_transpose)
378 {
379   PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
380   PetscInt dots_next   = dots->k;
381   Mat      local;
382   Vec      diag = NULL;
383 
384   PetscFunctionBegin;
385   PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
386   if (oldest >= next) PetscFunctionReturn(PETSC_SUCCESS);
387   PetscCall(PetscLogEventBegin(LMPROD_Solve, NULL, NULL, NULL, NULL));
388   if (!dots->rhs_local) PetscCall(VecCreateLocalVector(b, &dots->rhs_local));
389   if (!dots->lhs_local) PetscCall(VecDuplicate(dots->rhs_local, &dots->lhs_local));
390   switch (dots->block_type) {
391   case LMBLOCK_DIAGONAL:
392     PetscCall(LMProductsGetUpdatedDiagonal(dots, &diag));
393     PetscCall(VecDSVCyclic(hermitian_transpose, oldest, next, diag, b, x));
394     break;
395   case LMBLOCK_UPPER_TRIANGLE:
396     PetscCall(MatSeqDenseTRSVCyclic(hermitian_transpose, oldest, next, dots->full, b, x));
397     break;
398   default: {
399     PetscCall(MatDenseGetLocalMatrix(dots->full, &local));
400     PetscCall(VecGetLocalVector(b, dots->rhs_local));
401     PetscCall(VecGetLocalVector(x, dots->lhs_local));
402     if (dots->m_local) {
403       if (!hermitian_transpose) {
404         PetscCall(MatSolve(local, dots->rhs_local, dots->lhs_local));
405       } else {
406         Vec rhs_conj = dots->rhs_local;
407 
408         if (PetscDefined(USE_COMPLEX)) {
409           PetscCall(VecDuplicate(dots->rhs_local, &rhs_conj));
410           PetscCall(VecCopy(dots->rhs_local, rhs_conj));
411           PetscCall(VecConjugate(rhs_conj));
412         }
413         PetscCall(MatSolveTranspose(local, rhs_conj, dots->lhs_local));
414         if (PetscDefined(USE_COMPLEX)) {
415           PetscCall(VecConjugate(dots->lhs_local));
416           PetscCall(VecDestroy(&rhs_conj));
417         }
418       }
419     }
420     if (x != b) PetscCall(VecRestoreLocalVector(x, dots->lhs_local));
421     PetscCall(VecRestoreLocalVector(b, dots->rhs_local));
422   } break;
423   }
424   PetscCall(PetscLogEventEnd(LMPROD_Solve, NULL, NULL, NULL, NULL));
425   PetscFunctionReturn(PETSC_SUCCESS);
426 }
427 
LMProductsMult(LMProducts dots,PetscInt oldest,PetscInt next,PetscScalar alpha,Vec x,PetscScalar beta,Vec y,PetscBool hermitian_transpose)428 PETSC_INTERN PetscErrorCode LMProductsMult(LMProducts dots, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y, PetscBool hermitian_transpose)
429 {
430   PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
431   PetscInt dots_next   = dots->k;
432   Vec      diag        = NULL;
433 
434   PetscFunctionBegin;
435   PetscCall(PetscLogEventBegin(LMPROD_Mult, NULL, NULL, NULL, NULL));
436   PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
437   switch (dots->block_type) {
438   case LMBLOCK_DIAGONAL: {
439     PetscCall(LMProductsGetUpdatedDiagonal(dots, &diag));
440     PetscCall(VecDMVCyclic(hermitian_transpose, oldest, next, alpha, diag, x, beta, y));
441   } break;
442   case LMBLOCK_STRICT_UPPER_TRIANGLE: // the lower triangle has been zeroed, MatMult() is safe
443   case LMBLOCK_FULL:
444     PetscCall(MatSeqDenseGEMVCyclic(hermitian_transpose, oldest, next, alpha, dots->full, x, beta, y));
445     break;
446   default:
447     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
448   }
449   PetscCall(PetscLogEventEnd(LMPROD_Mult, NULL, NULL, NULL, NULL));
450   PetscFunctionReturn(PETSC_SUCCESS);
451 }
452 
LMProductsMultHermitian(LMProducts dots,PetscInt oldest,PetscInt next,PetscScalar alpha,Vec x,PetscScalar beta,Vec y)453 PETSC_INTERN PetscErrorCode LMProductsMultHermitian(LMProducts dots, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y)
454 {
455   PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
456   PetscInt dots_next   = dots->k;
457 
458   PetscFunctionBegin;
459   PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
460   if (dots->block_type == LMBLOCK_DIAGONAL) PetscCall(LMProductsMult(dots, oldest, next, alpha, x, beta, y, PETSC_FALSE));
461   else {
462     PetscCall(PetscLogEventBegin(LMPROD_Mult, NULL, NULL, NULL, NULL));
463     PetscCall(MatSeqDenseHEMVCyclic(oldest, next, alpha, dots->full, x, beta, y));
464     PetscCall(PetscLogEventEnd(LMPROD_Mult, NULL, NULL, NULL, NULL));
465   }
466   PetscFunctionReturn(PETSC_SUCCESS);
467 }
468 
LMProductsReset(LMProducts dots)469 PETSC_INTERN PetscErrorCode LMProductsReset(LMProducts dots)
470 {
471   PetscFunctionBegin;
472   if (dots) {
473     dots->k              = 0;
474     dots->operator_id    = 0;
475     dots->operator_state = 0;
476     if (dots->full) {
477       Mat full_local;
478 
479       PetscCall(MatDenseGetLocalMatrix(dots->full, &full_local));
480       PetscCall(MatSetUnfactored(full_local));
481       PetscCall(MatZeroEntries(full_local));
482     }
483     if (dots->diagonal_global) PetscCall(VecZeroEntries(dots->diagonal_dup));
484     if (dots->diagonal_dup) PetscCall(VecZeroEntries(dots->diagonal_dup));
485   }
486   PetscFunctionReturn(PETSC_SUCCESS);
487 }
488 
LMProductsGetDiagonalValue(LMProducts dots,PetscInt i,PetscScalar * v)489 PETSC_INTERN PetscErrorCode LMProductsGetDiagonalValue(LMProducts dots, PetscInt i, PetscScalar *v)
490 {
491   PetscFunctionBegin;
492   PetscInt oldest = PetscMax(0, dots->k - dots->m);
493   PetscInt next   = dots->k;
494   PetscInt idx    = i % dots->m;
495   PetscCheck(i >= oldest && i < next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Inserting value %d out of range [%d, %d)", (int)i, (int)oldest, (int)next);
496   PetscCall(VecGetValues(dots->diagonal_dup, 1, &idx, v));
497   PetscFunctionReturn(PETSC_SUCCESS);
498 }
499 
LMProductsInsertNextDiagonalValue(LMProducts dots,PetscInt i,PetscScalar v)500 PETSC_INTERN PetscErrorCode LMProductsInsertNextDiagonalValue(LMProducts dots, PetscInt i, PetscScalar v)
501 {
502   PetscInt idx = i % dots->m;
503 
504   PetscFunctionBegin;
505   PetscCheck(i == dots->k, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is not the next index (%" PetscInt_FMT ")", i, dots->k);
506   PetscCall(VecSetValue(dots->diagonal_dup, idx, v, INSERT_VALUES));
507   if (dots->update_diagonal_global) {
508     PetscScalar *array;
509     PetscMemType memtype;
510 
511     PetscCall(VecGetArrayAndMemType(dots->diagonal_global, &array, &memtype));
512     if (dots->m_local > 0) {
513       if (PetscMemTypeHost(memtype)) {
514         array[idx] = v;
515         PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
516       } else {
517         PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
518         PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
519         if (dots->m_local) PetscCall(VecCopy(dots->diagonal_dup, dots->diagonal_local));
520         PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
521       }
522     } else {
523       PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
524     }
525   }
526   dots->k++;
527   PetscFunctionReturn(PETSC_SUCCESS);
528 }
529 
LMProductsOnesOnUnusedDiagonal(Mat A,PetscInt oldest,PetscInt next)530 PETSC_INTERN PetscErrorCode LMProductsOnesOnUnusedDiagonal(Mat A, PetscInt oldest, PetscInt next)
531 {
532   PetscInt m;
533   Mat      sub;
534 
535   PetscFunctionBegin;
536   PetscCall(MatGetSize(A, &m, NULL));
537   // we could handle the general case but this is the only case used by MatLMVM
538   PetscCheck((next < m && oldest == 0) || next - oldest == m, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "General case not implemented");
539   if (next - oldest == m) PetscFunctionReturn(PETSC_SUCCESS); // nothing to do if all entries are used
540   PetscCall(MatDenseGetSubMatrix(A, next, m, next, m, &sub));
541   PetscCall(MatShift(sub, 1.0));
542   PetscCall(MatDenseRestoreSubMatrix(A, &sub));
543   PetscFunctionReturn(PETSC_SUCCESS);
544 }
545