xref: /petsc/src/mat/impls/lrc/lrc.c (revision daa037dfd3c3bec8dc8659548d2b20b07c1dc6de)
1 
2 #include <petsc/private/matimpl.h>          /*I "petscmat.h" I*/
3 
4 PETSC_EXTERN PetscErrorCode VecGetRootType_Private(Vec,VecType*);
5 
6 typedef struct {
7   Mat A;           /* sparse matrix */
8   Mat U,V;         /* dense tall-skinny matrices */
9   Vec c;           /* sequential vector containing the diagonal of C */
10   Vec work1,work2; /* sequential vectors that hold partial products */
11   Vec xl,yl;       /* auxiliary sequential vectors for matmult operation */
12 } Mat_LRC;
13 
14 static PetscErrorCode MatMult_LRC_kernel(Mat N,Vec x,Vec y,PetscBool transpose)
15 {
16   Mat_LRC        *Na = (Mat_LRC*)N->data;
17   PetscMPIInt    size;
18   Mat            U,V;
19 
20   PetscFunctionBegin;
21   U = transpose ? Na->V : Na->U;
22   V = transpose ? Na->U : Na->V;
23   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)N),&size));
24   if (size == 1) {
25     PetscCall(MatMultHermitianTranspose(V,x,Na->work1));
26     if (Na->c) {
27       PetscCall(VecPointwiseMult(Na->work1,Na->c,Na->work1));
28     }
29     if (Na->A) {
30       if (transpose) {
31         PetscCall(MatMultTranspose(Na->A,x,y));
32       } else {
33         PetscCall(MatMult(Na->A,x,y));
34       }
35       PetscCall(MatMultAdd(U,Na->work1,y,y));
36     } else {
37       PetscCall(MatMult(U,Na->work1,y));
38     }
39   } else {
40     Mat               Uloc,Vloc;
41     Vec               yl,xl;
42     const PetscScalar *w1;
43     PetscScalar       *w2;
44     PetscInt          nwork;
45     PetscMPIInt       mpinwork;
46 
47     xl = transpose ? Na->yl : Na->xl;
48     yl = transpose ? Na->xl : Na->yl;
49     PetscCall(VecGetLocalVector(y,yl));
50     PetscCall(MatDenseGetLocalMatrix(U,&Uloc));
51     PetscCall(MatDenseGetLocalMatrix(V,&Vloc));
52 
53     /* multiply the local part of V with the local part of x */
54     PetscCall(VecGetLocalVectorRead(x,xl));
55     PetscCall(MatMultHermitianTranspose(Vloc,xl,Na->work1));
56     PetscCall(VecRestoreLocalVectorRead(x,xl));
57 
58     /* form the sum of all the local multiplies: this is work2 = V'*x =
59        sum_{all processors} work1 */
60     PetscCall(VecGetArrayRead(Na->work1,&w1));
61     PetscCall(VecGetArrayWrite(Na->work2,&w2));
62     PetscCall(VecGetLocalSize(Na->work1,&nwork));
63     PetscCall(PetscMPIIntCast(nwork,&mpinwork));
64     PetscCall(MPIU_Allreduce(w1,w2,mpinwork,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)N)));
65     PetscCall(VecRestoreArrayRead(Na->work1,&w1));
66     PetscCall(VecRestoreArrayWrite(Na->work2,&w2));
67 
68     if (Na->c) {  /* work2 = C*work2 */
69       PetscCall(VecPointwiseMult(Na->work2,Na->c,Na->work2));
70     }
71 
72     if (Na->A) {
73       /* form y = A*x or A^t*x */
74       if (transpose) {
75         PetscCall(MatMultTranspose(Na->A,x,y));
76       } else {
77         PetscCall(MatMult(Na->A,x,y));
78       }
79       /* multiply-add y = y + U*work2 */
80       PetscCall(MatMultAdd(Uloc,Na->work2,yl,yl));
81     } else {
82       /* multiply y = U*work2 */
83       PetscCall(MatMult(Uloc,Na->work2,yl));
84     }
85 
86     PetscCall(VecRestoreLocalVector(y,yl));
87   }
88   PetscFunctionReturn(0);
89 }
90 
91 static PetscErrorCode MatMult_LRC(Mat N,Vec x,Vec y)
92 {
93   PetscFunctionBegin;
94   PetscCall(MatMult_LRC_kernel(N,x,y,PETSC_FALSE));
95   PetscFunctionReturn(0);
96 }
97 
98 static PetscErrorCode MatMultTranspose_LRC(Mat N,Vec x,Vec y)
99 {
100   PetscFunctionBegin;
101   PetscCall(MatMult_LRC_kernel(N,x,y,PETSC_TRUE));
102   PetscFunctionReturn(0);
103 }
104 
105 static PetscErrorCode MatDestroy_LRC(Mat N)
106 {
107   Mat_LRC        *Na = (Mat_LRC*)N->data;
108 
109   PetscFunctionBegin;
110   PetscCall(MatDestroy(&Na->A));
111   PetscCall(MatDestroy(&Na->U));
112   PetscCall(MatDestroy(&Na->V));
113   PetscCall(VecDestroy(&Na->c));
114   PetscCall(VecDestroy(&Na->work1));
115   PetscCall(VecDestroy(&Na->work2));
116   PetscCall(VecDestroy(&Na->xl));
117   PetscCall(VecDestroy(&Na->yl));
118   PetscCall(PetscFree(N->data));
119   PetscCall(PetscObjectComposeFunction((PetscObject)N,"MatLRCGetMats_C",NULL));
120   PetscFunctionReturn(0);
121 }
122 
123 static PetscErrorCode MatLRCGetMats_LRC(Mat N,Mat *A,Mat *U,Vec *c,Mat *V)
124 {
125   Mat_LRC *Na = (Mat_LRC*)N->data;
126 
127   PetscFunctionBegin;
128   if (A) *A = Na->A;
129   if (U) *U = Na->U;
130   if (c) *c = Na->c;
131   if (V) *V = Na->V;
132   PetscFunctionReturn(0);
133 }
134 
135 /*@
136    MatLRCGetMats - Returns the constituents of an LRC matrix
137 
138    Collective on Mat
139 
140    Input Parameter:
141 .  N - matrix of type LRC
142 
143    Output Parameters:
144 +  A - the (sparse) matrix
145 .  U - first dense rectangular (tall and skinny) matrix
146 .  c - a sequential vector containing the diagonal of C
147 -  V - second dense rectangular (tall and skinny) matrix
148 
149    Note:
150    The returned matrices need not be destroyed by the caller.
151 
152    Level: intermediate
153 
154 .seealso: MatCreateLRC()
155 @*/
156 PetscErrorCode MatLRCGetMats(Mat N,Mat *A,Mat *U,Vec *c,Mat *V)
157 {
158   PetscFunctionBegin;
159   PetscUseMethod(N,"MatLRCGetMats_C",(Mat,Mat*,Mat*,Vec*,Mat*),(N,A,U,c,V));
160   PetscFunctionReturn(0);
161 }
162 
163 /*@
164    MatCreateLRC - Creates a new matrix object that behaves like A + U*C*V'
165 
166    Collective on Mat
167 
168    Input Parameters:
169 +  A    - the (sparse) matrix (can be NULL)
170 .  U, V - two dense rectangular (tall and skinny) matrices
171 -  c    - a vector containing the diagonal of C (can be NULL)
172 
173    Output Parameter:
174 .  N    - the matrix that represents A + U*C*V'
175 
176    Notes:
177    The matrix A + U*C*V' is not formed! Rather the new matrix
178    object performs the matrix-vector product by first multiplying by
179    A and then adding the other term.
180 
181    C is a diagonal matrix (represented as a vector) of order k,
182    where k is the number of columns of both U and V.
183 
184    If A is NULL then the new object behaves like a low-rank matrix U*C*V'.
185 
186    Use V=U (or V=NULL) for a symmetric low-rank correction, A + U*C*U'.
187 
188    If c is NULL then the low-rank correction is just U*V'.
189    If a sequential c vector is used for a parallel matrix,
190    PETSc assumes that the values of the vector are consistently set across processors.
191 
192    Level: intermediate
193 
194 .seealso: MatLRCGetMats()
195 @*/
196 PetscErrorCode MatCreateLRC(Mat A,Mat U,Vec c,Mat V,Mat *N)
197 {
198   PetscBool      match;
199   PetscInt       m,n,k,m1,n1,k1;
200   Mat_LRC        *Na;
201   Mat            Uloc;
202   PetscMPIInt    size, csize = 0;
203 
204   PetscFunctionBegin;
205   if (A) PetscValidHeaderSpecific(A,MAT_CLASSID,1);
206   PetscValidHeaderSpecific(U,MAT_CLASSID,2);
207   if (c) PetscValidHeaderSpecific(c,VEC_CLASSID,3);
208   if (V) {
209     PetscValidHeaderSpecific(V,MAT_CLASSID,4);
210     PetscCheckSameComm(U,2,V,4);
211   }
212   if (A) PetscCheckSameComm(A,1,U,2);
213 
214   if (!V) V = U;
215   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)U,&match,MATSEQDENSE,MATMPIDENSE,""));
216   PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_SUP,"Matrix U must be of type dense, found %s",((PetscObject)U)->type_name);
217   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)V,&match,MATSEQDENSE,MATMPIDENSE,""));
218   PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_SUP,"Matrix V must be of type dense, found %s",((PetscObject)V)->type_name);
219   PetscCall(PetscStrcmp(U->defaultvectype,V->defaultvectype,&match));
220   PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_WRONG,"Matrix U and V must have the same VecType %s != %s",U->defaultvectype,V->defaultvectype);
221   if (A) {
222     PetscCall(PetscStrcmp(A->defaultvectype,U->defaultvectype,&match));
223     PetscCheck(match,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_WRONG,"Matrix A and U must have the same VecType %s != %s",A->defaultvectype,U->defaultvectype);
224   }
225 
226   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)U),&size));
227   PetscCall(MatGetSize(U,NULL,&k));
228   PetscCall(MatGetSize(V,NULL,&k1));
229   PetscCheckFalse(k != k1,PetscObjectComm((PetscObject)U),PETSC_ERR_ARG_INCOMP,"U and V have different number of columns (%" PetscInt_FMT " vs %" PetscInt_FMT ")",k,k1);
230   PetscCall(MatGetLocalSize(U,&m,NULL));
231   PetscCall(MatGetLocalSize(V,&n,NULL));
232   if (A) {
233     PetscCall(MatGetLocalSize(A,&m1,&n1));
234     PetscCheckFalse(m != m1,PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Local dimensions of U %" PetscInt_FMT " and A %" PetscInt_FMT " do not match",m,m1);
235     PetscCheckFalse(n != n1,PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Local dimensions of V %" PetscInt_FMT " and A %" PetscInt_FMT " do not match",n,n1);
236   }
237   if (c) {
238     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)c),&csize));
239     PetscCall(VecGetSize(c,&k1));
240     PetscCheckFalse(k != k1,PetscObjectComm((PetscObject)c),PETSC_ERR_ARG_INCOMP,"The length of c %" PetscInt_FMT " does not match the number of columns of U and V (%" PetscInt_FMT ")",k1,k);
241     PetscCheckFalse(csize != 1 && csize != size, PetscObjectComm((PetscObject)c),PETSC_ERR_ARG_INCOMP,"U and c must have the same communicator size %d != %d",size,csize);
242   }
243 
244   PetscCall(MatCreate(PetscObjectComm((PetscObject)U),N));
245   PetscCall(MatSetSizes(*N,m,n,PETSC_DECIDE,PETSC_DECIDE));
246   PetscCall(MatSetVecType(*N,U->defaultvectype));
247   PetscCall(PetscObjectChangeTypeName((PetscObject)*N,MATLRC));
248   /* Flag matrix as symmetric if A is symmetric and U == V */
249   PetscCall(MatSetOption(*N,MAT_SYMMETRIC,(PetscBool)((A ? A->symmetric : PETSC_TRUE) && U == V)));
250 
251   PetscCall(PetscNewLog(*N,&Na));
252   (*N)->data = (void*)Na;
253   Na->A      = A;
254   Na->U      = U;
255   Na->c      = c;
256   Na->V      = V;
257 
258   PetscCall(PetscObjectReference((PetscObject)A));
259   PetscCall(PetscObjectReference((PetscObject)Na->U));
260   PetscCall(PetscObjectReference((PetscObject)Na->V));
261   PetscCall(PetscObjectReference((PetscObject)c));
262 
263   PetscCall(MatDenseGetLocalMatrix(Na->U,&Uloc));
264   PetscCall(MatCreateVecs(Uloc,&Na->work1,NULL));
265   if (size != 1) {
266     Mat Vloc;
267 
268     if (Na->c && csize != 1) { /* scatter parallel vector to sequential */
269       VecScatter sct;
270 
271       PetscCall(VecScatterCreateToAll(Na->c,&sct,&c));
272       PetscCall(VecScatterBegin(sct,Na->c,c,INSERT_VALUES,SCATTER_FORWARD));
273       PetscCall(VecScatterEnd(sct,Na->c,c,INSERT_VALUES,SCATTER_FORWARD));
274       PetscCall(VecScatterDestroy(&sct));
275       PetscCall(VecDestroy(&Na->c));
276       PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)c));
277       Na->c = c;
278     }
279     PetscCall(MatDenseGetLocalMatrix(Na->V,&Vloc));
280     PetscCall(VecDuplicate(Na->work1,&Na->work2));
281     PetscCall(MatCreateVecs(Vloc,NULL,&Na->xl));
282     PetscCall(MatCreateVecs(Uloc,NULL,&Na->yl));
283   }
284   PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->work1));
285   PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->work1));
286   PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->xl));
287   PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->yl));
288 
289   /* Internally create a scaling vector if roottypes do not match */
290   if (Na->c) {
291     VecType rt1,rt2;
292 
293     PetscCall(VecGetRootType_Private(Na->work1,&rt1));
294     PetscCall(VecGetRootType_Private(Na->c,&rt2));
295     PetscCall(PetscStrcmp(rt1,rt2,&match));
296     if (!match) {
297       PetscCall(VecDuplicate(Na->c,&c));
298       PetscCall(VecCopy(Na->c,c));
299       PetscCall(VecDestroy(&Na->c));
300       PetscCall(PetscLogObjectParent((PetscObject)*N,(PetscObject)c));
301       Na->c = c;
302     }
303   }
304 
305   (*N)->ops->destroy       = MatDestroy_LRC;
306   (*N)->ops->mult          = MatMult_LRC;
307   (*N)->ops->multtranspose = MatMultTranspose_LRC;
308 
309   (*N)->assembled    = PETSC_TRUE;
310   (*N)->preallocated = PETSC_TRUE;
311 
312   PetscCall(PetscObjectComposeFunction((PetscObject)(*N),"MatLRCGetMats_C",MatLRCGetMats_LRC));
313   PetscCall(MatSetUp(*N));
314   PetscFunctionReturn(0);
315 }
316