xref: /petsc/src/mat/impls/submat/submat.c (revision e6e75211d226c622f451867f53ce5d558649ff4f)
1 
2 #include <petsc/private/matimpl.h>          /*I "petscmat.h" I*/
3 
4 typedef struct {
5   IS          isrow,iscol;      /* rows and columns in submatrix, only used to check consistency */
6   Vec         left,right;       /* optional scaling */
7   Vec         olwork,orwork;    /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
8   Vec         lwork,rwork;      /* work vectors inside the scatters */
9   VecScatter  lrestrict,rprolong;
10   Mat         A;
11   PetscScalar scale;
12 } Mat_SubMatrix;
13 
14 #undef __FUNCT__
15 #define __FUNCT__ "PreScaleLeft"
16 static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
17 {
18   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
19   PetscErrorCode ierr;
20 
21   PetscFunctionBegin;
22   if (!Na->left) {
23     *xx = x;
24   } else {
25     if (!Na->olwork) {
26       ierr = VecDuplicate(Na->left,&Na->olwork);CHKERRQ(ierr);
27     }
28     ierr = VecPointwiseMult(Na->olwork,x,Na->left);CHKERRQ(ierr);
29     *xx  = Na->olwork;
30   }
31   PetscFunctionReturn(0);
32 }
33 
34 #undef __FUNCT__
35 #define __FUNCT__ "PreScaleRight"
36 static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
37 {
38   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
39   PetscErrorCode ierr;
40 
41   PetscFunctionBegin;
42   if (!Na->right) {
43     *xx = x;
44   } else {
45     if (!Na->orwork) {
46       ierr = VecDuplicate(Na->right,&Na->orwork);CHKERRQ(ierr);
47     }
48     ierr = VecPointwiseMult(Na->orwork,x,Na->right);CHKERRQ(ierr);
49     *xx  = Na->orwork;
50   }
51   PetscFunctionReturn(0);
52 }
53 
54 #undef __FUNCT__
55 #define __FUNCT__ "PostScaleLeft"
56 static PetscErrorCode PostScaleLeft(Mat N,Vec x)
57 {
58   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
59   PetscErrorCode ierr;
60 
61   PetscFunctionBegin;
62   if (Na->left) {
63     ierr = VecPointwiseMult(x,x,Na->left);CHKERRQ(ierr);
64   }
65   PetscFunctionReturn(0);
66 }
67 
68 #undef __FUNCT__
69 #define __FUNCT__ "PostScaleRight"
70 static PetscErrorCode PostScaleRight(Mat N,Vec x)
71 {
72   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
73   PetscErrorCode ierr;
74 
75   PetscFunctionBegin;
76   if (Na->right) {
77     ierr = VecPointwiseMult(x,x,Na->right);CHKERRQ(ierr);
78   }
79   PetscFunctionReturn(0);
80 }
81 
82 #undef __FUNCT__
83 #define __FUNCT__ "MatScale_SubMatrix"
84 static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale)
85 {
86   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
87 
88   PetscFunctionBegin;
89   Na->scale *= scale;
90   PetscFunctionReturn(0);
91 }
92 
93 #undef __FUNCT__
94 #define __FUNCT__ "MatDiagonalScale_SubMatrix"
95 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
96 {
97   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
98   PetscErrorCode ierr;
99 
100   PetscFunctionBegin;
101   if (left) {
102     if (!Na->left) {
103       ierr = VecDuplicate(left,&Na->left);CHKERRQ(ierr);
104       ierr = VecCopy(left,Na->left);CHKERRQ(ierr);
105     } else {
106       ierr = VecPointwiseMult(Na->left,left,Na->left);CHKERRQ(ierr);
107     }
108   }
109   if (right) {
110     if (!Na->right) {
111       ierr = VecDuplicate(right,&Na->right);CHKERRQ(ierr);
112       ierr = VecCopy(right,Na->right);CHKERRQ(ierr);
113     } else {
114       ierr = VecPointwiseMult(Na->right,right,Na->right);CHKERRQ(ierr);
115     }
116   }
117   PetscFunctionReturn(0);
118 }
119 
120 #undef __FUNCT__
121 #define __FUNCT__ "MatMult_SubMatrix"
122 static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
123 {
124   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
125   Vec            xx  = 0;
126   PetscErrorCode ierr;
127 
128   PetscFunctionBegin;
129   ierr = PreScaleRight(N,x,&xx);CHKERRQ(ierr);
130   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
131   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
132   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
133   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
134   ierr = VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
135   ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
136   ierr = PostScaleLeft(N,y);CHKERRQ(ierr);
137   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
138   PetscFunctionReturn(0);
139 }
140 
141 #undef __FUNCT__
142 #define __FUNCT__ "MatMultAdd_SubMatrix"
143 static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
144 {
145   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
146   Vec            xx  = 0;
147   PetscErrorCode ierr;
148 
149   PetscFunctionBegin;
150   ierr = PreScaleRight(N,v1,&xx);CHKERRQ(ierr);
151   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
152   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
153   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
154   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
155   if (v2 == v3) {
156     if (Na->scale == (PetscScalar)1.0 && !Na->left) {
157       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
158       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
159     } else {
160       if (!Na->olwork) {ierr = VecDuplicate(v3,&Na->olwork);CHKERRQ(ierr);}
161       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
162       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
163       ierr = PostScaleLeft(N,Na->olwork);CHKERRQ(ierr);
164       ierr = VecAXPY(v3,Na->scale,Na->olwork);CHKERRQ(ierr);
165     }
166   } else {
167     ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
168     ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
169     ierr = PostScaleLeft(N,v3);CHKERRQ(ierr);
170     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
171   }
172   PetscFunctionReturn(0);
173 }
174 
175 #undef __FUNCT__
176 #define __FUNCT__ "MatMultTranspose_SubMatrix"
177 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
178 {
179   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
180   Vec            xx  = 0;
181   PetscErrorCode ierr;
182 
183   PetscFunctionBegin;
184   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
185   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
186   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
187   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
188   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
189   ierr = VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
190   ierr = VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
191   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
192   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
193   PetscFunctionReturn(0);
194 }
195 
196 #undef __FUNCT__
197 #define __FUNCT__ "MatMultTransposeAdd_SubMatrix"
198 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
199 {
200   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
201   Vec            xx  = 0;
202   PetscErrorCode ierr;
203 
204   PetscFunctionBegin;
205   ierr = PreScaleLeft(N,v1,&xx);CHKERRQ(ierr);
206   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
207   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
208   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
209   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
210   if (v2 == v3) {
211     if (Na->scale == (PetscScalar)1.0 && !Na->right) {
212       ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
213       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
214     } else {
215       if (!Na->orwork) {ierr = VecDuplicate(v3,&Na->orwork);CHKERRQ(ierr);}
216       ierr = VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
217       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
218       ierr = PostScaleRight(N,Na->orwork);CHKERRQ(ierr);
219       ierr = VecAXPY(v3,Na->scale,Na->orwork);CHKERRQ(ierr);
220     }
221   } else {
222     ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
223     ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
224     ierr = PostScaleRight(N,v3);CHKERRQ(ierr);
225     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
226   }
227   PetscFunctionReturn(0);
228 }
229 
230 #undef __FUNCT__
231 #define __FUNCT__ "MatDestroy_SubMatrix"
232 static PetscErrorCode MatDestroy_SubMatrix(Mat N)
233 {
234   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
235   PetscErrorCode ierr;
236 
237   PetscFunctionBegin;
238   ierr = ISDestroy(&Na->isrow);CHKERRQ(ierr);
239   ierr = ISDestroy(&Na->iscol);CHKERRQ(ierr);
240   ierr = VecDestroy(&Na->left);CHKERRQ(ierr);
241   ierr = VecDestroy(&Na->right);CHKERRQ(ierr);
242   ierr = VecDestroy(&Na->olwork);CHKERRQ(ierr);
243   ierr = VecDestroy(&Na->orwork);CHKERRQ(ierr);
244   ierr = VecDestroy(&Na->lwork);CHKERRQ(ierr);
245   ierr = VecDestroy(&Na->rwork);CHKERRQ(ierr);
246   ierr = VecScatterDestroy(&Na->lrestrict);CHKERRQ(ierr);
247   ierr = VecScatterDestroy(&Na->rprolong);CHKERRQ(ierr);
248   ierr = MatDestroy(&Na->A);CHKERRQ(ierr);
249   ierr = PetscFree(N->data);CHKERRQ(ierr);
250   PetscFunctionReturn(0);
251 }
252 
253 #undef __FUNCT__
254 #define __FUNCT__ "MatCreateSubMatrix"
255 /*@
256    MatCreateSubMatrix - Creates a composite matrix that acts as a submatrix
257 
258    Collective on Mat
259 
260    Input Parameters:
261 +  A - matrix that we will extract a submatrix of
262 .  isrow - rows to be present in the submatrix
263 -  iscol - columns to be present in the submatrix
264 
265    Output Parameters:
266 .  newmat - new matrix
267 
268    Level: developer
269 
270    Notes:
271    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.
272 
273 .seealso: MatGetSubMatrix(), MatSubMatrixUpdate()
274 @*/
275 PetscErrorCode  MatCreateSubMatrix(Mat A,IS isrow,IS iscol,Mat *newmat)
276 {
277   Vec            left,right;
278   PetscInt       m,n;
279   Mat            N;
280   Mat_SubMatrix  *Na;
281   PetscErrorCode ierr;
282 
283   PetscFunctionBegin;
284   PetscValidHeaderSpecific(A,MAT_CLASSID,1);
285   PetscValidHeaderSpecific(isrow,IS_CLASSID,2);
286   PetscValidHeaderSpecific(iscol,IS_CLASSID,3);
287   PetscValidPointer(newmat,4);
288   *newmat = 0;
289 
290   ierr = MatCreate(PetscObjectComm((PetscObject)A),&N);CHKERRQ(ierr);
291   ierr = ISGetLocalSize(isrow,&m);CHKERRQ(ierr);
292   ierr = ISGetLocalSize(iscol,&n);CHKERRQ(ierr);
293   ierr = MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);CHKERRQ(ierr);
294   ierr = PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);CHKERRQ(ierr);
295 
296   ierr      = PetscNewLog(N,&Na);CHKERRQ(ierr);
297   N->data   = (void*)Na;
298   ierr      = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
299   ierr      = PetscObjectReference((PetscObject)isrow);CHKERRQ(ierr);
300   ierr      = PetscObjectReference((PetscObject)iscol);CHKERRQ(ierr);
301   Na->A     = A;
302   Na->isrow = isrow;
303   Na->iscol = iscol;
304   Na->scale = 1.0;
305 
306   N->ops->destroy          = MatDestroy_SubMatrix;
307   N->ops->mult             = MatMult_SubMatrix;
308   N->ops->multadd          = MatMultAdd_SubMatrix;
309   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
310   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
311   N->ops->scale            = MatScale_SubMatrix;
312   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
313 
314   ierr = MatSetBlockSizesFromMats(N,A,A);CHKERRQ(ierr);
315   ierr = PetscLayoutSetUp(N->rmap);CHKERRQ(ierr);
316   ierr = PetscLayoutSetUp(N->cmap);CHKERRQ(ierr);
317 
318   ierr = MatCreateVecs(A,&Na->rwork,&Na->lwork);CHKERRQ(ierr);
319   ierr = VecCreate(PetscObjectComm((PetscObject)isrow),&left);CHKERRQ(ierr);
320   ierr = VecCreate(PetscObjectComm((PetscObject)iscol),&right);CHKERRQ(ierr);
321   ierr = VecSetSizes(left,m,PETSC_DETERMINE);CHKERRQ(ierr);
322   ierr = VecSetSizes(right,n,PETSC_DETERMINE);CHKERRQ(ierr);
323   ierr = VecSetUp(left);CHKERRQ(ierr);
324   ierr = VecSetUp(right);CHKERRQ(ierr);
325   ierr = VecScatterCreate(Na->lwork,isrow,left,NULL,&Na->lrestrict);CHKERRQ(ierr);
326   ierr = VecScatterCreate(right,NULL,Na->rwork,iscol,&Na->rprolong);CHKERRQ(ierr);
327   ierr = VecDestroy(&left);CHKERRQ(ierr);
328   ierr = VecDestroy(&right);CHKERRQ(ierr);
329 
330   N->assembled = PETSC_TRUE;
331 
332   ierr = MatSetUp(N);CHKERRQ(ierr);
333 
334   *newmat      = N;
335   PetscFunctionReturn(0);
336 }
337 
338 
339 #undef __FUNCT__
340 #define __FUNCT__ "MatSubMatrixUpdate"
341 /*@
342    MatSubMatrixUpdate - Updates a submatrix
343 
344    Collective on Mat
345 
346    Input Parameters:
347 +  N - submatrix to update
348 .  A - full matrix in the submatrix
349 .  isrow - rows in the update (same as the first time the submatrix was created)
350 -  iscol - columns in the update (same as the first time the submatrix was created)
351 
352    Level: developer
353 
354    Notes:
355    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.
356 
357 .seealso: MatGetSubMatrix(), MatCreateSubMatrix()
358 @*/
359 PetscErrorCode  MatSubMatrixUpdate(Mat N,Mat A,IS isrow,IS iscol)
360 {
361   PetscErrorCode ierr;
362   PetscBool      flg;
363   Mat_SubMatrix  *Na;
364 
365   PetscFunctionBegin;
366   PetscValidHeaderSpecific(N,MAT_CLASSID,1);
367   PetscValidHeaderSpecific(A,MAT_CLASSID,2);
368   PetscValidHeaderSpecific(isrow,IS_CLASSID,3);
369   PetscValidHeaderSpecific(iscol,IS_CLASSID,4);
370   ierr = PetscObjectTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);CHKERRQ(ierr);
371   if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix has wrong type");
372 
373   Na   = (Mat_SubMatrix*)N->data;
374   ierr = ISEqual(isrow,Na->isrow,&flg);CHKERRQ(ierr);
375   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
376   ierr = ISEqual(iscol,Na->iscol,&flg);CHKERRQ(ierr);
377   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");
378 
379   ierr  = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
380   ierr  = MatDestroy(&Na->A);CHKERRQ(ierr);
381   Na->A = A;
382 
383   Na->scale = 1.0;
384   ierr      = VecDestroy(&Na->left);CHKERRQ(ierr);
385   ierr      = VecDestroy(&Na->right);CHKERRQ(ierr);
386   PetscFunctionReturn(0);
387 }
388