xref: /petsc/src/mat/impls/submat/submat.c (revision 7b6bb2c608b6fc6714ef38fda02c2dbb91c82665)
1 
2 #include <private/matimpl.h>          /*I "petscmat.h" I*/
3 
4 typedef struct {
5   IS isrow,iscol;               /* rows and columns in submatrix, only used to check consistency */
6   Vec left,right;               /* optional scaling */
7   Vec olwork,orwork;            /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
8   Vec lwork,rwork;              /* work vectors inside the scatters */
9   VecScatter lrestrict,rprolong;
10   Mat A;
11   PetscScalar scale;
12 } Mat_SubMatrix;
13 
14 #undef __FUNCT__
15 #define __FUNCT__ "PreScaleLeft"
16 static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
17 {
18   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
19   PetscErrorCode ierr;
20 
21   PetscFunctionBegin;
22   if (!Na->left) {
23     *xx = x;
24   } else {
25     if (!Na->olwork) {
26       ierr = VecDuplicate(Na->left,&Na->olwork);CHKERRQ(ierr);
27     }
28     ierr = VecPointwiseMult(Na->olwork,x,Na->left);CHKERRQ(ierr);
29     *xx = Na->olwork;
30   }
31   PetscFunctionReturn(0);
32 }
33 
34 #undef __FUNCT__
35 #define __FUNCT__ "PreScaleRight"
36 static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
37 {
38   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
39   PetscErrorCode ierr;
40 
41   PetscFunctionBegin;
42   if (!Na->right) {
43     *xx = x;
44   } else {
45     if (!Na->orwork) {
46       ierr = VecDuplicate(Na->right,&Na->orwork);CHKERRQ(ierr);
47     }
48     ierr = VecPointwiseMult(Na->orwork,x,Na->right);CHKERRQ(ierr);
49     *xx = Na->orwork;
50   }
51   PetscFunctionReturn(0);
52 }
53 
54 #undef __FUNCT__
55 #define __FUNCT__ "PostScaleLeft"
56 static PetscErrorCode PostScaleLeft(Mat N,Vec x)
57 {
58   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
59   PetscErrorCode ierr;
60 
61   PetscFunctionBegin;
62   if (Na->left) {
63     ierr = VecPointwiseMult(x,x,Na->left);CHKERRQ(ierr);
64   }
65   PetscFunctionReturn(0);
66 }
67 
68 #undef __FUNCT__
69 #define __FUNCT__ "PostScaleRight"
70 static PetscErrorCode PostScaleRight(Mat N,Vec x)
71 {
72   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
73   PetscErrorCode ierr;
74 
75   PetscFunctionBegin;
76   if (Na->right) {
77     ierr = VecPointwiseMult(x,x,Na->right);CHKERRQ(ierr);
78   }
79   PetscFunctionReturn(0);
80 }
81 
82 #undef __FUNCT__
83 #define __FUNCT__ "MatScale_SubMatrix"
84 static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale)
85 {
86   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
87 
88   PetscFunctionBegin;
89   Na->scale *= scale;
90   PetscFunctionReturn(0);
91 }
92 
93 #undef __FUNCT__
94 #define __FUNCT__ "MatDiagonalScale_SubMatrix"
95 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
96 {
97   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
98   PetscErrorCode ierr;
99 
100   PetscFunctionBegin;
101   if (left) {
102     if (!Na->left) {
103       ierr = VecDuplicate(left,&Na->left);CHKERRQ(ierr);
104       ierr = VecCopy(left,Na->left);CHKERRQ(ierr);
105     } else {
106       ierr = VecPointwiseMult(Na->left,left,Na->left);CHKERRQ(ierr);
107     }
108   }
109   if (right) {
110     if (!Na->right) {
111       ierr = VecDuplicate(right,&Na->right);CHKERRQ(ierr);
112       ierr = VecCopy(right,Na->right);CHKERRQ(ierr);
113     } else {
114       ierr = VecPointwiseMult(Na->right,right,Na->right);CHKERRQ(ierr);
115     }
116   }
117   PetscFunctionReturn(0);
118 }
119 
120 #undef __FUNCT__
121 #define __FUNCT__ "MatMult_SubMatrix"
122 static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
123 {
124   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
125   Vec             xx=0;
126   PetscErrorCode  ierr;
127 
128   PetscFunctionBegin;
129   ierr = PreScaleRight(N,x,&xx);CHKERRQ(ierr);
130   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
131   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
132   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
133   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
134   ierr = VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
135   ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
136   ierr = PostScaleLeft(N,y);CHKERRQ(ierr);
137   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
138   PetscFunctionReturn(0);
139 }
140 
141 #undef __FUNCT__
142 #define __FUNCT__ "MatMultAdd_SubMatrix"
143 static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
144 {
145   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
146   Vec             xx=0;
147   PetscErrorCode  ierr;
148 
149   PetscFunctionBegin;
150   ierr = PreScaleRight(N,v1,&xx);CHKERRQ(ierr);
151   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
152   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
153   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
154   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
155   if (v2 == v3) {
156     if (Na->scale == (PetscScalar)1.0 && !Na->left) {
157       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
158       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
159     } else {
160       if (!Na->olwork) {ierr = VecDuplicate(v3,&Na->olwork);CHKERRQ(ierr);}
161       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
162       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
163       ierr = PostScaleLeft(N,Na->olwork);CHKERRQ(ierr);
164       ierr = VecAXPY(v3,Na->scale,Na->olwork);CHKERRQ(ierr);
165     }
166   } else {
167     ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
168     ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
169     ierr = PostScaleLeft(N,v3);CHKERRQ(ierr);
170     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
171   }
172   PetscFunctionReturn(0);
173 }
174 
175 #undef __FUNCT__
176 #define __FUNCT__ "MatMultTranspose_SubMatrix"
177 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
178 {
179   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
180   Vec             xx=0;
181   PetscErrorCode ierr;
182 
183   PetscFunctionBegin;
184   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
185   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
186   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
187   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
188   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
189   ierr = VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
190   ierr = VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
191   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
192   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
193   PetscFunctionReturn(0);
194 }
195 
196 #undef __FUNCT__
197 #define __FUNCT__ "MatMultTransposeAdd_SubMatrix"
198 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
199 {
200   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
201   Vec             xx =0;
202   PetscErrorCode ierr;
203 
204   PetscFunctionBegin;
205   ierr = PreScaleLeft(N,v1,&xx);CHKERRQ(ierr);
206   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
207   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
208   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
209   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
210   if (v2 == v3) {
211     if (Na->scale == (PetscScalar)1.0 && !Na->right) {
212       ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
213       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
214     } else {
215       if (!Na->orwork) {ierr = VecDuplicate(v3,&Na->orwork);CHKERRQ(ierr);}
216       ierr = VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
217       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
218       ierr = PostScaleRight(N,Na->orwork);CHKERRQ(ierr);
219       ierr = VecAXPY(v3,Na->scale,Na->orwork);CHKERRQ(ierr);
220     }
221   } else {
222     ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
223     ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
224     ierr = PostScaleRight(N,v3);CHKERRQ(ierr);
225     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
226   }
227   PetscFunctionReturn(0);
228 }
229 
230 #undef __FUNCT__
231 #define __FUNCT__ "MatDestroy_SubMatrix"
232 static PetscErrorCode MatDestroy_SubMatrix(Mat N)
233 {
234   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
235   PetscErrorCode ierr;
236 
237   PetscFunctionBegin;
238   ierr = ISDestroy(&Na->isrow);CHKERRQ(ierr);
239   ierr = ISDestroy(&Na->iscol);CHKERRQ(ierr);
240   ierr = VecDestroy(&Na->left);CHKERRQ(ierr);
241   ierr = VecDestroy(&Na->right);CHKERRQ(ierr);
242   ierr = VecDestroy(&Na->olwork);CHKERRQ(ierr);
243   ierr = VecDestroy(&Na->orwork);CHKERRQ(ierr);
244   ierr = VecDestroy(&Na->lwork);CHKERRQ(ierr);
245   ierr = VecDestroy(&Na->rwork);CHKERRQ(ierr);
246   ierr = VecScatterDestroy(&Na->lrestrict);CHKERRQ(ierr);
247   ierr = VecScatterDestroy(&Na->rprolong);CHKERRQ(ierr);
248   ierr = MatDestroy(&Na->A);CHKERRQ(ierr);
249   ierr = PetscFree(N->data);CHKERRQ(ierr);
250   PetscFunctionReturn(0);
251 }
252 
253 #undef __FUNCT__
254 #define __FUNCT__ "MatCreateSubMatrix"
255 /*@
256    MatCreateSubMatrix - Creates a composite matrix that acts as a submatrix
257 
258    Collective on Mat
259 
260    Input Parameters:
261 +  A - matrix that we will extract a submatrix of
262 .  isrow - rows to be present in the submatrix
263 -  iscol - columns to be present in the submatrix
264 
265    Output Parameters:
266 .  newmat - new matrix
267 
268    Level: developer
269 
270    Notes:
271    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.
272 
273 .seealso: MatGetSubMatrix(), MatSubMatrixUpdate()
274 @*/
275 PetscErrorCode  MatCreateSubMatrix(Mat A,IS isrow,IS iscol,Mat *newmat)
276 {
277   Vec            left,right;
278   PetscInt       m,n;
279   Mat            N;
280   Mat_SubMatrix *Na;
281   PetscErrorCode ierr;
282 
283   PetscFunctionBegin;
284   PetscValidHeaderSpecific(A,MAT_CLASSID,1);
285   PetscValidHeaderSpecific(isrow,IS_CLASSID,2);
286   PetscValidHeaderSpecific(iscol,IS_CLASSID,3);
287   PetscValidPointer(newmat,4);
288   *newmat = 0;
289 
290   ierr = MatCreate(((PetscObject)A)->comm,&N);CHKERRQ(ierr);
291   ierr = ISGetLocalSize(isrow,&m);CHKERRQ(ierr);
292   ierr = ISGetLocalSize(iscol,&n);CHKERRQ(ierr);
293   ierr = MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);CHKERRQ(ierr);
294   ierr = PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);CHKERRQ(ierr);
295 
296   ierr = PetscNewLog(N,Mat_SubMatrix,&Na);CHKERRQ(ierr);
297   N->data   = (void*)Na;
298   ierr = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
299   ierr = PetscObjectReference((PetscObject)isrow);CHKERRQ(ierr);
300   ierr = PetscObjectReference((PetscObject)iscol);CHKERRQ(ierr);
301   Na->A     = A;
302   Na->isrow = isrow;
303   Na->iscol = iscol;
304   Na->scale = 1.0;
305 
306   N->ops->destroy          = MatDestroy_SubMatrix;
307   N->ops->mult             = MatMult_SubMatrix;
308   N->ops->multadd          = MatMultAdd_SubMatrix;
309   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
310   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
311   N->ops->scale            = MatScale_SubMatrix;
312   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
313 
314   N->assembled = PETSC_TRUE;
315 
316   ierr = PetscLayoutSetBlockSize(N->rmap,A->rmap->bs);CHKERRQ(ierr);
317   ierr = PetscLayoutSetBlockSize(N->cmap,A->cmap->bs);CHKERRQ(ierr);
318   ierr = PetscLayoutSetUp(N->rmap);CHKERRQ(ierr);
319   ierr = PetscLayoutSetUp(N->cmap);CHKERRQ(ierr);
320 
321   ierr = MatGetVecs(A,&Na->rwork,&Na->lwork);CHKERRQ(ierr);
322   ierr = VecCreate(((PetscObject)isrow)->comm,&left);CHKERRQ(ierr);
323   ierr = VecCreate(((PetscObject)iscol)->comm,&right);CHKERRQ(ierr);
324   ierr = VecSetSizes(left,m,PETSC_DETERMINE);CHKERRQ(ierr);
325   ierr = VecSetSizes(right,n,PETSC_DETERMINE);CHKERRQ(ierr);
326   ierr = VecSetUp(left);CHKERRQ(ierr);
327   ierr = VecSetUp(right);CHKERRQ(ierr);
328   ierr = VecScatterCreate(Na->lwork,isrow,left,PETSC_NULL,&Na->lrestrict);CHKERRQ(ierr);
329   ierr = VecScatterCreate(right,PETSC_NULL,Na->rwork,iscol,&Na->rprolong);CHKERRQ(ierr);
330   ierr = VecDestroy(&left);CHKERRQ(ierr);
331   ierr = VecDestroy(&right);CHKERRQ(ierr);
332 
333   *newmat = N;
334   PetscFunctionReturn(0);
335 }
336 
337 
338 #undef __FUNCT__
339 #define __FUNCT__ "MatSubMatrixUpdate"
340 /*@
341    MatSubMatrixUpdate - Updates a submatrix
342 
343    Collective on Mat
344 
345    Input Parameters:
346 +  N - submatrix to update
347 .  A - full matrix in the submatrix
348 .  isrow - rows in the update (same as the first time the submatrix was created)
349 -  iscol - columns in the update (same as the first time the submatrix was created)
350 
351    Level: developer
352 
353    Notes:
354    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.
355 
356 .seealso: MatGetSubMatrix(), MatCreateSubMatrix()
357 @*/
358 PetscErrorCode  MatSubMatrixUpdate(Mat N,Mat A,IS isrow,IS iscol)
359 {
360   PetscErrorCode  ierr;
361   PetscBool       flg;
362   Mat_SubMatrix  *Na;
363 
364   PetscFunctionBegin;
365   PetscValidHeaderSpecific(N,MAT_CLASSID,1);
366   PetscValidHeaderSpecific(A,MAT_CLASSID,2);
367   PetscValidHeaderSpecific(isrow,IS_CLASSID,3);
368   PetscValidHeaderSpecific(iscol,IS_CLASSID,4);
369   ierr = PetscTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);CHKERRQ(ierr);
370   if (!flg) SETERRQ(((PetscObject)A)->comm,PETSC_ERR_ARG_WRONG,"Matrix has wrong type");
371 
372   Na = (Mat_SubMatrix*)N->data;
373   ierr = ISEqual(isrow,Na->isrow,&flg);CHKERRQ(ierr);
374   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
375   ierr = ISEqual(iscol,Na->iscol,&flg);CHKERRQ(ierr);
376   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");
377 
378   ierr = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
379   ierr = MatDestroy(&Na->A);CHKERRQ(ierr);
380   Na->A = A;
381 
382   Na->scale = 1.0;
383   ierr = VecDestroy(&Na->left);CHKERRQ(ierr);
384   ierr = VecDestroy(&Na->right);CHKERRQ(ierr);
385   PetscFunctionReturn(0);
386 }
387