xref: /petsc/src/mat/impls/lrc/lrc.c (revision 9d13fa56c5c6523e02c36edc0e4e22bf2d0334a8)
1 
2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
3 
4 PETSC_EXTERN PetscErrorCode VecGetRootType_Private(Vec, VecType *);
5 
6 typedef struct {
7   Mat A;            /* sparse matrix */
8   Mat U, V;         /* dense tall-skinny matrices */
9   Vec c;            /* sequential vector containing the diagonal of C */
10   Vec work1, work2; /* sequential vectors that hold partial products */
11   Vec xl, yl;       /* auxiliary sequential vectors for matmult operation */
12 } Mat_LRC;
13 
14 static PetscErrorCode MatMult_LRC_kernel(Mat N, Vec x, Vec y, PetscBool transpose)
15 {
16   Mat_LRC    *Na = (Mat_LRC *)N->data;
17   PetscMPIInt size;
18   Mat         U, V;
19 
20   PetscFunctionBegin;
21   U = transpose ? Na->V : Na->U;
22   V = transpose ? Na->U : Na->V;
23   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)N), &size));
24   if (size == 1) {
25     PetscCall(MatMultHermitianTranspose(V, x, Na->work1));
26     if (Na->c) PetscCall(VecPointwiseMult(Na->work1, Na->c, Na->work1));
27     if (Na->A) {
28       if (transpose) {
29         PetscCall(MatMultTranspose(Na->A, x, y));
30       } else {
31         PetscCall(MatMult(Na->A, x, y));
32       }
33       PetscCall(MatMultAdd(U, Na->work1, y, y));
34     } else {
35       PetscCall(MatMult(U, Na->work1, y));
36     }
37   } else {
38     Mat                Uloc, Vloc;
39     Vec                yl, xl;
40     const PetscScalar *w1;
41     PetscScalar       *w2;
42     PetscInt           nwork;
43     PetscMPIInt        mpinwork;
44 
45     xl = transpose ? Na->yl : Na->xl;
46     yl = transpose ? Na->xl : Na->yl;
47     PetscCall(VecGetLocalVector(y, yl));
48     PetscCall(MatDenseGetLocalMatrix(U, &Uloc));
49     PetscCall(MatDenseGetLocalMatrix(V, &Vloc));
50 
51     /* multiply the local part of V with the local part of x */
52     PetscCall(VecGetLocalVectorRead(x, xl));
53     PetscCall(MatMultHermitianTranspose(Vloc, xl, Na->work1));
54     PetscCall(VecRestoreLocalVectorRead(x, xl));
55 
56     /* form the sum of all the local multiplies: this is work2 = V'*x =
57        sum_{all processors} work1 */
58     PetscCall(VecGetArrayRead(Na->work1, &w1));
59     PetscCall(VecGetArrayWrite(Na->work2, &w2));
60     PetscCall(VecGetLocalSize(Na->work1, &nwork));
61     PetscCall(PetscMPIIntCast(nwork, &mpinwork));
62     PetscCall(MPIU_Allreduce(w1, w2, mpinwork, MPIU_SCALAR, MPIU_SUM, PetscObjectComm((PetscObject)N)));
63     PetscCall(VecRestoreArrayRead(Na->work1, &w1));
64     PetscCall(VecRestoreArrayWrite(Na->work2, &w2));
65 
66     if (Na->c) { /* work2 = C*work2 */
67       PetscCall(VecPointwiseMult(Na->work2, Na->c, Na->work2));
68     }
69 
70     if (Na->A) {
71       /* form y = A*x or A^t*x */
72       if (transpose) {
73         PetscCall(MatMultTranspose(Na->A, x, y));
74       } else {
75         PetscCall(MatMult(Na->A, x, y));
76       }
77       /* multiply-add y = y + U*work2 */
78       PetscCall(MatMultAdd(Uloc, Na->work2, yl, yl));
79     } else {
80       /* multiply y = U*work2 */
81       PetscCall(MatMult(Uloc, Na->work2, yl));
82     }
83 
84     PetscCall(VecRestoreLocalVector(y, yl));
85   }
86   PetscFunctionReturn(PETSC_SUCCESS);
87 }
88 
89 static PetscErrorCode MatMult_LRC(Mat N, Vec x, Vec y)
90 {
91   PetscFunctionBegin;
92   PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_FALSE));
93   PetscFunctionReturn(PETSC_SUCCESS);
94 }
95 
96 static PetscErrorCode MatMultTranspose_LRC(Mat N, Vec x, Vec y)
97 {
98   PetscFunctionBegin;
99   PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_TRUE));
100   PetscFunctionReturn(PETSC_SUCCESS);
101 }
102 
103 static PetscErrorCode MatDestroy_LRC(Mat N)
104 {
105   Mat_LRC *Na = (Mat_LRC *)N->data;
106 
107   PetscFunctionBegin;
108   PetscCall(MatDestroy(&Na->A));
109   PetscCall(MatDestroy(&Na->U));
110   PetscCall(MatDestroy(&Na->V));
111   PetscCall(VecDestroy(&Na->c));
112   PetscCall(VecDestroy(&Na->work1));
113   PetscCall(VecDestroy(&Na->work2));
114   PetscCall(VecDestroy(&Na->xl));
115   PetscCall(VecDestroy(&Na->yl));
116   PetscCall(PetscFree(N->data));
117   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCGetMats_C", NULL));
118   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCSetMats_C", NULL));
119   PetscFunctionReturn(PETSC_SUCCESS);
120 }
121 
122 static PetscErrorCode MatLRCGetMats_LRC(Mat N, Mat *A, Mat *U, Vec *c, Mat *V)
123 {
124   Mat_LRC *Na = (Mat_LRC *)N->data;
125 
126   PetscFunctionBegin;
127   if (A) *A = Na->A;
128   if (U) *U = Na->U;
129   if (c) *c = Na->c;
130   if (V) *V = Na->V;
131   PetscFunctionReturn(PETSC_SUCCESS);
132 }
133 
134 static PetscErrorCode MatLRCSetMats_LRC(Mat N, Mat A, Mat U, Vec c, Mat V)
135 {
136   Mat_LRC *Na = (Mat_LRC *)N->data;
137 
138   PetscFunctionBegin;
139   PetscCall(PetscObjectReference((PetscObject)A));
140   PetscCall(PetscObjectReference((PetscObject)U));
141   PetscCall(PetscObjectReference((PetscObject)V));
142   PetscCall(PetscObjectReference((PetscObject)c));
143   PetscCall(MatDestroy(&Na->A));
144   PetscCall(MatDestroy(&Na->U));
145   PetscCall(MatDestroy(&Na->V));
146   PetscCall(VecDestroy(&Na->c));
147   Na->A = A;
148   Na->U = U;
149   Na->c = c;
150   Na->V = V;
151   PetscFunctionReturn(PETSC_SUCCESS);
152 }
153 
154 /*@
155    MatLRCGetMats - Returns the constituents of an LRC matrix
156 
157    Not collective
158 
159    Input Parameter:
160 .  N - matrix of type `MATLRC`
161 
162    Output Parameters:
163 +  A - the (sparse) matrix
164 .  U - first dense rectangular (tall and skinny) matrix
165 .  c - a sequential vector containing the diagonal of C
166 -  V - second dense rectangular (tall and skinny) matrix
167 
168    Level: intermediate
169 
170    Notes:
171    The returned matrices should not be destroyed by the caller.
172 
173    `U`, `c`, `V` may be `NULL` if not needed
174 
175 .seealso: [](ch_matrices), `MatLRCSetMats()`, `Mat`, `MATLRC`, `MatCreateLRC()`
176 @*/
177 PetscErrorCode MatLRCGetMats(Mat N, Mat *A, Mat *U, Vec *c, Mat *V)
178 {
179   PetscFunctionBegin;
180   PetscUseMethod(N, "MatLRCGetMats_C", (Mat, Mat *, Mat *, Vec *, Mat *), (N, A, U, c, V));
181   PetscFunctionReturn(PETSC_SUCCESS);
182 }
183 
184 /*@
185    MatLRCSetMats - Sets the constituents of an LRC matrix
186 
187    Logically collective
188 
189    Input Parameters:
190 +  N - matrix of type `MATLRC`
191 .  A - the (sparse) matrix
192 .  U - first dense rectangular (tall and skinny) matrix
193 .  c - a sequential vector containing the diagonal of C
194 -  V - second dense rectangular (tall and skinny) matrix
195 
196    Level: intermediate
197 
198    Note:
199    If `V` is `NULL`, then it is assumed to be identical to `U`.
200 
201 .seealso: [](ch_matrices), `MatLRCGetMats()`, `Mat`, `MATLRC`, `MatCreateLRC()`
202 @*/
203 PetscErrorCode MatLRCSetMats(Mat N, Mat A, Mat U, Vec c, Mat V)
204 {
205   PetscInt  k, k1, m, n, m1, n1;
206   PetscBool match;
207 
208   PetscFunctionBegin;
209   if (A) PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
210   PetscValidHeaderSpecific(U, MAT_CLASSID, 2);
211   if (c) PetscValidHeaderSpecific(c, VEC_CLASSID, 3);
212   if (V) {
213     PetscValidHeaderSpecific(V, MAT_CLASSID, 4);
214     PetscCheckSameComm(U, 2, V, 4);
215   }
216   if (A) PetscCheckSameComm(A, 1, U, 2);
217   if (!V) V = U;
218   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)U, &match, MATSEQDENSE, MATMPIDENSE, ""));
219   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix U must be of type dense, found %s", ((PetscObject)U)->type_name);
220   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)V, &match, MATSEQDENSE, MATMPIDENSE, ""));
221   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix V must be of type dense, found %s", ((PetscObject)V)->type_name);
222   PetscCall(PetscStrcmp(U->defaultvectype, V->defaultvectype, &match));
223   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix U and V must have the same VecType %s != %s", U->defaultvectype, V->defaultvectype);
224   if (A) {
225     PetscCall(PetscStrcmp(A->defaultvectype, U->defaultvectype, &match));
226     PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix A and U must have the same VecType %s != %s", A->defaultvectype, U->defaultvectype);
227   }
228   PetscCall(MatGetSize(U, NULL, &k));
229   PetscCall(MatGetSize(V, NULL, &k1));
230   PetscCheck(k == k1, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_INCOMP, "U and V have different number of columns (%" PetscInt_FMT " vs %" PetscInt_FMT ")", k, k1);
231   PetscCall(MatGetLocalSize(U, &m, NULL));
232   PetscCall(MatGetLocalSize(V, &n, NULL));
233   if (A) {
234     PetscCall(MatGetLocalSize(A, &m1, &n1));
235     PetscCheck(m == m1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of U %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", m, m1);
236     PetscCheck(n == n1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of V %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", n, n1);
237   }
238   if (c) {
239     PetscMPIInt size, csize;
240 
241     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U), &size));
242     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c), &csize));
243     PetscCall(VecGetSize(c, &k1));
244     PetscCheck(k == k1, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "The length of c %" PetscInt_FMT " does not match the number of columns of U and V (%" PetscInt_FMT ")", k1, k);
245     PetscCheck(csize == 1 || csize == size, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "U and c must have the same communicator size %d != %d", size, csize);
246   }
247   PetscCall(MatSetSizes(N, m, n, PETSC_DECIDE, PETSC_DECIDE));
248 
249   PetscUseMethod(N, "MatLRCSetMats_C", (Mat, Mat, Mat, Vec, Mat), (N, A, U, c, V));
250   PetscFunctionReturn(PETSC_SUCCESS);
251 }
252 
253 static PetscErrorCode MatSetUp_LRC(Mat N)
254 {
255   Mat_LRC    *Na = (Mat_LRC *)N->data;
256   Mat         A  = Na->A;
257   Mat         U  = Na->U;
258   Mat         V  = Na->V;
259   Vec         c  = Na->c;
260   Mat         Uloc;
261   PetscMPIInt size, csize = 0;
262 
263   PetscFunctionBegin;
264   PetscCall(MatSetVecType(N, U->defaultvectype));
265   // Flag matrix as symmetric if A is symmetric and U == V
266   PetscCall(MatSetOption(N, MAT_SYMMETRIC, (PetscBool)((A ? A->symmetric == PETSC_BOOL3_TRUE : PETSC_TRUE) && U == V)));
267   PetscCall(MatDenseGetLocalMatrix(Na->U, &Uloc));
268   PetscCall(MatCreateVecs(Uloc, &Na->work1, NULL));
269 
270   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U), &size));
271   if (c) PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c), &csize));
272   if (size != 1) {
273     Mat Vloc;
274 
275     if (Na->c && csize != 1) { /* scatter parallel vector to sequential */
276       VecScatter sct;
277 
278       PetscCall(VecScatterCreateToAll(Na->c, &sct, &c));
279       PetscCall(VecScatterBegin(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD));
280       PetscCall(VecScatterEnd(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD));
281       PetscCall(VecScatterDestroy(&sct));
282       PetscCall(VecDestroy(&Na->c));
283       Na->c = c;
284     }
285     PetscCall(MatDenseGetLocalMatrix(Na->V, &Vloc));
286     PetscCall(VecDuplicate(Na->work1, &Na->work2));
287     PetscCall(MatCreateVecs(Vloc, NULL, &Na->xl));
288     PetscCall(MatCreateVecs(Uloc, NULL, &Na->yl));
289   }
290   // Internally create a scaling vector if roottypes do not match
291   if (Na->c) {
292     VecType   rt1, rt2;
293     PetscBool match;
294 
295     PetscCall(VecGetRootType_Private(Na->work1, &rt1));
296     PetscCall(VecGetRootType_Private(Na->c, &rt2));
297     PetscCall(PetscStrcmp(rt1, rt2, &match));
298     if (!match) {
299       PetscCall(VecDuplicate(Na->c, &c));
300       PetscCall(VecCopy(Na->c, c));
301       PetscCall(VecDestroy(&Na->c));
302       Na->c = c;
303     }
304   }
305   N->assembled    = PETSC_TRUE;
306   N->preallocated = PETSC_TRUE;
307   PetscFunctionReturn(PETSC_SUCCESS);
308 }
309 
310 PETSC_EXTERN PetscErrorCode MatCreate_LRC(Mat N)
311 {
312   Mat_LRC *Na;
313 
314   PetscFunctionBegin;
315   PetscCall(PetscObjectChangeTypeName((PetscObject)N, MATLRC));
316   PetscCall(PetscNew(&Na));
317   N->data               = (void *)Na;
318   N->ops->destroy       = MatDestroy_LRC;
319   N->ops->setup         = MatSetUp_LRC;
320   N->ops->mult          = MatMult_LRC;
321   N->ops->multtranspose = MatMultTranspose_LRC;
322 
323   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCGetMats_C", MatLRCGetMats_LRC));
324   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCSetMats_C", MatLRCSetMats_LRC));
325   PetscFunctionReturn(PETSC_SUCCESS);
326 }
327 
328 /*MC
329   MATLRC -  "lrc" - a matrix object that behaves like A + U*C*V'
330 
331   Note:
332    The matrix A + U*C*V' is not formed! Rather the matrix  object performs the matrix-vector product `MatMult()`, by first multiplying by
333    A and then adding the other term.
334 
335   Level: advanced
336 
337 .seealso: [](ch_matrices), `Mat`, `MatCreateLRC()`, `MatMult()`, `MatLRCGetMats()`, `MatLRCSetMats()`
338 M*/
339 
340 /*@
341    MatCreateLRC - Creates a new matrix object that behaves like A + U*C*V' of type `MATLRC`
342 
343    Collective
344 
345    Input Parameters:
346 +  A    - the (sparse) matrix (can be `NULL`)
347 .  U    - dense rectangular (tall and skinny) matrix
348 .  V    - dense rectangular (tall and skinny) matrix
349 -  c    - a vector containing the diagonal of C (can be `NULL`)
350 
351    Output Parameter:
352 .  N    - the matrix that represents A + U*C*V'
353 
354    Level: intermediate
355 
356    Notes:
357    The matrix A + U*C*V' is not formed! Rather the new matrix
358    object performs the matrix-vector product `MatMult()`, by first multiplying by
359    A and then adding the other term.
360 
361    `C` is a diagonal matrix (represented as a vector) of order k,
362    where k is the number of columns of both `U` and `V`.
363 
364    If `A` is `NULL` then the new object behaves like a low-rank matrix U*C*V'.
365 
366    Use `V`=`U` (or `V`=`NULL`) for a symmetric low-rank correction, A + U*C*U'.
367 
368    If `c` is `NULL` then the low-rank correction is just U*V'.
369    If a sequential `c` vector is used for a parallel matrix,
370    PETSc assumes that the values of the vector are consistently set across processors.
371 
372 .seealso: [](ch_matrices), `Mat`, `MATLRC`, `MatLRCGetMats()`
373 @*/
374 PetscErrorCode MatCreateLRC(Mat A, Mat U, Vec c, Mat V, Mat *N)
375 {
376   PetscFunctionBegin;
377   PetscCall(MatCreate(PetscObjectComm((PetscObject)U), N));
378   PetscCall(MatSetType(*N, MATLRC));
379   PetscCall(MatLRCSetMats(*N, A, U, c, V));
380   PetscCall(MatSetUp(*N));
381   PetscFunctionReturn(PETSC_SUCCESS);
382 }
383