xref: /petsc/src/mat/impls/submat/submat.c (revision 94ef8dde638caef1d0cd84a7dc8a2db65fcda8b6)
1 
2 #include <petsc/private/matimpl.h>          /*I "petscmat.h" I*/
3 
4 typedef struct {
5   IS          isrow,iscol;      /* rows and columns in submatrix, only used to check consistency */
6   Vec         left,right;       /* optional scaling */
7   Vec         olwork,orwork;    /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
8   Vec         lwork,rwork;      /* work vectors inside the scatters */
9   VecScatter  lrestrict,rprolong;
10   Mat         A;
11   PetscScalar scale;
12 } Mat_SubVirtual;
13 
14 static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
15 {
16   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
17   PetscErrorCode ierr;
18 
19   PetscFunctionBegin;
20   if (!Na->left) {
21     *xx = x;
22   } else {
23     if (!Na->olwork) {
24       ierr = VecDuplicate(Na->left,&Na->olwork);CHKERRQ(ierr);
25     }
26     ierr = VecPointwiseMult(Na->olwork,x,Na->left);CHKERRQ(ierr);
27     *xx  = Na->olwork;
28   }
29   PetscFunctionReturn(0);
30 }
31 
32 static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
33 {
34   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
35   PetscErrorCode ierr;
36 
37   PetscFunctionBegin;
38   if (!Na->right) {
39     *xx = x;
40   } else {
41     if (!Na->orwork) {
42       ierr = VecDuplicate(Na->right,&Na->orwork);CHKERRQ(ierr);
43     }
44     ierr = VecPointwiseMult(Na->orwork,x,Na->right);CHKERRQ(ierr);
45     *xx  = Na->orwork;
46   }
47   PetscFunctionReturn(0);
48 }
49 
50 static PetscErrorCode PostScaleLeft(Mat N,Vec x)
51 {
52   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
53   PetscErrorCode ierr;
54 
55   PetscFunctionBegin;
56   if (Na->left) {
57     ierr = VecPointwiseMult(x,x,Na->left);CHKERRQ(ierr);
58   }
59   PetscFunctionReturn(0);
60 }
61 
62 static PetscErrorCode PostScaleRight(Mat N,Vec x)
63 {
64   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
65   PetscErrorCode ierr;
66 
67   PetscFunctionBegin;
68   if (Na->right) {
69     ierr = VecPointwiseMult(x,x,Na->right);CHKERRQ(ierr);
70   }
71   PetscFunctionReturn(0);
72 }
73 
74 static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale)
75 {
76   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
77 
78   PetscFunctionBegin;
79   Na->scale *= scale;
80   PetscFunctionReturn(0);
81 }
82 
83 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
84 {
85   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
86   PetscErrorCode ierr;
87 
88   PetscFunctionBegin;
89   if (left) {
90     if (!Na->left) {
91       ierr = VecDuplicate(left,&Na->left);CHKERRQ(ierr);
92       ierr = VecCopy(left,Na->left);CHKERRQ(ierr);
93     } else {
94       ierr = VecPointwiseMult(Na->left,left,Na->left);CHKERRQ(ierr);
95     }
96   }
97   if (right) {
98     if (!Na->right) {
99       ierr = VecDuplicate(right,&Na->right);CHKERRQ(ierr);
100       ierr = VecCopy(right,Na->right);CHKERRQ(ierr);
101     } else {
102       ierr = VecPointwiseMult(Na->right,right,Na->right);CHKERRQ(ierr);
103     }
104   }
105   PetscFunctionReturn(0);
106 }
107 
108 static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
109 {
110   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
111   Vec            xx  = 0;
112   PetscErrorCode ierr;
113 
114   PetscFunctionBegin;
115   ierr = PreScaleRight(N,x,&xx);CHKERRQ(ierr);
116   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
117   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
118   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
119   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
120   ierr = VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
121   ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
122   ierr = PostScaleLeft(N,y);CHKERRQ(ierr);
123   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
124   PetscFunctionReturn(0);
125 }
126 
127 static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
128 {
129   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
130   Vec            xx  = 0;
131   PetscErrorCode ierr;
132 
133   PetscFunctionBegin;
134   ierr = PreScaleRight(N,v1,&xx);CHKERRQ(ierr);
135   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
136   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
137   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
138   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
139   if (v2 == v3) {
140     if (Na->scale == (PetscScalar)1.0 && !Na->left) {
141       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
142       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
143     } else {
144       if (!Na->olwork) {ierr = VecDuplicate(v3,&Na->olwork);CHKERRQ(ierr);}
145       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
146       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
147       ierr = PostScaleLeft(N,Na->olwork);CHKERRQ(ierr);
148       ierr = VecAXPY(v3,Na->scale,Na->olwork);CHKERRQ(ierr);
149     }
150   } else {
151     ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
152     ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
153     ierr = PostScaleLeft(N,v3);CHKERRQ(ierr);
154     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
155   }
156   PetscFunctionReturn(0);
157 }
158 
159 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
160 {
161   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
162   Vec            xx  = 0;
163   PetscErrorCode ierr;
164 
165   PetscFunctionBegin;
166   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
167   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
168   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
169   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
170   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
171   ierr = VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
172   ierr = VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
173   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
174   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
175   PetscFunctionReturn(0);
176 }
177 
178 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
179 {
180   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
181   Vec            xx  = 0;
182   PetscErrorCode ierr;
183 
184   PetscFunctionBegin;
185   ierr = PreScaleLeft(N,v1,&xx);CHKERRQ(ierr);
186   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
187   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
188   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
189   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
190   if (v2 == v3) {
191     if (Na->scale == (PetscScalar)1.0 && !Na->right) {
192       ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
193       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
194     } else {
195       if (!Na->orwork) {ierr = VecDuplicate(v3,&Na->orwork);CHKERRQ(ierr);}
196       ierr = VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
197       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
198       ierr = PostScaleRight(N,Na->orwork);CHKERRQ(ierr);
199       ierr = VecAXPY(v3,Na->scale,Na->orwork);CHKERRQ(ierr);
200     }
201   } else {
202     ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
203     ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
204     ierr = PostScaleRight(N,v3);CHKERRQ(ierr);
205     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
206   }
207   PetscFunctionReturn(0);
208 }
209 
210 static PetscErrorCode MatDestroy_SubMatrix(Mat N)
211 {
212   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
213   PetscErrorCode ierr;
214 
215   PetscFunctionBegin;
216   ierr = ISDestroy(&Na->isrow);CHKERRQ(ierr);
217   ierr = ISDestroy(&Na->iscol);CHKERRQ(ierr);
218   ierr = VecDestroy(&Na->left);CHKERRQ(ierr);
219   ierr = VecDestroy(&Na->right);CHKERRQ(ierr);
220   ierr = VecDestroy(&Na->olwork);CHKERRQ(ierr);
221   ierr = VecDestroy(&Na->orwork);CHKERRQ(ierr);
222   ierr = VecDestroy(&Na->lwork);CHKERRQ(ierr);
223   ierr = VecDestroy(&Na->rwork);CHKERRQ(ierr);
224   ierr = VecScatterDestroy(&Na->lrestrict);CHKERRQ(ierr);
225   ierr = VecScatterDestroy(&Na->rprolong);CHKERRQ(ierr);
226   ierr = MatDestroy(&Na->A);CHKERRQ(ierr);
227   ierr = PetscFree(N->data);CHKERRQ(ierr);
228   PetscFunctionReturn(0);
229 }
230 
231 /*@
232    MatCreateSubMatrixVirtual - Creates a virtual matrix that acts as a submatrix
233 
234    Collective on Mat
235 
236    Input Parameters:
237 +  A - matrix that we will extract a submatrix of
238 .  isrow - rows to be present in the submatrix
239 -  iscol - columns to be present in the submatrix
240 
241    Output Parameters:
242 .  newmat - new matrix
243 
244    Level: developer
245 
246    Notes:
247    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
248 
249 .seealso: MatCreateSubMatrix(), MatSubMatrixVirtualUpdate()
250 @*/
251 PetscErrorCode MatCreateSubMatrixVirtual(Mat A,IS isrow,IS iscol,Mat *newmat)
252 {
253   Vec            left,right;
254   PetscInt       m,n;
255   Mat            N;
256   Mat_SubVirtual *Na;
257   PetscErrorCode ierr;
258 
259   PetscFunctionBegin;
260   PetscValidHeaderSpecific(A,MAT_CLASSID,1);
261   PetscValidHeaderSpecific(isrow,IS_CLASSID,2);
262   PetscValidHeaderSpecific(iscol,IS_CLASSID,3);
263   PetscValidPointer(newmat,4);
264   *newmat = 0;
265 
266   ierr = MatCreate(PetscObjectComm((PetscObject)A),&N);CHKERRQ(ierr);
267   ierr = ISGetLocalSize(isrow,&m);CHKERRQ(ierr);
268   ierr = ISGetLocalSize(iscol,&n);CHKERRQ(ierr);
269   ierr = MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);CHKERRQ(ierr);
270   ierr = PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);CHKERRQ(ierr);
271 
272   ierr      = PetscNewLog(N,&Na);CHKERRQ(ierr);
273   N->data   = (void*)Na;
274   ierr      = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
275   ierr      = PetscObjectReference((PetscObject)isrow);CHKERRQ(ierr);
276   ierr      = PetscObjectReference((PetscObject)iscol);CHKERRQ(ierr);
277   Na->A     = A;
278   Na->isrow = isrow;
279   Na->iscol = iscol;
280   Na->scale = 1.0;
281 
282   N->ops->destroy          = MatDestroy_SubMatrix;
283   N->ops->mult             = MatMult_SubMatrix;
284   N->ops->multadd          = MatMultAdd_SubMatrix;
285   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
286   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
287   N->ops->scale            = MatScale_SubMatrix;
288   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
289 
290   ierr = MatSetBlockSizesFromMats(N,A,A);CHKERRQ(ierr);
291   ierr = PetscLayoutSetUp(N->rmap);CHKERRQ(ierr);
292   ierr = PetscLayoutSetUp(N->cmap);CHKERRQ(ierr);
293 
294   ierr = MatCreateVecs(A,&Na->rwork,&Na->lwork);CHKERRQ(ierr);
295   ierr = VecCreate(PetscObjectComm((PetscObject)isrow),&left);CHKERRQ(ierr);
296   ierr = VecCreate(PetscObjectComm((PetscObject)iscol),&right);CHKERRQ(ierr);
297   ierr = VecSetSizes(left,m,PETSC_DETERMINE);CHKERRQ(ierr);
298   ierr = VecSetSizes(right,n,PETSC_DETERMINE);CHKERRQ(ierr);
299   ierr = VecSetUp(left);CHKERRQ(ierr);
300   ierr = VecSetUp(right);CHKERRQ(ierr);
301   ierr = VecScatterCreate(Na->lwork,isrow,left,NULL,&Na->lrestrict);CHKERRQ(ierr);
302   ierr = VecScatterCreate(right,NULL,Na->rwork,iscol,&Na->rprolong);CHKERRQ(ierr);
303   ierr = VecDestroy(&left);CHKERRQ(ierr);
304   ierr = VecDestroy(&right);CHKERRQ(ierr);
305 
306   N->assembled = PETSC_TRUE;
307 
308   ierr = MatSetUp(N);CHKERRQ(ierr);
309 
310   *newmat      = N;
311   PetscFunctionReturn(0);
312 }
313 
314 
315 /*@
316    MatSubMatrixVirtualUpdate - Updates a submatrix
317 
318    Collective on Mat
319 
320    Input Parameters:
321 +  N - submatrix to update
322 .  A - full matrix in the submatrix
323 .  isrow - rows in the update (same as the first time the submatrix was created)
324 -  iscol - columns in the update (same as the first time the submatrix was created)
325 
326    Level: developer
327 
328    Notes:
329    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
330 
331 .seealso: MatCreateSubMatrixVirtual()
332 @*/
333 PetscErrorCode  MatSubMatrixVirtualUpdate(Mat N,Mat A,IS isrow,IS iscol)
334 {
335   PetscErrorCode ierr;
336   PetscBool      flg;
337   Mat_SubVirtual *Na;
338 
339   PetscFunctionBegin;
340   PetscValidHeaderSpecific(N,MAT_CLASSID,1);
341   PetscValidHeaderSpecific(A,MAT_CLASSID,2);
342   PetscValidHeaderSpecific(isrow,IS_CLASSID,3);
343   PetscValidHeaderSpecific(iscol,IS_CLASSID,4);
344   ierr = PetscObjectTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);CHKERRQ(ierr);
345   if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix has wrong type");
346 
347   Na   = (Mat_SubVirtual*)N->data;
348   ierr = ISEqual(isrow,Na->isrow,&flg);CHKERRQ(ierr);
349   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
350   ierr = ISEqual(iscol,Na->iscol,&flg);CHKERRQ(ierr);
351   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");
352 
353   ierr  = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
354   ierr  = MatDestroy(&Na->A);CHKERRQ(ierr);
355   Na->A = A;
356 
357   Na->scale = 1.0;
358   ierr      = VecDestroy(&Na->left);CHKERRQ(ierr);
359   ierr      = VecDestroy(&Na->right);CHKERRQ(ierr);
360   PetscFunctionReturn(0);
361 }
362