xref: /petsc/src/mat/impls/lrc/lrc.c (revision df4cd43f92eaa320656440c40edb1046daee8f75)
1 
2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
3 
4 PETSC_EXTERN PetscErrorCode VecGetRootType_Private(Vec, VecType *);
5 
6 typedef struct {
7   Mat A;            /* sparse matrix */
8   Mat U, V;         /* dense tall-skinny matrices */
9   Vec c;            /* sequential vector containing the diagonal of C */
10   Vec work1, work2; /* sequential vectors that hold partial products */
11   Vec xl, yl;       /* auxiliary sequential vectors for matmult operation */
12 } Mat_LRC;
13 
14 static PetscErrorCode MatMult_LRC_kernel(Mat N, Vec x, Vec y, PetscBool transpose)
15 {
16   Mat_LRC    *Na = (Mat_LRC *)N->data;
17   PetscMPIInt size;
18   Mat         U, V;
19 
20   PetscFunctionBegin;
21   U = transpose ? Na->V : Na->U;
22   V = transpose ? Na->U : Na->V;
23   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)N), &size));
24   if (size == 1) {
25     PetscCall(MatMultHermitianTranspose(V, x, Na->work1));
26     if (Na->c) PetscCall(VecPointwiseMult(Na->work1, Na->c, Na->work1));
27     if (Na->A) {
28       if (transpose) {
29         PetscCall(MatMultTranspose(Na->A, x, y));
30       } else {
31         PetscCall(MatMult(Na->A, x, y));
32       }
33       PetscCall(MatMultAdd(U, Na->work1, y, y));
34     } else {
35       PetscCall(MatMult(U, Na->work1, y));
36     }
37   } else {
38     Mat                Uloc, Vloc;
39     Vec                yl, xl;
40     const PetscScalar *w1;
41     PetscScalar       *w2;
42     PetscInt           nwork;
43     PetscMPIInt        mpinwork;
44 
45     xl = transpose ? Na->yl : Na->xl;
46     yl = transpose ? Na->xl : Na->yl;
47     PetscCall(VecGetLocalVector(y, yl));
48     PetscCall(MatDenseGetLocalMatrix(U, &Uloc));
49     PetscCall(MatDenseGetLocalMatrix(V, &Vloc));
50 
51     /* multiply the local part of V with the local part of x */
52     PetscCall(VecGetLocalVectorRead(x, xl));
53     PetscCall(MatMultHermitianTranspose(Vloc, xl, Na->work1));
54     PetscCall(VecRestoreLocalVectorRead(x, xl));
55 
56     /* form the sum of all the local multiplies: this is work2 = V'*x =
57        sum_{all processors} work1 */
58     PetscCall(VecGetArrayRead(Na->work1, &w1));
59     PetscCall(VecGetArrayWrite(Na->work2, &w2));
60     PetscCall(VecGetLocalSize(Na->work1, &nwork));
61     PetscCall(PetscMPIIntCast(nwork, &mpinwork));
62     PetscCall(MPIU_Allreduce(w1, w2, mpinwork, MPIU_SCALAR, MPIU_SUM, PetscObjectComm((PetscObject)N)));
63     PetscCall(VecRestoreArrayRead(Na->work1, &w1));
64     PetscCall(VecRestoreArrayWrite(Na->work2, &w2));
65 
66     if (Na->c) { /* work2 = C*work2 */
67       PetscCall(VecPointwiseMult(Na->work2, Na->c, Na->work2));
68     }
69 
70     if (Na->A) {
71       /* form y = A*x or A^t*x */
72       if (transpose) {
73         PetscCall(MatMultTranspose(Na->A, x, y));
74       } else {
75         PetscCall(MatMult(Na->A, x, y));
76       }
77       /* multiply-add y = y + U*work2 */
78       PetscCall(MatMultAdd(Uloc, Na->work2, yl, yl));
79     } else {
80       /* multiply y = U*work2 */
81       PetscCall(MatMult(Uloc, Na->work2, yl));
82     }
83 
84     PetscCall(VecRestoreLocalVector(y, yl));
85   }
86   PetscFunctionReturn(PETSC_SUCCESS);
87 }
88 
89 static PetscErrorCode MatMult_LRC(Mat N, Vec x, Vec y)
90 {
91   PetscFunctionBegin;
92   PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_FALSE));
93   PetscFunctionReturn(PETSC_SUCCESS);
94 }
95 
96 static PetscErrorCode MatMultTranspose_LRC(Mat N, Vec x, Vec y)
97 {
98   PetscFunctionBegin;
99   PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_TRUE));
100   PetscFunctionReturn(PETSC_SUCCESS);
101 }
102 
103 static PetscErrorCode MatDestroy_LRC(Mat N)
104 {
105   Mat_LRC *Na = (Mat_LRC *)N->data;
106 
107   PetscFunctionBegin;
108   PetscCall(MatDestroy(&Na->A));
109   PetscCall(MatDestroy(&Na->U));
110   PetscCall(MatDestroy(&Na->V));
111   PetscCall(VecDestroy(&Na->c));
112   PetscCall(VecDestroy(&Na->work1));
113   PetscCall(VecDestroy(&Na->work2));
114   PetscCall(VecDestroy(&Na->xl));
115   PetscCall(VecDestroy(&Na->yl));
116   PetscCall(PetscFree(N->data));
117   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCGetMats_C", NULL));
118   PetscFunctionReturn(PETSC_SUCCESS);
119 }
120 
121 static PetscErrorCode MatLRCGetMats_LRC(Mat N, Mat *A, Mat *U, Vec *c, Mat *V)
122 {
123   Mat_LRC *Na = (Mat_LRC *)N->data;
124 
125   PetscFunctionBegin;
126   if (A) *A = Na->A;
127   if (U) *U = Na->U;
128   if (c) *c = Na->c;
129   if (V) *V = Na->V;
130   PetscFunctionReturn(PETSC_SUCCESS);
131 }
132 
133 /*@
134    MatLRCGetMats - Returns the constituents of an LRC matrix
135 
136    Collective
137 
138    Input Parameter:
139 .  N - matrix of type `MATLRC`
140 
141    Output Parameters:
142 +  A - the (sparse) matrix
143 .  U - first dense rectangular (tall and skinny) matrix
144 .  c - a sequential vector containing the diagonal of C
145 -  V - second dense rectangular (tall and skinny) matrix
146 
147    Level: intermediate
148 
149    Notes:
150    The returned matrices need not be destroyed by the caller.
151 
152    `U`, `c`, `V` may be `NULL` if not needed
153 
154 .seealso: [](chapter_matrices), `Mat`, `MATLRC`, `MatCreateLRC()`
155 @*/
156 PetscErrorCode MatLRCGetMats(Mat N, Mat *A, Mat *U, Vec *c, Mat *V)
157 {
158   PetscFunctionBegin;
159   PetscUseMethod(N, "MatLRCGetMats_C", (Mat, Mat *, Mat *, Vec *, Mat *), (N, A, U, c, V));
160   PetscFunctionReturn(PETSC_SUCCESS);
161 }
162 
163 /*MC
164   MATLRC -  "lrc" - a matrix object that behaves like A + U*C*V'
165 
166   Note:
167    The matrix A + U*C*V' is not formed! Rather the matrix  object performs the matrix-vector product `MatMult()`, by first multiplying by
168    A and then adding the other term.
169 
170   Level: advanced
171 
172 .seealso: [](chapter_matrices), `Mat`, `MatCreateLRC()`, `MatMult()`, `MatLRCGetMats()`
173 M*/
174 
175 /*@
176    MatCreateLRC - Creates a new matrix object that behaves like A + U*C*V' of type `MATLRC`
177 
178    Collective
179 
180    Input Parameters:
181 +  A    - the (sparse) matrix (can be `NULL`)
182 .  U    - dense rectangular (tall and skinny) matrix
183 .  V    - dense rectangular (tall and skinny) matrix
184 -  c    - a vector containing the diagonal of C (can be `NULL`)
185 
186    Output Parameter:
187 .  N    - the matrix that represents A + U*C*V'
188 
189    Level: intermediate
190 
191    Notes:
192    The matrix A + U*C*V' is not formed! Rather the new matrix
193    object performs the matrix-vector product `MatMult()`, by first multiplying by
194    A and then adding the other term.
195 
196    `C` is a diagonal matrix (represented as a vector) of order k,
197    where k is the number of columns of both `U` and `V`.
198 
199    If `A` is `NULL` then the new object behaves like a low-rank matrix U*C*V'.
200 
201    Use `V`=`U` (or `V`=`NULL`) for a symmetric low-rank correction, A + U*C*U'.
202 
203    If `c` is `NULL` then the low-rank correction is just U*V'.
204    If a sequential `c` vector is used for a parallel matrix,
205    PETSc assumes that the values of the vector are consistently set across processors.
206 
207 .seealso: [](chapter_matrices), `Mat`, `MATLRC`, `MatLRCGetMats()`
208 @*/
209 PetscErrorCode MatCreateLRC(Mat A, Mat U, Vec c, Mat V, Mat *N)
210 {
211   PetscBool   match;
212   PetscInt    m, n, k, m1, n1, k1;
213   Mat_LRC    *Na;
214   Mat         Uloc;
215   PetscMPIInt size, csize = 0;
216 
217   PetscFunctionBegin;
218   if (A) PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
219   PetscValidHeaderSpecific(U, MAT_CLASSID, 2);
220   if (c) PetscValidHeaderSpecific(c, VEC_CLASSID, 3);
221   if (V) {
222     PetscValidHeaderSpecific(V, MAT_CLASSID, 4);
223     PetscCheckSameComm(U, 2, V, 4);
224   }
225   if (A) PetscCheckSameComm(A, 1, U, 2);
226 
227   if (!V) V = U;
228   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)U, &match, MATSEQDENSE, MATMPIDENSE, ""));
229   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix U must be of type dense, found %s", ((PetscObject)U)->type_name);
230   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)V, &match, MATSEQDENSE, MATMPIDENSE, ""));
231   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix V must be of type dense, found %s", ((PetscObject)V)->type_name);
232   PetscCall(PetscStrcmp(U->defaultvectype, V->defaultvectype, &match));
233   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix U and V must have the same VecType %s != %s", U->defaultvectype, V->defaultvectype);
234   if (A) {
235     PetscCall(PetscStrcmp(A->defaultvectype, U->defaultvectype, &match));
236     PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix A and U must have the same VecType %s != %s", A->defaultvectype, U->defaultvectype);
237   }
238 
239   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U), &size));
240   PetscCall(MatGetSize(U, NULL, &k));
241   PetscCall(MatGetSize(V, NULL, &k1));
242   PetscCheck(k == k1, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_INCOMP, "U and V have different number of columns (%" PetscInt_FMT " vs %" PetscInt_FMT ")", k, k1);
243   PetscCall(MatGetLocalSize(U, &m, NULL));
244   PetscCall(MatGetLocalSize(V, &n, NULL));
245   if (A) {
246     PetscCall(MatGetLocalSize(A, &m1, &n1));
247     PetscCheck(m == m1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of U %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", m, m1);
248     PetscCheck(n == n1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of V %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", n, n1);
249   }
250   if (c) {
251     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c), &csize));
252     PetscCall(VecGetSize(c, &k1));
253     PetscCheck(k == k1, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "The length of c %" PetscInt_FMT " does not match the number of columns of U and V (%" PetscInt_FMT ")", k1, k);
254     PetscCheck(csize == 1 || csize == size, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "U and c must have the same communicator size %d != %d", size, csize);
255   }
256 
257   PetscCall(MatCreate(PetscObjectComm((PetscObject)U), N));
258   PetscCall(MatSetSizes(*N, m, n, PETSC_DECIDE, PETSC_DECIDE));
259   PetscCall(MatSetVecType(*N, U->defaultvectype));
260   PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATLRC));
261   /* Flag matrix as symmetric if A is symmetric and U == V */
262   PetscCall(MatSetOption(*N, MAT_SYMMETRIC, (PetscBool)((A ? A->symmetric == PETSC_BOOL3_TRUE : PETSC_TRUE) && U == V)));
263 
264   PetscCall(PetscNew(&Na));
265   (*N)->data = (void *)Na;
266   Na->A      = A;
267   Na->U      = U;
268   Na->c      = c;
269   Na->V      = V;
270 
271   PetscCall(PetscObjectReference((PetscObject)A));
272   PetscCall(PetscObjectReference((PetscObject)Na->U));
273   PetscCall(PetscObjectReference((PetscObject)Na->V));
274   PetscCall(PetscObjectReference((PetscObject)c));
275 
276   PetscCall(MatDenseGetLocalMatrix(Na->U, &Uloc));
277   PetscCall(MatCreateVecs(Uloc, &Na->work1, NULL));
278   if (size != 1) {
279     Mat Vloc;
280 
281     if (Na->c && csize != 1) { /* scatter parallel vector to sequential */
282       VecScatter sct;
283 
284       PetscCall(VecScatterCreateToAll(Na->c, &sct, &c));
285       PetscCall(VecScatterBegin(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD));
286       PetscCall(VecScatterEnd(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD));
287       PetscCall(VecScatterDestroy(&sct));
288       PetscCall(VecDestroy(&Na->c));
289       Na->c = c;
290     }
291     PetscCall(MatDenseGetLocalMatrix(Na->V, &Vloc));
292     PetscCall(VecDuplicate(Na->work1, &Na->work2));
293     PetscCall(MatCreateVecs(Vloc, NULL, &Na->xl));
294     PetscCall(MatCreateVecs(Uloc, NULL, &Na->yl));
295   }
296 
297   /* Internally create a scaling vector if roottypes do not match */
298   if (Na->c) {
299     VecType rt1, rt2;
300 
301     PetscCall(VecGetRootType_Private(Na->work1, &rt1));
302     PetscCall(VecGetRootType_Private(Na->c, &rt2));
303     PetscCall(PetscStrcmp(rt1, rt2, &match));
304     if (!match) {
305       PetscCall(VecDuplicate(Na->c, &c));
306       PetscCall(VecCopy(Na->c, c));
307       PetscCall(VecDestroy(&Na->c));
308       Na->c = c;
309     }
310   }
311 
312   (*N)->ops->destroy       = MatDestroy_LRC;
313   (*N)->ops->mult          = MatMult_LRC;
314   (*N)->ops->multtranspose = MatMultTranspose_LRC;
315 
316   (*N)->assembled    = PETSC_TRUE;
317   (*N)->preallocated = PETSC_TRUE;
318 
319   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatLRCGetMats_C", MatLRCGetMats_LRC));
320   PetscCall(MatSetUp(*N));
321   PetscFunctionReturn(PETSC_SUCCESS);
322 }
323