xref: /petsc/src/mat/impls/lrc/lrc.c (revision 98d129c30f3ee9fdddc40fdbc5a989b7be64f888)
1 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
2 
3 PETSC_EXTERN PetscErrorCode VecGetRootType_Private(Vec, VecType *);
4 
5 typedef struct {
6   Mat A;            /* sparse matrix */
7   Mat U, V;         /* dense tall-skinny matrices */
8   Vec c;            /* sequential vector containing the diagonal of C */
9   Vec work1, work2; /* sequential vectors that hold partial products */
10   Vec xl, yl;       /* auxiliary sequential vectors for matmult operation */
11 } Mat_LRC;
12 
13 static PetscErrorCode MatMult_LRC_kernel(Mat N, Vec x, Vec y, PetscBool transpose)
14 {
15   Mat_LRC    *Na = (Mat_LRC *)N->data;
16   PetscMPIInt size;
17   Mat         U, V;
18 
19   PetscFunctionBegin;
20   U = transpose ? Na->V : Na->U;
21   V = transpose ? Na->U : Na->V;
22   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)N), &size));
23   if (size == 1) {
24     PetscCall(MatMultHermitianTranspose(V, x, Na->work1));
25     if (Na->c) PetscCall(VecPointwiseMult(Na->work1, Na->c, Na->work1));
26     if (Na->A) {
27       if (transpose) {
28         PetscCall(MatMultTranspose(Na->A, x, y));
29       } else {
30         PetscCall(MatMult(Na->A, x, y));
31       }
32       PetscCall(MatMultAdd(U, Na->work1, y, y));
33     } else {
34       PetscCall(MatMult(U, Na->work1, y));
35     }
36   } else {
37     Mat                Uloc, Vloc;
38     Vec                yl, xl;
39     const PetscScalar *w1;
40     PetscScalar       *w2;
41     PetscInt           nwork;
42     PetscMPIInt        mpinwork;
43 
44     xl = transpose ? Na->yl : Na->xl;
45     yl = transpose ? Na->xl : Na->yl;
46     PetscCall(VecGetLocalVector(y, yl));
47     PetscCall(MatDenseGetLocalMatrix(U, &Uloc));
48     PetscCall(MatDenseGetLocalMatrix(V, &Vloc));
49 
50     /* multiply the local part of V with the local part of x */
51     PetscCall(VecGetLocalVectorRead(x, xl));
52     PetscCall(MatMultHermitianTranspose(Vloc, xl, Na->work1));
53     PetscCall(VecRestoreLocalVectorRead(x, xl));
54 
55     /* form the sum of all the local multiplies: this is work2 = V'*x =
56        sum_{all processors} work1 */
57     PetscCall(VecGetArrayRead(Na->work1, &w1));
58     PetscCall(VecGetArrayWrite(Na->work2, &w2));
59     PetscCall(VecGetLocalSize(Na->work1, &nwork));
60     PetscCall(PetscMPIIntCast(nwork, &mpinwork));
61     PetscCall(MPIU_Allreduce(w1, w2, mpinwork, MPIU_SCALAR, MPIU_SUM, PetscObjectComm((PetscObject)N)));
62     PetscCall(VecRestoreArrayRead(Na->work1, &w1));
63     PetscCall(VecRestoreArrayWrite(Na->work2, &w2));
64 
65     if (Na->c) { /* work2 = C*work2 */
66       PetscCall(VecPointwiseMult(Na->work2, Na->c, Na->work2));
67     }
68 
69     if (Na->A) {
70       /* form y = A*x or A^t*x */
71       if (transpose) {
72         PetscCall(MatMultTranspose(Na->A, x, y));
73       } else {
74         PetscCall(MatMult(Na->A, x, y));
75       }
76       /* multiply-add y = y + U*work2 */
77       PetscCall(MatMultAdd(Uloc, Na->work2, yl, yl));
78     } else {
79       /* multiply y = U*work2 */
80       PetscCall(MatMult(Uloc, Na->work2, yl));
81     }
82 
83     PetscCall(VecRestoreLocalVector(y, yl));
84   }
85   PetscFunctionReturn(PETSC_SUCCESS);
86 }
87 
88 static PetscErrorCode MatMult_LRC(Mat N, Vec x, Vec y)
89 {
90   PetscFunctionBegin;
91   PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_FALSE));
92   PetscFunctionReturn(PETSC_SUCCESS);
93 }
94 
95 static PetscErrorCode MatMultTranspose_LRC(Mat N, Vec x, Vec y)
96 {
97   PetscFunctionBegin;
98   PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_TRUE));
99   PetscFunctionReturn(PETSC_SUCCESS);
100 }
101 
102 static PetscErrorCode MatDestroy_LRC(Mat N)
103 {
104   Mat_LRC *Na = (Mat_LRC *)N->data;
105 
106   PetscFunctionBegin;
107   PetscCall(MatDestroy(&Na->A));
108   PetscCall(MatDestroy(&Na->U));
109   PetscCall(MatDestroy(&Na->V));
110   PetscCall(VecDestroy(&Na->c));
111   PetscCall(VecDestroy(&Na->work1));
112   PetscCall(VecDestroy(&Na->work2));
113   PetscCall(VecDestroy(&Na->xl));
114   PetscCall(VecDestroy(&Na->yl));
115   PetscCall(PetscFree(N->data));
116   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCGetMats_C", NULL));
117   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCSetMats_C", NULL));
118   PetscFunctionReturn(PETSC_SUCCESS);
119 }
120 
121 static PetscErrorCode MatLRCGetMats_LRC(Mat N, Mat *A, Mat *U, Vec *c, Mat *V)
122 {
123   Mat_LRC *Na = (Mat_LRC *)N->data;
124 
125   PetscFunctionBegin;
126   if (A) *A = Na->A;
127   if (U) *U = Na->U;
128   if (c) *c = Na->c;
129   if (V) *V = Na->V;
130   PetscFunctionReturn(PETSC_SUCCESS);
131 }
132 
133 static PetscErrorCode MatLRCSetMats_LRC(Mat N, Mat A, Mat U, Vec c, Mat V)
134 {
135   Mat_LRC *Na = (Mat_LRC *)N->data;
136 
137   PetscFunctionBegin;
138   PetscCall(PetscObjectReference((PetscObject)A));
139   PetscCall(PetscObjectReference((PetscObject)U));
140   PetscCall(PetscObjectReference((PetscObject)V));
141   PetscCall(PetscObjectReference((PetscObject)c));
142   PetscCall(MatDestroy(&Na->A));
143   PetscCall(MatDestroy(&Na->U));
144   PetscCall(MatDestroy(&Na->V));
145   PetscCall(VecDestroy(&Na->c));
146   Na->A = A;
147   Na->U = U;
148   Na->c = c;
149   Na->V = V;
150   PetscFunctionReturn(PETSC_SUCCESS);
151 }
152 
153 /*@
154   MatLRCGetMats - Returns the constituents of an LRC matrix
155 
156   Not collective
157 
158   Input Parameter:
159 . N - matrix of type `MATLRC`
160 
161   Output Parameters:
162 + A - the (sparse) matrix
163 . U - first dense rectangular (tall and skinny) matrix
164 . c - a sequential vector containing the diagonal of C
165 - V - second dense rectangular (tall and skinny) matrix
166 
167   Level: intermediate
168 
169   Notes:
170   The returned matrices should not be destroyed by the caller.
171 
172   `U`, `c`, `V` may be `NULL` if not needed
173 
174 .seealso: [](ch_matrices), `MatLRCSetMats()`, `Mat`, `MATLRC`, `MatCreateLRC()`
175 @*/
176 PetscErrorCode MatLRCGetMats(Mat N, Mat *A, Mat *U, Vec *c, Mat *V)
177 {
178   PetscFunctionBegin;
179   PetscUseMethod(N, "MatLRCGetMats_C", (Mat, Mat *, Mat *, Vec *, Mat *), (N, A, U, c, V));
180   PetscFunctionReturn(PETSC_SUCCESS);
181 }
182 
183 /*@
184   MatLRCSetMats - Sets the constituents of an LRC matrix
185 
186   Logically collective
187 
188   Input Parameters:
189 + N - matrix of type `MATLRC`
190 . A - the (sparse) matrix
191 . U - first dense rectangular (tall and skinny) matrix
192 . c - a sequential vector containing the diagonal of C
193 - V - second dense rectangular (tall and skinny) matrix
194 
195   Level: intermediate
196 
197   Note:
198   If `V` is `NULL`, then it is assumed to be identical to `U`.
199 
200 .seealso: [](ch_matrices), `MatLRCGetMats()`, `Mat`, `MATLRC`, `MatCreateLRC()`
201 @*/
202 PetscErrorCode MatLRCSetMats(Mat N, Mat A, Mat U, Vec c, Mat V)
203 {
204   PetscInt  k, k1, m, n, m1, n1;
205   PetscBool match;
206 
207   PetscFunctionBegin;
208   if (A) PetscValidHeaderSpecific(A, MAT_CLASSID, 2);
209   PetscValidHeaderSpecific(U, MAT_CLASSID, 3);
210   if (c) PetscValidHeaderSpecific(c, VEC_CLASSID, 4);
211   if (V) {
212     PetscValidHeaderSpecific(V, MAT_CLASSID, 5);
213     PetscCheckSameComm(U, 3, V, 5);
214   }
215   if (A) PetscCheckSameComm(A, 2, U, 3);
216   if (!V) V = U;
217   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)U, &match, MATSEQDENSE, MATMPIDENSE, ""));
218   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix U must be of type dense, found %s", ((PetscObject)U)->type_name);
219   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)V, &match, MATSEQDENSE, MATMPIDENSE, ""));
220   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix V must be of type dense, found %s", ((PetscObject)V)->type_name);
221   PetscCall(PetscStrcmp(U->defaultvectype, V->defaultvectype, &match));
222   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix U and V must have the same VecType %s != %s", U->defaultvectype, V->defaultvectype);
223   if (A) {
224     PetscCall(PetscStrcmp(A->defaultvectype, U->defaultvectype, &match));
225     PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix A and U must have the same VecType %s != %s", A->defaultvectype, U->defaultvectype);
226   }
227   PetscCall(MatGetSize(U, NULL, &k));
228   PetscCall(MatGetSize(V, NULL, &k1));
229   PetscCheck(k == k1, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_INCOMP, "U and V have different number of columns (%" PetscInt_FMT " vs %" PetscInt_FMT ")", k, k1);
230   PetscCall(MatGetLocalSize(U, &m, NULL));
231   PetscCall(MatGetLocalSize(V, &n, NULL));
232   if (A) {
233     PetscCall(MatGetLocalSize(A, &m1, &n1));
234     PetscCheck(m == m1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of U %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", m, m1);
235     PetscCheck(n == n1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of V %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", n, n1);
236   }
237   if (c) {
238     PetscMPIInt size, csize;
239 
240     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U), &size));
241     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c), &csize));
242     PetscCall(VecGetSize(c, &k1));
243     PetscCheck(k == k1, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "The length of c %" PetscInt_FMT " does not match the number of columns of U and V (%" PetscInt_FMT ")", k1, k);
244     PetscCheck(csize == 1 || csize == size, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "U and c must have the same communicator size %d != %d", size, csize);
245   }
246   PetscCall(MatSetSizes(N, m, n, PETSC_DECIDE, PETSC_DECIDE));
247 
248   PetscUseMethod(N, "MatLRCSetMats_C", (Mat, Mat, Mat, Vec, Mat), (N, A, U, c, V));
249   PetscFunctionReturn(PETSC_SUCCESS);
250 }
251 
252 static PetscErrorCode MatSetUp_LRC(Mat N)
253 {
254   Mat_LRC    *Na = (Mat_LRC *)N->data;
255   Mat         A  = Na->A;
256   Mat         U  = Na->U;
257   Mat         V  = Na->V;
258   Vec         c  = Na->c;
259   Mat         Uloc;
260   PetscMPIInt size, csize = 0;
261 
262   PetscFunctionBegin;
263   PetscCall(MatSetVecType(N, U->defaultvectype));
264   // Flag matrix as symmetric if A is symmetric and U == V
265   PetscCall(MatSetOption(N, MAT_SYMMETRIC, (PetscBool)((A ? A->symmetric == PETSC_BOOL3_TRUE : PETSC_TRUE) && U == V)));
266   PetscCall(MatDenseGetLocalMatrix(Na->U, &Uloc));
267   PetscCall(MatCreateVecs(Uloc, &Na->work1, NULL));
268 
269   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U), &size));
270   if (c) PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c), &csize));
271   if (size != 1) {
272     Mat Vloc;
273 
274     if (Na->c && csize != 1) { /* scatter parallel vector to sequential */
275       VecScatter sct;
276 
277       PetscCall(VecScatterCreateToAll(Na->c, &sct, &c));
278       PetscCall(VecScatterBegin(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD));
279       PetscCall(VecScatterEnd(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD));
280       PetscCall(VecScatterDestroy(&sct));
281       PetscCall(VecDestroy(&Na->c));
282       Na->c = c;
283     }
284     PetscCall(MatDenseGetLocalMatrix(Na->V, &Vloc));
285     PetscCall(VecDuplicate(Na->work1, &Na->work2));
286     PetscCall(MatCreateVecs(Vloc, NULL, &Na->xl));
287     PetscCall(MatCreateVecs(Uloc, NULL, &Na->yl));
288   }
289   // Internally create a scaling vector if roottypes do not match
290   if (Na->c) {
291     VecType   rt1, rt2;
292     PetscBool match;
293 
294     PetscCall(VecGetRootType_Private(Na->work1, &rt1));
295     PetscCall(VecGetRootType_Private(Na->c, &rt2));
296     PetscCall(PetscStrcmp(rt1, rt2, &match));
297     if (!match) {
298       PetscCall(VecDuplicate(Na->c, &c));
299       PetscCall(VecCopy(Na->c, c));
300       PetscCall(VecDestroy(&Na->c));
301       Na->c = c;
302     }
303   }
304   N->assembled    = PETSC_TRUE;
305   N->preallocated = PETSC_TRUE;
306   PetscFunctionReturn(PETSC_SUCCESS);
307 }
308 
309 PETSC_EXTERN PetscErrorCode MatCreate_LRC(Mat N)
310 {
311   Mat_LRC *Na;
312 
313   PetscFunctionBegin;
314   PetscCall(PetscObjectChangeTypeName((PetscObject)N, MATLRC));
315   PetscCall(PetscNew(&Na));
316   N->data               = (void *)Na;
317   N->ops->destroy       = MatDestroy_LRC;
318   N->ops->setup         = MatSetUp_LRC;
319   N->ops->mult          = MatMult_LRC;
320   N->ops->multtranspose = MatMultTranspose_LRC;
321 
322   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCGetMats_C", MatLRCGetMats_LRC));
323   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCSetMats_C", MatLRCSetMats_LRC));
324   PetscFunctionReturn(PETSC_SUCCESS);
325 }
326 
327 /*MC
328   MATLRC -  "lrc" - a matrix object that behaves like A + U*C*V'
329 
330   Note:
331    The matrix A + U*C*V' is not formed! Rather the matrix  object performs the matrix-vector product `MatMult()`, by first multiplying by
332    A and then adding the other term.
333 
334   Level: advanced
335 
336 .seealso: [](ch_matrices), `Mat`, `MatCreateLRC()`, `MatMult()`, `MatLRCGetMats()`, `MatLRCSetMats()`
337 M*/
338 
339 /*@
340   MatCreateLRC - Creates a new matrix object that behaves like A + U*C*V' of type `MATLRC`
341 
342   Collective
343 
344   Input Parameters:
345 + A - the (sparse) matrix (can be `NULL`)
346 . U - dense rectangular (tall and skinny) matrix
347 . V - dense rectangular (tall and skinny) matrix
348 - c - a vector containing the diagonal of C (can be `NULL`)
349 
350   Output Parameter:
351 . N - the matrix that represents A + U*C*V'
352 
353   Level: intermediate
354 
355   Notes:
356   The matrix A + U*C*V' is not formed! Rather the new matrix
357   object performs the matrix-vector product `MatMult()`, by first multiplying by
358   A and then adding the other term.
359 
360   `C` is a diagonal matrix (represented as a vector) of order k,
361   where k is the number of columns of both `U` and `V`.
362 
363   If `A` is `NULL` then the new object behaves like a low-rank matrix U*C*V'.
364 
365   Use `V`=`U` (or `V`=`NULL`) for a symmetric low-rank correction, A + U*C*U'.
366 
367   If `c` is `NULL` then the low-rank correction is just U*V'.
368   If a sequential `c` vector is used for a parallel matrix,
369   PETSc assumes that the values of the vector are consistently set across processors.
370 
371 .seealso: [](ch_matrices), `Mat`, `MATLRC`, `MatLRCGetMats()`
372 @*/
373 PetscErrorCode MatCreateLRC(Mat A, Mat U, Vec c, Mat V, Mat *N)
374 {
375   PetscFunctionBegin;
376   PetscCall(MatCreate(PetscObjectComm((PetscObject)U), N));
377   PetscCall(MatSetType(*N, MATLRC));
378   PetscCall(MatLRCSetMats(*N, A, U, c, V));
379   PetscCall(MatSetUp(*N));
380   PetscFunctionReturn(PETSC_SUCCESS);
381 }
382