xref: /petsc/src/mat/impls/lrc/lrc.c (revision fdf6c4e30aafdbc795e4f76379caa977fd5cdf5a)
1 
2 #include <petsc/private/matimpl.h>          /*I "petscmat.h" I*/
3 
4 PETSC_EXTERN PetscErrorCode VecGetRootType_Private(Vec,VecType*);
5 
6 typedef struct {
7   Mat A;           /* sparse matrix */
8   Mat U,V;         /* dense tall-skinny matrices */
9   Vec c;           /* sequential vector containing the diagonal of C */
10   Vec work1,work2; /* sequential vectors that hold partial products */
11   Vec xl,yl;       /* auxiliary sequential vectors for matmult operation */
12 } Mat_LRC;
13 
14 static PetscErrorCode MatMult_LRC_kernel(Mat N,Vec x,Vec y,PetscBool transpose)
15 {
16   Mat_LRC        *Na = (Mat_LRC*)N->data;
17   PetscMPIInt    size;
18   Mat            U,V;
19 
20   PetscFunctionBegin;
21   U = transpose ? Na->V : Na->U;
22   V = transpose ? Na->U : Na->V;
23   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)N),&size));
24   if (size == 1) {
25     PetscCall(MatMultHermitianTranspose(V,x,Na->work1));
26     if (Na->c) PetscCall(VecPointwiseMult(Na->work1,Na->c,Na->work1));
27     if (Na->A) {
28       if (transpose) {
29         PetscCall(MatMultTranspose(Na->A,x,y));
30       } else {
31         PetscCall(MatMult(Na->A,x,y));
32       }
33       PetscCall(MatMultAdd(U,Na->work1,y,y));
34     } else {
35       PetscCall(MatMult(U,Na->work1,y));
36     }
37   } else {
38     Mat               Uloc,Vloc;
39     Vec               yl,xl;
40     const PetscScalar *w1;
41     PetscScalar       *w2;
42     PetscInt          nwork;
43     PetscMPIInt       mpinwork;
44 
45     xl = transpose ? Na->yl : Na->xl;
46     yl = transpose ? Na->xl : Na->yl;
47     PetscCall(VecGetLocalVector(y,yl));
48     PetscCall(MatDenseGetLocalMatrix(U,&Uloc));
49     PetscCall(MatDenseGetLocalMatrix(V,&Vloc));
50 
51     /* multiply the local part of V with the local part of x */
52     PetscCall(VecGetLocalVectorRead(x,xl));
53     PetscCall(MatMultHermitianTranspose(Vloc,xl,Na->work1));
54     PetscCall(VecRestoreLocalVectorRead(x,xl));
55 
56     /* form the sum of all the local multiplies: this is work2 = V'*x =
57        sum_{all processors} work1 */
58     PetscCall(VecGetArrayRead(Na->work1,&w1));
59     PetscCall(VecGetArrayWrite(Na->work2,&w2));
60     PetscCall(VecGetLocalSize(Na->work1,&nwork));
61     PetscCall(PetscMPIIntCast(nwork,&mpinwork));
62     PetscCall(MPIU_Allreduce(w1,w2,mpinwork,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)N)));
63     PetscCall(VecRestoreArrayRead(Na->work1,&w1));
64     PetscCall(VecRestoreArrayWrite(Na->work2,&w2));
65 
66     if (Na->c) {  /* work2 = C*work2 */
67       PetscCall(VecPointwiseMult(Na->work2,Na->c,Na->work2));
68     }
69 
70     if (Na->A) {
71       /* form y = A*x or A^t*x */
72       if (transpose) {
73         PetscCall(MatMultTranspose(Na->A,x,y));
74       } else {
75         PetscCall(MatMult(Na->A,x,y));
76       }
77       /* multiply-add y = y + U*work2 */
78       PetscCall(MatMultAdd(Uloc,Na->work2,yl,yl));
79     } else {
80       /* multiply y = U*work2 */
81       PetscCall(MatMult(Uloc,Na->work2,yl));
82     }
83 
84     PetscCall(VecRestoreLocalVector(y,yl));
85   }
86   PetscFunctionReturn(0);
87 }
88 
89 static PetscErrorCode MatMult_LRC(Mat N,Vec x,Vec y)
90 {
91   PetscFunctionBegin;
92   PetscCall(MatMult_LRC_kernel(N,x,y,PETSC_FALSE));
93   PetscFunctionReturn(0);
94 }
95 
96 static PetscErrorCode MatMultTranspose_LRC(Mat N,Vec x,Vec y)
97 {
98   PetscFunctionBegin;
99   PetscCall(MatMult_LRC_kernel(N,x,y,PETSC_TRUE));
100   PetscFunctionReturn(0);
101 }
102 
103 static PetscErrorCode MatDestroy_LRC(Mat N)
104 {
105   Mat_LRC        *Na = (Mat_LRC*)N->data;
106 
107   PetscFunctionBegin;
108   PetscCall(MatDestroy(&Na->A));
109   PetscCall(MatDestroy(&Na->U));
110   PetscCall(MatDestroy(&Na->V));
111   PetscCall(VecDestroy(&Na->c));
112   PetscCall(VecDestroy(&Na->work1));
113   PetscCall(VecDestroy(&Na->work2));
114   PetscCall(VecDestroy(&Na->xl));
115   PetscCall(VecDestroy(&Na->yl));
116   PetscCall(PetscFree(N->data));
117   PetscCall(PetscObjectComposeFunction((PetscObject)N,"MatLRCGetMats_C",NULL));
118   PetscFunctionReturn(0);
119 }
120 
121 static PetscErrorCode MatLRCGetMats_LRC(Mat N,Mat *A,Mat *U,Vec *c,Mat *V)
122 {
123   Mat_LRC *Na = (Mat_LRC*)N->data;
124 
125   PetscFunctionBegin;
126   if (A) *A = Na->A;
127   if (U) *U = Na->U;
128   if (c) *c = Na->c;
129   if (V) *V = Na->V;
130   PetscFunctionReturn(0);
131 }
132 
133 /*@
134    MatLRCGetMats - Returns the constituents of an LRC matrix
135 
136    Collective on Mat
137 
138    Input Parameter:
139 .  N - matrix of type LRC
140 
141    Output Parameters:
142 +  A - the (sparse) matrix
143 .  U - first dense rectangular (tall and skinny) matrix
144 .  c - a sequential vector containing the diagonal of C
145 -  V - second dense rectangular (tall and skinny) matrix
146 
147    Note:
148    The returned matrices need not be destroyed by the caller.
149 
150    Level: intermediate
151 
152 .seealso: `MatCreateLRC()`
153 @*/
154 PetscErrorCode MatLRCGetMats(Mat N,Mat *A,Mat *U,Vec *c,Mat *V)
155 {
156   PetscFunctionBegin;
157   PetscUseMethod(N,"MatLRCGetMats_C",(Mat,Mat*,Mat*,Vec*,Mat*),(N,A,U,c,V));
158   PetscFunctionReturn(0);
159 }
160 
161 /*@
162    MatCreateLRC - Creates a new matrix object that behaves like A + U*C*V'
163 
164    Collective on Mat
165 
166    Input Parameters:
167 +  A    - the (sparse) matrix (can be NULL)
168 .  U, V - two dense rectangular (tall and skinny) matrices
169 -  c    - a vector containing the diagonal of C (can be NULL)
170 
171    Output Parameter:
172 .  N    - the matrix that represents A + U*C*V'
173 
174    Notes:
175    The matrix A + U*C*V' is not formed! Rather the new matrix
176    object performs the matrix-vector product by first multiplying by
177    A and then adding the other term.
178 
179    C is a diagonal matrix (represented as a vector) of order k,
180    where k is the number of columns of both U and V.
181 
182    If A is NULL then the new object behaves like a low-rank matrix U*C*V'.
183 
184    Use V=U (or V=NULL) for a symmetric low-rank correction, A + U*C*U'.
185 
186    If c is NULL then the low-rank correction is just U*V'.
187    If a sequential c vector is used for a parallel matrix,
188    PETSc assumes that the values of the vector are consistently set across processors.
189 
190    Level: intermediate
191 
192 .seealso: `MatLRCGetMats()`
193 @*/
194 PetscErrorCode MatCreateLRC(Mat A,Mat U,Vec c,Mat V,Mat *N)
195 {
196   PetscBool      match;
197   PetscInt       m,n,k,m1,n1,k1;
198   Mat_LRC        *Na;
199   Mat            Uloc;
200   PetscMPIInt    size, csize = 0;
201 
202   PetscFunctionBegin;
203   if (A) PetscValidHeaderSpecific(A,MAT_CLASSID,1);
204   PetscValidHeaderSpecific(U,MAT_CLASSID,2);
205   if (c) PetscValidHeaderSpecific(c,VEC_CLASSID,3);
206   if (V) {
207     PetscValidHeaderSpecific(V,MAT_CLASSID,4);
208     PetscCheckSameComm(U,2,V,4);
209   }
210   if (A) PetscCheckSameComm(A,1,U,2);
211 
212   if (!V) V = U;
213   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)U,&match,MATSEQDENSE,MATMPIDENSE,""));
214   PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_SUP,"Matrix U must be of type dense, found %s",((PetscObject)U)->type_name);
215   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)V,&match,MATSEQDENSE,MATMPIDENSE,""));
216   PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_SUP,"Matrix V must be of type dense, found %s",((PetscObject)V)->type_name);
217   PetscCall(PetscStrcmp(U->defaultvectype,V->defaultvectype,&match));
218   PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_WRONG,"Matrix U and V must have the same VecType %s != %s",U->defaultvectype,V->defaultvectype);
219   if (A) {
220     PetscCall(PetscStrcmp(A->defaultvectype,U->defaultvectype,&match));
221     PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_WRONG,"Matrix A and U must have the same VecType %s != %s",A->defaultvectype,U->defaultvectype);
222   }
223 
224   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U),&size));
225   PetscCall(MatGetSize(U,NULL,&k));
226   PetscCall(MatGetSize(V,NULL,&k1));
227   PetscCheck(k == k1,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_INCOMP,"U and V have different number of columns (%" PetscInt_FMT " vs %" PetscInt_FMT ")",k,k1);
228   PetscCall(MatGetLocalSize(U,&m,NULL));
229   PetscCall(MatGetLocalSize(V,&n,NULL));
230   if (A) {
231     PetscCall(MatGetLocalSize(A,&m1,&n1));
232     PetscCheck(m == m1,PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Local dimensions of U %" PetscInt_FMT " and A %" PetscInt_FMT " do not match",m,m1);
233     PetscCheck(n == n1,PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Local dimensions of V %" PetscInt_FMT " and A %" PetscInt_FMT " do not match",n,n1);
234   }
235   if (c) {
236     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c),&csize));
237     PetscCall(VecGetSize(c,&k1));
238     PetscCheck(k == k1,PetscObjectComm((PetscObject)c),PETSC_ERR_ARG_INCOMP,"The length of c %" PetscInt_FMT " does not match the number of columns of U and V (%" PetscInt_FMT ")",k1,k);
239     PetscCheck(csize == 1 || csize == size, PetscObjectComm((PetscObject)c),PETSC_ERR_ARG_INCOMP,"U and c must have the same communicator size %d != %d",size,csize);
240   }
241 
242   PetscCall(MatCreate(PetscObjectComm((PetscObject)U),N));
243   PetscCall(MatSetSizes(*N,m,n,PETSC_DECIDE,PETSC_DECIDE));
244   PetscCall(MatSetVecType(*N,U->defaultvectype));
245   PetscCall(PetscObjectChangeTypeName((PetscObject)*N,MATLRC));
246   /* Flag matrix as symmetric if A is symmetric and U == V */
247   PetscCall(MatSetOption(*N,MAT_SYMMETRIC,(PetscBool)((A ? A->symmetric == PETSC_BOOL3_TRUE : PETSC_TRUE) && U == V)));
248 
249   PetscCall(PetscNewLog(*N,&Na));
250   (*N)->data = (void*)Na;
251   Na->A      = A;
252   Na->U      = U;
253   Na->c      = c;
254   Na->V      = V;
255 
256   PetscCall(PetscObjectReference((PetscObject)A));
257   PetscCall(PetscObjectReference((PetscObject)Na->U));
258   PetscCall(PetscObjectReference((PetscObject)Na->V));
259   PetscCall(PetscObjectReference((PetscObject)c));
260 
261   PetscCall(MatDenseGetLocalMatrix(Na->U,&Uloc));
262   PetscCall(MatCreateVecs(Uloc,&Na->work1,NULL));
263   if (size != 1) {
264     Mat Vloc;
265 
266     if (Na->c && csize != 1) { /* scatter parallel vector to sequential */
267       VecScatter sct;
268 
269       PetscCall(VecScatterCreateToAll(Na->c,&sct,&c));
270       PetscCall(VecScatterBegin(sct,Na->c,c,INSERT_VALUES,SCATTER_FORWARD));
271       PetscCall(VecScatterEnd(sct,Na->c,c,INSERT_VALUES,SCATTER_FORWARD));
272       PetscCall(VecScatterDestroy(&sct));
273       PetscCall(VecDestroy(&Na->c));
274       PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)c));
275       Na->c = c;
276     }
277     PetscCall(MatDenseGetLocalMatrix(Na->V,&Vloc));
278     PetscCall(VecDuplicate(Na->work1,&Na->work2));
279     PetscCall(MatCreateVecs(Vloc,NULL,&Na->xl));
280     PetscCall(MatCreateVecs(Uloc,NULL,&Na->yl));
281   }
282   PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->work1));
283   PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->work1));
284   PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->xl));
285   PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->yl));
286 
287   /* Internally create a scaling vector if roottypes do not match */
288   if (Na->c) {
289     VecType rt1,rt2;
290 
291     PetscCall(VecGetRootType_Private(Na->work1,&rt1));
292     PetscCall(VecGetRootType_Private(Na->c,&rt2));
293     PetscCall(PetscStrcmp(rt1,rt2,&match));
294     if (!match) {
295       PetscCall(VecDuplicate(Na->c,&c));
296       PetscCall(VecCopy(Na->c,c));
297       PetscCall(VecDestroy(&Na->c));
298       PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)c));
299       Na->c = c;
300     }
301   }
302 
303   (*N)->ops->destroy       = MatDestroy_LRC;
304   (*N)->ops->mult          = MatMult_LRC;
305   (*N)->ops->multtranspose = MatMultTranspose_LRC;
306 
307   (*N)->assembled    = PETSC_TRUE;
308   (*N)->preallocated = PETSC_TRUE;
309 
310   PetscCall(PetscObjectComposeFunction((PetscObject)(*N),"MatLRCGetMats_C",MatLRCGetMats_LRC));
311   PetscCall(MatSetUp(*N));
312   PetscFunctionReturn(0);
313 }
314