1 #include <petsc/private/petscimpl.h>
2 #include <petscmat.h>
3 #include <petscblaslapack.h>
4 #include <petscdevice.h>
5 #include "lmproducts.h"
6 #include "blas_cyclic/blas_cyclic.h"
7
8 PetscLogEvent LMPROD_Mult, LMPROD_Solve, LMPROD_Update;
9
LMProductsCreate(LMBasis basis,LMBlockType block_type,LMProducts * dots)10 PETSC_INTERN PetscErrorCode LMProductsCreate(LMBasis basis, LMBlockType block_type, LMProducts *dots)
11 {
12 PetscInt m, m_local;
13
14 PetscFunctionBegin;
15 PetscAssertPointer(basis, 1);
16 PetscValidHeaderSpecific(basis->vecs, MAT_CLASSID, 1);
17 PetscCheck(block_type >= 0 && block_type < LMBLOCK_END, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Invalid LMBlockType");
18 PetscCall(PetscNew(dots));
19 (*dots)->m = m = basis->m;
20 (*dots)->block_type = block_type;
21 PetscCall(MatGetLocalSize(basis->vecs, NULL, &m_local));
22 (*dots)->m_local = m_local;
23 if (block_type == LMBLOCK_DIAGONAL) {
24 VecType vec_type;
25
26 PetscCall(MatCreateVecs(basis->vecs, &(*dots)->diagonal_global, NULL));
27 PetscCall(VecCreateLocalVector((*dots)->diagonal_global, &(*dots)->diagonal_local));
28 PetscCall(VecGetType((*dots)->diagonal_local, &vec_type));
29 PetscCall(VecCreate(PETSC_COMM_SELF, &(*dots)->diagonal_dup));
30 PetscCall(VecSetSizes((*dots)->diagonal_dup, m, m));
31 PetscCall(VecSetType((*dots)->diagonal_dup, vec_type));
32 PetscCall(VecSetUp((*dots)->diagonal_dup));
33 } else {
34 VecType vec_type;
35
36 PetscCall(MatGetVecType(basis->vecs, &vec_type));
37 PetscCall(MatCreateDenseFromVecType(PetscObjectComm((PetscObject)basis->vecs), vec_type, m_local, m_local, m, m, m_local, NULL, &(*dots)->full));
38 }
39 PetscFunctionReturn(PETSC_SUCCESS);
40 }
41
LMProductsDestroy(LMProducts * dots_p)42 PETSC_INTERN PetscErrorCode LMProductsDestroy(LMProducts *dots_p)
43 {
44 PetscFunctionBegin;
45 LMProducts dots = *dots_p;
46 if (dots == NULL) PetscFunctionReturn(PETSC_SUCCESS);
47 PetscCall(MatDestroy(&dots->full));
48 PetscCall(VecDestroy(&dots->diagonal_dup));
49 PetscCall(VecDestroy(&dots->diagonal_local));
50 PetscCall(VecDestroy(&dots->diagonal_global));
51 PetscCall(VecDestroy(&dots->rhs_local));
52 PetscCall(VecDestroy(&dots->lhs_local));
53 PetscCall(PetscFree(dots));
54 PetscFunctionReturn(PETSC_SUCCESS);
55 }
56
LMProductsPrepare_Internal(LMProducts dots,PetscObjectId operator_id,PetscObjectState operator_state,PetscInt oldest,PetscInt next)57 static PetscErrorCode LMProductsPrepare_Internal(LMProducts dots, PetscObjectId operator_id, PetscObjectState operator_state, PetscInt oldest, PetscInt next)
58 {
59 PetscFunctionBegin;
60 if (dots->operator_id != operator_id || dots->operator_state != operator_state) {
61 // invalidate the block
62 dots->operator_id = operator_id;
63 dots->operator_state = operator_state;
64 dots->k = oldest;
65 }
66 dots->k = PetscMax(oldest, dots->k);
67 PetscFunctionReturn(PETSC_SUCCESS);
68 }
69
LMProductsPrepareFromBases(LMProducts dots,LMBasis X,LMBasis Y)70 static PetscErrorCode LMProductsPrepareFromBases(LMProducts dots, LMBasis X, LMBasis Y)
71 {
72 PetscInt oldest, next;
73 PetscObjectId operator_id = (X->operator_id == 0) ? Y->operator_id : X->operator_id;
74 PetscObjectId operator_state = (X->operator_id == 0) ? Y->operator_state : X->operator_state;
75
76 PetscFunctionBegin;
77 PetscCall(LMBasisGetRange(X, &oldest, &next));
78 PetscCall(LMProductsPrepare_Internal(dots, operator_id, operator_state, oldest, next));
79 PetscFunctionReturn(PETSC_SUCCESS);
80 }
81
LMProductsPrepare(LMProducts dots,Mat op,PetscInt oldest,PetscInt next)82 PETSC_INTERN PetscErrorCode LMProductsPrepare(LMProducts dots, Mat op, PetscInt oldest, PetscInt next)
83 {
84 PetscObjectId operator_id;
85 PetscObjectState operator_state;
86
87 PetscFunctionBegin;
88 PetscCall(PetscObjectGetId((PetscObject)op, &operator_id));
89 PetscCall(PetscObjectStateGet((PetscObject)op, &operator_state));
90 PetscCall(LMProductsPrepare_Internal(dots, operator_id, operator_state, oldest, next));
91 PetscFunctionReturn(PETSC_SUCCESS);
92 }
93
LMProductsUpdate_Internal(LMProducts dots,LMBasis X,LMBasis Y,PetscInt oldest,PetscInt next)94 static PetscErrorCode LMProductsUpdate_Internal(LMProducts dots, LMBasis X, LMBasis Y, PetscInt oldest, PetscInt next)
95 {
96 MPI_Comm comm = PetscObjectComm((PetscObject)X->vecs);
97 PetscInt start;
98
99 PetscFunctionBegin;
100 PetscAssert(X->m == Y->m && X->m == dots->m, comm, PETSC_ERR_ARG_INCOMP, "X vecs, Y vecs, and dots incompatible in size, (%d, %d, %d)", (int)X->m, (int)Y->m, (int)dots->m);
101 PetscAssert(X->k == Y->k, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are incompatible in state, (%d, %d)", (int)X->k, (int)Y->k);
102 PetscAssert(dots->k <= X->k, comm, PETSC_ERR_ARG_INCOMP, "Dot products are ahead of X and Y, (%d, %d)", (int)dots->k, (int)X->k);
103 PetscAssert(X->operator_id == 0 || Y->operator_id == 0 || X->operator_id == Y->operator_id, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are from different operators");
104 PetscAssert(X->operator_id != Y->operator_id || Y->operator_state == X->operator_state, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are from different operator states");
105
106 PetscCall(LMProductsPrepareFromBases(dots, X, Y));
107
108 start = dots->k;
109 if (start == next) PetscFunctionReturn(PETSC_SUCCESS);
110 PetscCall(PetscLogEventBegin(LMPROD_Update, NULL, NULL, NULL, NULL));
111 switch (dots->block_type) {
112 case LMBLOCK_DIAGONAL:
113 for (PetscInt i = start; i < next; i++) {
114 Vec x, y;
115 PetscScalar xTy;
116
117 PetscCall(LMBasisGetVecRead(X, i, &x));
118 y = x;
119 if (Y != X) PetscCall(LMBasisGetVecRead(Y, i, &y));
120 PetscCall(VecDot(y, x, &xTy));
121 if (Y != X) PetscCall(LMBasisRestoreVecRead(Y, i, &y));
122 PetscCall(LMBasisRestoreVecRead(X, i, &x));
123 PetscCall(LMProductsInsertNextDiagonalValue(dots, i, xTy));
124 }
125 break;
126 case LMBLOCK_STRICT_UPPER_TRIANGLE: {
127 Mat local;
128
129 PetscCall(MatDenseGetLocalMatrix(dots->full, &local));
130 // we have to proceed index by index because we want to zero each row after we compute the corresponding column
131 for (PetscInt i = start; i < next; i++) {
132 Mat row;
133 Vec column, y;
134
135 PetscCall(LMBasisGetVecRead(Y, i, &y));
136 PetscCall(MatDenseGetColumnVec(dots->full, i % dots->m, &column));
137 PetscCall(LMBasisGEMVH(X, oldest, next, 1.0, y, 0.0, column));
138 PetscCall(MatDenseRestoreColumnVec(dots->full, i % dots->m, &column));
139 PetscCall(LMBasisRestoreVecRead(Y, i, &y));
140
141 // zero out the new row
142 if (dots->m_local) {
143 PetscCall(MatDenseGetSubMatrix(local, i % dots->m, (i % dots->m) + 1, PETSC_DECIDE, PETSC_DECIDE, &row));
144 PetscCall(MatZeroEntries(row));
145 PetscCall(MatDenseRestoreSubMatrix(local, &row));
146 }
147 }
148 } break;
149 case LMBLOCK_UPPER_TRIANGLE: {
150 PetscInt mid = next - (next % dots->m);
151 PetscInt start_idx = start % dots->m;
152 PetscInt next_idx = ((next - 1) % dots->m) + 1;
153
154 if (next_idx > start_idx) {
155 PetscCall(LMBasisGEMMH(X, oldest, next, Y, start, next, 1.0, 0.0, dots->full));
156 } else {
157 PetscCall(LMBasisGEMMH(X, oldest, mid, Y, start, mid, 1.0, 0.0, dots->full));
158 PetscCall(LMBasisGEMMH(X, oldest, next, Y, mid, next, 1.0, 0.0, dots->full));
159 }
160 } break;
161 case LMBLOCK_FULL:
162 PetscCall(LMBasisGEMMH(X, oldest, next, Y, start, next, 1.0, 0.0, dots->full));
163 PetscCall(LMBasisGEMMH(X, start, next, Y, oldest, start, 1.0, 0.0, dots->full));
164 break;
165 default:
166 PetscUnreachable();
167 }
168 dots->k = next;
169 if (dots->debug) {
170 const PetscScalar *values = NULL;
171 PetscInt lda;
172 PetscInt N;
173
174 PetscCall(MatGetSize(X->vecs, &N, NULL));
175 if (dots->block_type == LMBLOCK_DIAGONAL) {
176 lda = 0;
177 if (dots->update_diagonal_global) {
178 PetscCall(VecGetArrayRead(dots->diagonal_global, &values));
179 } else {
180 PetscCall(VecGetArrayRead(dots->diagonal_dup, &values));
181 }
182 } else {
183 PetscCall(MatDenseGetLDA(dots->full, &lda));
184 PetscCall(MatDenseGetArrayRead(dots->full, &values));
185 }
186 for (PetscInt i = oldest; i < next; i++) {
187 Vec x_i_, x_i;
188 PetscReal x_norm;
189 PetscInt j_start = oldest;
190 PetscInt j_end = next;
191
192 PetscCall(LMBasisGetVecRead(X, i, &x_i_));
193 PetscCall(VecNorm(x_i_, NORM_1, &x_norm));
194 PetscCall(VecDuplicate(x_i_, &x_i));
195 PetscCall(VecCopy(x_i_, x_i));
196 PetscCall(LMBasisRestoreVecRead(X, i, &x_i_));
197
198 switch (dots->block_type) {
199 case LMBLOCK_DIAGONAL:
200 j_start = i;
201 j_end = i + 1;
202 break;
203 case LMBLOCK_UPPER_TRIANGLE:
204 j_start = i;
205 break;
206 case LMBLOCK_STRICT_UPPER_TRIANGLE:
207 j_start = i + 1;
208 break;
209 default:
210 break;
211 }
212 for (PetscInt j = j_start; j < j_end; j++) {
213 Vec y_j;
214 PetscScalar dot_true, dot = 0.0, diff;
215 PetscReal y_norm;
216
217 PetscCall(LMBasisGetVecRead(Y, j, &y_j));
218 PetscCall(VecDot(y_j, x_i, &dot_true));
219 PetscCall(VecNorm(y_j, NORM_1, &y_norm));
220 if (dots->m_local) dot = values[(j % dots->m) * lda + (i % dots->m)];
221 PetscCallMPI(MPI_Bcast(&dot, 1, MPIU_SCALAR, 0, comm));
222 diff = dot_true - dot;
223 if (PetscDefined(USE_COMPLEX)) {
224 PetscCheck(PetscAbsScalar(diff) <= PETSC_SMALL * N * x_norm * y_norm, comm, PETSC_ERR_PLIB, "LMProducts debug: dots[%" PetscInt_FMT ", %" PetscInt_FMT "] = %g + i*%g != VecDot() = %g + i*%g", i, j, (double)PetscRealPart(dot), (double)PetscImaginaryPart(dot), (double)PetscRealPart(dot_true), (double)PetscImaginaryPart(dot_true));
225 } else {
226 PetscCheck(PetscAbsScalar(diff) <= PETSC_SMALL * N * x_norm * y_norm, comm, PETSC_ERR_PLIB, "LMProducts debug: dots[%" PetscInt_FMT ", %" PetscInt_FMT "] = %g != VecDot() = %g", i, j, (double)PetscRealPart(dot), (double)PetscRealPart(dot_true));
227 }
228 PetscCall(LMBasisRestoreVecRead(Y, j, &y_j));
229 }
230
231 PetscCall(VecDestroy(&x_i));
232 }
233
234 if (dots->block_type == LMBLOCK_DIAGONAL) {
235 if (dots->update_diagonal_global) {
236 PetscCall(VecRestoreArrayRead(dots->diagonal_global, &values));
237 } else {
238 PetscCall(VecRestoreArrayRead(dots->diagonal_dup, &values));
239 }
240 } else {
241 PetscCall(MatDenseRestoreArrayRead(dots->full, &values));
242 }
243 }
244 PetscCall(PetscLogEventEnd(LMPROD_Update, NULL, NULL, NULL, NULL));
245 PetscFunctionReturn(PETSC_SUCCESS);
246 }
247
248 // dots = X^H Y
LMProductsUpdate(LMProducts dots,LMBasis X,LMBasis Y)249 PETSC_INTERN PetscErrorCode LMProductsUpdate(LMProducts dots, LMBasis X, LMBasis Y)
250 {
251 PetscInt oldest, next;
252
253 PetscFunctionBegin;
254 PetscCall(LMBasisGetRange(X, &oldest, &next));
255 PetscCall(LMProductsUpdate_Internal(dots, X, Y, oldest, next));
256 PetscFunctionReturn(PETSC_SUCCESS);
257 }
258
LMProductsCopy(LMProducts src,LMProducts dest)259 PETSC_INTERN PetscErrorCode LMProductsCopy(LMProducts src, LMProducts dest)
260 {
261 PetscFunctionBegin;
262 PetscCheck(dest->m == src->m, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different size");
263 PetscCheck(dest->m_local == src->m_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different size");
264 PetscCheck(dest->block_type == src->block_type, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different block type");
265 dest->k = src->k;
266 dest->m_local = src->m_local;
267 if (src->full) PetscCall(MatCopy(src->full, dest->full, DIFFERENT_NONZERO_PATTERN));
268 if (src->diagonal_dup) PetscCall(VecCopy(src->diagonal_dup, dest->diagonal_dup));
269 if (src->diagonal_global) PetscCall(VecCopy(src->diagonal_global, dest->diagonal_global));
270 dest->update_diagonal_global = src->update_diagonal_global;
271 dest->operator_id = src->operator_id;
272 dest->operator_state = src->operator_state;
273 PetscFunctionReturn(PETSC_SUCCESS);
274 }
275
LMProductsScale(LMProducts dots,PetscScalar scale)276 PETSC_INTERN PetscErrorCode LMProductsScale(LMProducts dots, PetscScalar scale)
277 {
278 PetscFunctionBegin;
279 if (dots->full) PetscCall(MatScale(dots->full, scale));
280 if (dots->diagonal_dup) PetscCall(VecScale(dots->diagonal_dup, scale));
281 if (dots->diagonal_global) PetscCall(VecScale(dots->diagonal_global, scale));
282 PetscFunctionReturn(PETSC_SUCCESS);
283 }
284
LMProductsGetLocalMatrix(LMProducts dots,Mat * G_local,PetscInt * k,PetscBool * local_is_nonempty)285 PETSC_INTERN PetscErrorCode LMProductsGetLocalMatrix(LMProducts dots, Mat *G_local, PetscInt *k, PetscBool *local_is_nonempty)
286 {
287 PetscFunctionBegin;
288 PetscCheck(dots->block_type != LMBLOCK_DIAGONAL, PETSC_COMM_SELF, PETSC_ERR_SUP, "Asking for full matrix of diagonal products");
289 PetscCall(MatDenseGetLocalMatrix(dots->full, G_local));
290 if (k) *k = dots->k;
291 if (local_is_nonempty) *local_is_nonempty = (dots->m_local == dots->m) ? PETSC_TRUE : PETSC_FALSE;
292 PetscFunctionReturn(PETSC_SUCCESS);
293 }
294
LMProductsRestoreLocalMatrix(LMProducts dots,Mat * G_local,PetscInt * k)295 PETSC_INTERN PetscErrorCode LMProductsRestoreLocalMatrix(LMProducts dots, Mat *G_local, PetscInt *k)
296 {
297 PetscFunctionBegin;
298 if (G_local) *G_local = NULL;
299 if (k) dots->k = *k;
300 PetscFunctionReturn(PETSC_SUCCESS);
301 }
302
LMProductsGetUpdatedDiagonal(LMProducts dots,Vec * diagonal)303 static PetscErrorCode LMProductsGetUpdatedDiagonal(LMProducts dots, Vec *diagonal)
304 {
305 PetscFunctionBegin;
306 if (!dots->update_diagonal_global) {
307 PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
308 if (dots->m_local) PetscCall(VecCopy(dots->diagonal_dup, dots->diagonal_local));
309 PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
310 dots->update_diagonal_global = PETSC_TRUE;
311 }
312 if (diagonal) *diagonal = dots->diagonal_global;
313 PetscFunctionReturn(PETSC_SUCCESS);
314 }
315
LMProductsGetLocalDiagonal(LMProducts dots,Vec * D_local)316 PETSC_INTERN PetscErrorCode LMProductsGetLocalDiagonal(LMProducts dots, Vec *D_local)
317 {
318 PetscFunctionBegin;
319 PetscCall(LMProductsGetUpdatedDiagonal(dots, NULL));
320 PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
321 *D_local = dots->diagonal_local;
322 PetscFunctionReturn(PETSC_SUCCESS);
323 }
324
LMProductsRestoreLocalDiagonal(LMProducts dots,Vec * D_local)325 PETSC_INTERN PetscErrorCode LMProductsRestoreLocalDiagonal(LMProducts dots, Vec *D_local)
326 {
327 PetscFunctionBegin;
328 PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
329 *D_local = NULL;
330 PetscFunctionReturn(PETSC_SUCCESS);
331 }
332
LMProductsGetNextColumn(LMProducts dots,Vec * col)333 PETSC_INTERN PetscErrorCode LMProductsGetNextColumn(LMProducts dots, Vec *col)
334 {
335 PetscFunctionBegin;
336 PetscCheck(dots->block_type != LMBLOCK_DIAGONAL, PETSC_COMM_SELF, PETSC_ERR_SUP, "Asking for column of diagonal products");
337 PetscCall(MatDenseGetColumnVecWrite(dots->full, dots->k % dots->m, col));
338 PetscFunctionReturn(PETSC_SUCCESS);
339 }
340
LMProductsRestoreNextColumn(LMProducts dots,Vec * col)341 PETSC_INTERN PetscErrorCode LMProductsRestoreNextColumn(LMProducts dots, Vec *col)
342 {
343 PetscFunctionBegin;
344 PetscCall(MatDenseRestoreColumnVecWrite(dots->full, dots->k % dots->m, col));
345 dots->k++;
346 PetscFunctionReturn(PETSC_SUCCESS);
347 }
348
349 // copy conj(triu(G)) into tril(G)
LMProductsMakeHermitian(Mat local,PetscInt oldest,PetscInt next)350 PETSC_INTERN PetscErrorCode LMProductsMakeHermitian(Mat local, PetscInt oldest, PetscInt next)
351 {
352 PetscInt m;
353
354 PetscFunctionBegin;
355 PetscCall(MatGetLocalSize(local, &m, NULL));
356 if (m) {
357 // TODO: implement on device?
358 PetscScalar *a;
359 PetscInt lda;
360
361 PetscCall(MatDenseGetLDA(local, &lda));
362 PetscCall(MatDenseGetArray(local, &a));
363 for (PetscInt j_ = oldest; j_ < next; j_++) {
364 PetscInt j = j_ % m;
365
366 a[j + j * lda] = PetscRealPart(a[j + j * lda]);
367 for (PetscInt i_ = j_ + 1; i_ < next; i_++) {
368 PetscInt i = i_ % m;
369
370 a[i + j * lda] = PetscConj(a[j + i * lda]);
371 }
372 }
373 }
374 PetscFunctionReturn(PETSC_SUCCESS);
375 }
376
LMProductsSolve(LMProducts dots,PetscInt oldest,PetscInt next,Vec b,Vec x,PetscBool hermitian_transpose)377 PETSC_INTERN PetscErrorCode LMProductsSolve(LMProducts dots, PetscInt oldest, PetscInt next, Vec b, Vec x, PetscBool hermitian_transpose)
378 {
379 PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
380 PetscInt dots_next = dots->k;
381 Mat local;
382 Vec diag = NULL;
383
384 PetscFunctionBegin;
385 PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
386 if (oldest >= next) PetscFunctionReturn(PETSC_SUCCESS);
387 PetscCall(PetscLogEventBegin(LMPROD_Solve, NULL, NULL, NULL, NULL));
388 if (!dots->rhs_local) PetscCall(VecCreateLocalVector(b, &dots->rhs_local));
389 if (!dots->lhs_local) PetscCall(VecDuplicate(dots->rhs_local, &dots->lhs_local));
390 switch (dots->block_type) {
391 case LMBLOCK_DIAGONAL:
392 PetscCall(LMProductsGetUpdatedDiagonal(dots, &diag));
393 PetscCall(VecDSVCyclic(hermitian_transpose, oldest, next, diag, b, x));
394 break;
395 case LMBLOCK_UPPER_TRIANGLE:
396 PetscCall(MatSeqDenseTRSVCyclic(hermitian_transpose, oldest, next, dots->full, b, x));
397 break;
398 default: {
399 PetscCall(MatDenseGetLocalMatrix(dots->full, &local));
400 PetscCall(VecGetLocalVector(b, dots->rhs_local));
401 PetscCall(VecGetLocalVector(x, dots->lhs_local));
402 if (dots->m_local) {
403 if (!hermitian_transpose) {
404 PetscCall(MatSolve(local, dots->rhs_local, dots->lhs_local));
405 } else {
406 Vec rhs_conj = dots->rhs_local;
407
408 if (PetscDefined(USE_COMPLEX)) {
409 PetscCall(VecDuplicate(dots->rhs_local, &rhs_conj));
410 PetscCall(VecCopy(dots->rhs_local, rhs_conj));
411 PetscCall(VecConjugate(rhs_conj));
412 }
413 PetscCall(MatSolveTranspose(local, rhs_conj, dots->lhs_local));
414 if (PetscDefined(USE_COMPLEX)) {
415 PetscCall(VecConjugate(dots->lhs_local));
416 PetscCall(VecDestroy(&rhs_conj));
417 }
418 }
419 }
420 if (x != b) PetscCall(VecRestoreLocalVector(x, dots->lhs_local));
421 PetscCall(VecRestoreLocalVector(b, dots->rhs_local));
422 } break;
423 }
424 PetscCall(PetscLogEventEnd(LMPROD_Solve, NULL, NULL, NULL, NULL));
425 PetscFunctionReturn(PETSC_SUCCESS);
426 }
427
LMProductsMult(LMProducts dots,PetscInt oldest,PetscInt next,PetscScalar alpha,Vec x,PetscScalar beta,Vec y,PetscBool hermitian_transpose)428 PETSC_INTERN PetscErrorCode LMProductsMult(LMProducts dots, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y, PetscBool hermitian_transpose)
429 {
430 PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
431 PetscInt dots_next = dots->k;
432 Vec diag = NULL;
433
434 PetscFunctionBegin;
435 PetscCall(PetscLogEventBegin(LMPROD_Mult, NULL, NULL, NULL, NULL));
436 PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
437 switch (dots->block_type) {
438 case LMBLOCK_DIAGONAL: {
439 PetscCall(LMProductsGetUpdatedDiagonal(dots, &diag));
440 PetscCall(VecDMVCyclic(hermitian_transpose, oldest, next, alpha, diag, x, beta, y));
441 } break;
442 case LMBLOCK_STRICT_UPPER_TRIANGLE: // the lower triangle has been zeroed, MatMult() is safe
443 case LMBLOCK_FULL:
444 PetscCall(MatSeqDenseGEMVCyclic(hermitian_transpose, oldest, next, alpha, dots->full, x, beta, y));
445 break;
446 default:
447 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
448 }
449 PetscCall(PetscLogEventEnd(LMPROD_Mult, NULL, NULL, NULL, NULL));
450 PetscFunctionReturn(PETSC_SUCCESS);
451 }
452
LMProductsMultHermitian(LMProducts dots,PetscInt oldest,PetscInt next,PetscScalar alpha,Vec x,PetscScalar beta,Vec y)453 PETSC_INTERN PetscErrorCode LMProductsMultHermitian(LMProducts dots, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y)
454 {
455 PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
456 PetscInt dots_next = dots->k;
457
458 PetscFunctionBegin;
459 PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
460 if (dots->block_type == LMBLOCK_DIAGONAL) PetscCall(LMProductsMult(dots, oldest, next, alpha, x, beta, y, PETSC_FALSE));
461 else {
462 PetscCall(PetscLogEventBegin(LMPROD_Mult, NULL, NULL, NULL, NULL));
463 PetscCall(MatSeqDenseHEMVCyclic(oldest, next, alpha, dots->full, x, beta, y));
464 PetscCall(PetscLogEventEnd(LMPROD_Mult, NULL, NULL, NULL, NULL));
465 }
466 PetscFunctionReturn(PETSC_SUCCESS);
467 }
468
LMProductsReset(LMProducts dots)469 PETSC_INTERN PetscErrorCode LMProductsReset(LMProducts dots)
470 {
471 PetscFunctionBegin;
472 if (dots) {
473 dots->k = 0;
474 dots->operator_id = 0;
475 dots->operator_state = 0;
476 if (dots->full) {
477 Mat full_local;
478
479 PetscCall(MatDenseGetLocalMatrix(dots->full, &full_local));
480 PetscCall(MatSetUnfactored(full_local));
481 PetscCall(MatZeroEntries(full_local));
482 }
483 if (dots->diagonal_global) PetscCall(VecZeroEntries(dots->diagonal_dup));
484 if (dots->diagonal_dup) PetscCall(VecZeroEntries(dots->diagonal_dup));
485 }
486 PetscFunctionReturn(PETSC_SUCCESS);
487 }
488
LMProductsGetDiagonalValue(LMProducts dots,PetscInt i,PetscScalar * v)489 PETSC_INTERN PetscErrorCode LMProductsGetDiagonalValue(LMProducts dots, PetscInt i, PetscScalar *v)
490 {
491 PetscFunctionBegin;
492 PetscInt oldest = PetscMax(0, dots->k - dots->m);
493 PetscInt next = dots->k;
494 PetscInt idx = i % dots->m;
495 PetscCheck(i >= oldest && i < next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Inserting value %d out of range [%d, %d)", (int)i, (int)oldest, (int)next);
496 PetscCall(VecGetValues(dots->diagonal_dup, 1, &idx, v));
497 PetscFunctionReturn(PETSC_SUCCESS);
498 }
499
LMProductsInsertNextDiagonalValue(LMProducts dots,PetscInt i,PetscScalar v)500 PETSC_INTERN PetscErrorCode LMProductsInsertNextDiagonalValue(LMProducts dots, PetscInt i, PetscScalar v)
501 {
502 PetscInt idx = i % dots->m;
503
504 PetscFunctionBegin;
505 PetscCheck(i == dots->k, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is not the next index (%" PetscInt_FMT ")", i, dots->k);
506 PetscCall(VecSetValue(dots->diagonal_dup, idx, v, INSERT_VALUES));
507 if (dots->update_diagonal_global) {
508 PetscScalar *array;
509 PetscMemType memtype;
510
511 PetscCall(VecGetArrayAndMemType(dots->diagonal_global, &array, &memtype));
512 if (dots->m_local > 0) {
513 if (PetscMemTypeHost(memtype)) {
514 array[idx] = v;
515 PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
516 } else {
517 PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
518 PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
519 if (dots->m_local) PetscCall(VecCopy(dots->diagonal_dup, dots->diagonal_local));
520 PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
521 }
522 } else {
523 PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
524 }
525 }
526 dots->k++;
527 PetscFunctionReturn(PETSC_SUCCESS);
528 }
529
LMProductsOnesOnUnusedDiagonal(Mat A,PetscInt oldest,PetscInt next)530 PETSC_INTERN PetscErrorCode LMProductsOnesOnUnusedDiagonal(Mat A, PetscInt oldest, PetscInt next)
531 {
532 PetscInt m;
533 Mat sub;
534
535 PetscFunctionBegin;
536 PetscCall(MatGetSize(A, &m, NULL));
537 // we could handle the general case but this is the only case used by MatLMVM
538 PetscCheck((next < m && oldest == 0) || next - oldest == m, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "General case not implemented");
539 if (next - oldest == m) PetscFunctionReturn(PETSC_SUCCESS); // nothing to do if all entries are used
540 PetscCall(MatDenseGetSubMatrix(A, next, m, next, m, &sub));
541 PetscCall(MatShift(sub, 1.0));
542 PetscCall(MatDenseRestoreSubMatrix(A, &sub));
543 PetscFunctionReturn(PETSC_SUCCESS);
544 }
545