xref: /petsc/src/mat/impls/lrc/lrc.c (revision d2522c19e8fa9bca20aaca277941d9a63e71db6a)
1 
2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
3 
4 PETSC_EXTERN PetscErrorCode VecGetRootType_Private(Vec, VecType *);
5 
6 typedef struct {
7   Mat A;            /* sparse matrix */
8   Mat U, V;         /* dense tall-skinny matrices */
9   Vec c;            /* sequential vector containing the diagonal of C */
10   Vec work1, work2; /* sequential vectors that hold partial products */
11   Vec xl, yl;       /* auxiliary sequential vectors for matmult operation */
12 } Mat_LRC;
13 
14 static PetscErrorCode MatMult_LRC_kernel(Mat N, Vec x, Vec y, PetscBool transpose) {
15   Mat_LRC    *Na = (Mat_LRC *)N->data;
16   PetscMPIInt size;
17   Mat         U, V;
18 
19   PetscFunctionBegin;
20   U = transpose ? Na->V : Na->U;
21   V = transpose ? Na->U : Na->V;
22   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)N), &size));
23   if (size == 1) {
24     PetscCall(MatMultHermitianTranspose(V, x, Na->work1));
25     if (Na->c) PetscCall(VecPointwiseMult(Na->work1, Na->c, Na->work1));
26     if (Na->A) {
27       if (transpose) {
28         PetscCall(MatMultTranspose(Na->A, x, y));
29       } else {
30         PetscCall(MatMult(Na->A, x, y));
31       }
32       PetscCall(MatMultAdd(U, Na->work1, y, y));
33     } else {
34       PetscCall(MatMult(U, Na->work1, y));
35     }
36   } else {
37     Mat                Uloc, Vloc;
38     Vec                yl, xl;
39     const PetscScalar *w1;
40     PetscScalar       *w2;
41     PetscInt           nwork;
42     PetscMPIInt        mpinwork;
43 
44     xl = transpose ? Na->yl : Na->xl;
45     yl = transpose ? Na->xl : Na->yl;
46     PetscCall(VecGetLocalVector(y, yl));
47     PetscCall(MatDenseGetLocalMatrix(U, &Uloc));
48     PetscCall(MatDenseGetLocalMatrix(V, &Vloc));
49 
50     /* multiply the local part of V with the local part of x */
51     PetscCall(VecGetLocalVectorRead(x, xl));
52     PetscCall(MatMultHermitianTranspose(Vloc, xl, Na->work1));
53     PetscCall(VecRestoreLocalVectorRead(x, xl));
54 
55     /* form the sum of all the local multiplies: this is work2 = V'*x =
56        sum_{all processors} work1 */
57     PetscCall(VecGetArrayRead(Na->work1, &w1));
58     PetscCall(VecGetArrayWrite(Na->work2, &w2));
59     PetscCall(VecGetLocalSize(Na->work1, &nwork));
60     PetscCall(PetscMPIIntCast(nwork, &mpinwork));
61     PetscCall(MPIU_Allreduce(w1, w2, mpinwork, MPIU_SCALAR, MPIU_SUM, PetscObjectComm((PetscObject)N)));
62     PetscCall(VecRestoreArrayRead(Na->work1, &w1));
63     PetscCall(VecRestoreArrayWrite(Na->work2, &w2));
64 
65     if (Na->c) { /* work2 = C*work2 */
66       PetscCall(VecPointwiseMult(Na->work2, Na->c, Na->work2));
67     }
68 
69     if (Na->A) {
70       /* form y = A*x or A^t*x */
71       if (transpose) {
72         PetscCall(MatMultTranspose(Na->A, x, y));
73       } else {
74         PetscCall(MatMult(Na->A, x, y));
75       }
76       /* multiply-add y = y + U*work2 */
77       PetscCall(MatMultAdd(Uloc, Na->work2, yl, yl));
78     } else {
79       /* multiply y = U*work2 */
80       PetscCall(MatMult(Uloc, Na->work2, yl));
81     }
82 
83     PetscCall(VecRestoreLocalVector(y, yl));
84   }
85   PetscFunctionReturn(0);
86 }
87 
88 static PetscErrorCode MatMult_LRC(Mat N, Vec x, Vec y) {
89   PetscFunctionBegin;
90   PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_FALSE));
91   PetscFunctionReturn(0);
92 }
93 
94 static PetscErrorCode MatMultTranspose_LRC(Mat N, Vec x, Vec y) {
95   PetscFunctionBegin;
96   PetscCall(MatMult_LRC_kernel(N, x, y, PETSC_TRUE));
97   PetscFunctionReturn(0);
98 }
99 
100 static PetscErrorCode MatDestroy_LRC(Mat N) {
101   Mat_LRC *Na = (Mat_LRC *)N->data;
102 
103   PetscFunctionBegin;
104   PetscCall(MatDestroy(&Na->A));
105   PetscCall(MatDestroy(&Na->U));
106   PetscCall(MatDestroy(&Na->V));
107   PetscCall(VecDestroy(&Na->c));
108   PetscCall(VecDestroy(&Na->work1));
109   PetscCall(VecDestroy(&Na->work2));
110   PetscCall(VecDestroy(&Na->xl));
111   PetscCall(VecDestroy(&Na->yl));
112   PetscCall(PetscFree(N->data));
113   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatLRCGetMats_C", NULL));
114   PetscFunctionReturn(0);
115 }
116 
117 static PetscErrorCode MatLRCGetMats_LRC(Mat N, Mat *A, Mat *U, Vec *c, Mat *V) {
118   Mat_LRC *Na = (Mat_LRC *)N->data;
119 
120   PetscFunctionBegin;
121   if (A) *A = Na->A;
122   if (U) *U = Na->U;
123   if (c) *c = Na->c;
124   if (V) *V = Na->V;
125   PetscFunctionReturn(0);
126 }
127 
128 /*@
129    MatLRCGetMats - Returns the constituents of an LRC matrix
130 
131    Collective on Mat
132 
133    Input Parameter:
134 .  N - matrix of type LRC
135 
136    Output Parameters:
137 +  A - the (sparse) matrix
138 .  U - first dense rectangular (tall and skinny) matrix
139 .  c - a sequential vector containing the diagonal of C
140 -  V - second dense rectangular (tall and skinny) matrix
141 
142    Note:
143    The returned matrices need not be destroyed by the caller.
144 
145    Level: intermediate
146 
147 .seealso: `MatCreateLRC()`
148 @*/
149 PetscErrorCode MatLRCGetMats(Mat N, Mat *A, Mat *U, Vec *c, Mat *V) {
150   PetscFunctionBegin;
151   PetscUseMethod(N, "MatLRCGetMats_C", (Mat, Mat *, Mat *, Vec *, Mat *), (N, A, U, c, V));
152   PetscFunctionReturn(0);
153 }
154 
155 /*@
156    MatCreateLRC - Creates a new matrix object that behaves like A + U*C*V'
157 
158    Collective on Mat
159 
160    Input Parameters:
161 +  A    - the (sparse) matrix (can be NULL)
162 .  U, V - two dense rectangular (tall and skinny) matrices
163 -  c    - a vector containing the diagonal of C (can be NULL)
164 
165    Output Parameter:
166 .  N    - the matrix that represents A + U*C*V'
167 
168    Notes:
169    The matrix A + U*C*V' is not formed! Rather the new matrix
170    object performs the matrix-vector product by first multiplying by
171    A and then adding the other term.
172 
173    C is a diagonal matrix (represented as a vector) of order k,
174    where k is the number of columns of both U and V.
175 
176    If A is NULL then the new object behaves like a low-rank matrix U*C*V'.
177 
178    Use V=U (or V=NULL) for a symmetric low-rank correction, A + U*C*U'.
179 
180    If c is NULL then the low-rank correction is just U*V'.
181    If a sequential c vector is used for a parallel matrix,
182    PETSc assumes that the values of the vector are consistently set across processors.
183 
184    Level: intermediate
185 
186 .seealso: `MatLRCGetMats()`
187 @*/
188 PetscErrorCode MatCreateLRC(Mat A, Mat U, Vec c, Mat V, Mat *N) {
189   PetscBool   match;
190   PetscInt    m, n, k, m1, n1, k1;
191   Mat_LRC    *Na;
192   Mat         Uloc;
193   PetscMPIInt size, csize = 0;
194 
195   PetscFunctionBegin;
196   if (A) PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
197   PetscValidHeaderSpecific(U, MAT_CLASSID, 2);
198   if (c) PetscValidHeaderSpecific(c, VEC_CLASSID, 3);
199   if (V) {
200     PetscValidHeaderSpecific(V, MAT_CLASSID, 4);
201     PetscCheckSameComm(U, 2, V, 4);
202   }
203   if (A) PetscCheckSameComm(A, 1, U, 2);
204 
205   if (!V) V = U;
206   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)U, &match, MATSEQDENSE, MATMPIDENSE, ""));
207   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix U must be of type dense, found %s", ((PetscObject)U)->type_name);
208   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)V, &match, MATSEQDENSE, MATMPIDENSE, ""));
209   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_SUP, "Matrix V must be of type dense, found %s", ((PetscObject)V)->type_name);
210   PetscCall(PetscStrcmp(U->defaultvectype, V->defaultvectype, &match));
211   PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix U and V must have the same VecType %s != %s", U->defaultvectype, V->defaultvectype);
212   if (A) {
213     PetscCall(PetscStrcmp(A->defaultvectype, U->defaultvectype, &match));
214     PetscCheck(match, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_WRONG, "Matrix A and U must have the same VecType %s != %s", A->defaultvectype, U->defaultvectype);
215   }
216 
217   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U), &size));
218   PetscCall(MatGetSize(U, NULL, &k));
219   PetscCall(MatGetSize(V, NULL, &k1));
220   PetscCheck(k == k1, PetscObjectComm((PetscObject)U), PETSC_ERR_ARG_INCOMP, "U and V have different number of columns (%" PetscInt_FMT " vs %" PetscInt_FMT ")", k, k1);
221   PetscCall(MatGetLocalSize(U, &m, NULL));
222   PetscCall(MatGetLocalSize(V, &n, NULL));
223   if (A) {
224     PetscCall(MatGetLocalSize(A, &m1, &n1));
225     PetscCheck(m == m1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of U %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", m, m1);
226     PetscCheck(n == n1, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Local dimensions of V %" PetscInt_FMT " and A %" PetscInt_FMT " do not match", n, n1);
227   }
228   if (c) {
229     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c), &csize));
230     PetscCall(VecGetSize(c, &k1));
231     PetscCheck(k == k1, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "The length of c %" PetscInt_FMT " does not match the number of columns of U and V (%" PetscInt_FMT ")", k1, k);
232     PetscCheck(csize == 1 || csize == size, PetscObjectComm((PetscObject)c), PETSC_ERR_ARG_INCOMP, "U and c must have the same communicator size %d != %d", size, csize);
233   }
234 
235   PetscCall(MatCreate(PetscObjectComm((PetscObject)U), N));
236   PetscCall(MatSetSizes(*N, m, n, PETSC_DECIDE, PETSC_DECIDE));
237   PetscCall(MatSetVecType(*N, U->defaultvectype));
238   PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATLRC));
239   /* Flag matrix as symmetric if A is symmetric and U == V */
240   PetscCall(MatSetOption(*N, MAT_SYMMETRIC, (PetscBool)((A ? A->symmetric == PETSC_BOOL3_TRUE : PETSC_TRUE) && U == V)));
241 
242   PetscCall(PetscNewLog(*N, &Na));
243   (*N)->data = (void *)Na;
244   Na->A      = A;
245   Na->U      = U;
246   Na->c      = c;
247   Na->V      = V;
248 
249   PetscCall(PetscObjectReference((PetscObject)A));
250   PetscCall(PetscObjectReference((PetscObject)Na->U));
251   PetscCall(PetscObjectReference((PetscObject)Na->V));
252   PetscCall(PetscObjectReference((PetscObject)c));
253 
254   PetscCall(MatDenseGetLocalMatrix(Na->U, &Uloc));
255   PetscCall(MatCreateVecs(Uloc, &Na->work1, NULL));
256   if (size != 1) {
257     Mat Vloc;
258 
259     if (Na->c && csize != 1) { /* scatter parallel vector to sequential */
260       VecScatter sct;
261 
262       PetscCall(VecScatterCreateToAll(Na->c, &sct, &c));
263       PetscCall(VecScatterBegin(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD));
264       PetscCall(VecScatterEnd(sct, Na->c, c, INSERT_VALUES, SCATTER_FORWARD));
265       PetscCall(VecScatterDestroy(&sct));
266       PetscCall(VecDestroy(&Na->c));
267       PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)c));
268       Na->c = c;
269     }
270     PetscCall(MatDenseGetLocalMatrix(Na->V, &Vloc));
271     PetscCall(VecDuplicate(Na->work1, &Na->work2));
272     PetscCall(MatCreateVecs(Vloc, NULL, &Na->xl));
273     PetscCall(MatCreateVecs(Uloc, NULL, &Na->yl));
274   }
275   PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)Na->work1));
276   PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)Na->work1));
277   PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)Na->xl));
278   PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)Na->yl));
279 
280   /* Internally create a scaling vector if roottypes do not match */
281   if (Na->c) {
282     VecType rt1, rt2;
283 
284     PetscCall(VecGetRootType_Private(Na->work1, &rt1));
285     PetscCall(VecGetRootType_Private(Na->c, &rt2));
286     PetscCall(PetscStrcmp(rt1, rt2, &match));
287     if (!match) {
288       PetscCall(VecDuplicate(Na->c, &c));
289       PetscCall(VecCopy(Na->c, c));
290       PetscCall(VecDestroy(&Na->c));
291       PetscCall(PetscLogObjectParent((PetscObject)*N, (PetscObject)c));
292       Na->c = c;
293     }
294   }
295 
296   (*N)->ops->destroy       = MatDestroy_LRC;
297   (*N)->ops->mult          = MatMult_LRC;
298   (*N)->ops->multtranspose = MatMultTranspose_LRC;
299 
300   (*N)->assembled    = PETSC_TRUE;
301   (*N)->preallocated = PETSC_TRUE;
302 
303   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatLRCGetMats_C", MatLRCGetMats_LRC));
304   PetscCall(MatSetUp(*N));
305   PetscFunctionReturn(0);
306 }
307