xref: /petsc/src/mat/impls/submat/submat.c (revision 0700a8246d308f50502909ba325e6169d3ee27eb)
1 #define PETSCMAT_DLL
2 
3 #include "private/matimpl.h"          /*I "petscmat.h" I*/
4 
5 typedef struct {
6   IS isrow,iscol;               /* rows and columns in submatrix, only used to check consistency */
7   Vec left,right;               /* optional scaling */
8   Vec olwork,orwork;            /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
9   Vec lwork,rwork;              /* work vectors inside the scatters */
10   VecScatter lrestrict,rprolong;
11   Mat A;
12   PetscScalar scale;
13 } Mat_SubMatrix;
14 
15 #undef __FUNCT__
16 #define __FUNCT__ "PreScaleLeft"
17 static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
18 {
19   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
20   PetscErrorCode ierr;
21 
22   PetscFunctionBegin;
23   if (!Na->left) {
24     *xx = x;
25   } else {
26     if (!Na->olwork) {
27       ierr = VecDuplicate(Na->left,&Na->olwork);CHKERRQ(ierr);
28     }
29     ierr = VecPointwiseMult(Na->left,x,Na->olwork);CHKERRQ(ierr);
30     *xx = Na->olwork;
31   }
32   PetscFunctionReturn(0);
33 }
34 
35 #undef __FUNCT__
36 #define __FUNCT__ "PreScaleRight"
37 static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
38 {
39   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
40   PetscErrorCode ierr;
41 
42   PetscFunctionBegin;
43   if (!Na->right) {
44     *xx = x;
45   } else {
46     if (!Na->orwork) {
47       ierr = VecDuplicate(Na->right,&Na->orwork);CHKERRQ(ierr);
48     }
49     ierr = VecPointwiseMult(Na->right,x,Na->orwork);CHKERRQ(ierr);
50     *xx = Na->orwork;
51   }
52   PetscFunctionReturn(0);
53 }
54 
55 #undef __FUNCT__
56 #define __FUNCT__ "PostScaleLeft"
57 static PetscErrorCode PostScaleLeft(Mat N,Vec x)
58 {
59   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
60   PetscErrorCode ierr;
61 
62   PetscFunctionBegin;
63   if (Na->left) {
64     ierr = VecPointwiseMult(x,x,Na->left);CHKERRQ(ierr);
65   }
66   PetscFunctionReturn(0);
67 }
68 
69 #undef __FUNCT__
70 #define __FUNCT__ "PostScaleRight"
71 static PetscErrorCode PostScaleRight(Mat N,Vec x)
72 {
73   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
74   PetscErrorCode ierr;
75 
76   PetscFunctionBegin;
77   if (Na->right) {
78     ierr = VecPointwiseMult(x,x,Na->right);CHKERRQ(ierr);
79   }
80   PetscFunctionReturn(0);
81 }
82 
83 #undef __FUNCT__
84 #define __FUNCT__ "MatScale_SubMatrix"
85 static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale)
86 {
87   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
88 
89   PetscFunctionBegin;
90   Na->scale *= scale;
91   PetscFunctionReturn(0);
92 }
93 
94 #undef __FUNCT__
95 #define __FUNCT__ "MatDiagonalScale_SubMatrix"
96 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
97 {
98   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
99   PetscErrorCode ierr;
100 
101   PetscFunctionBegin;
102   if (left) {
103     if (!Na->left) {
104       ierr = VecDuplicate(left,&Na->left);CHKERRQ(ierr);
105       ierr = VecCopy(left,Na->left);CHKERRQ(ierr);
106     } else {
107       ierr = VecPointwiseMult(Na->left,left,Na->left);CHKERRQ(ierr);
108     }
109   }
110   if (right) {
111     if (!Na->right) {
112       ierr = VecDuplicate(right,&Na->right);CHKERRQ(ierr);
113       ierr = VecCopy(right,Na->right);CHKERRQ(ierr);
114     } else {
115       ierr = VecPointwiseMult(Na->right,right,Na->right);CHKERRQ(ierr);
116     }
117   }
118   PetscFunctionReturn(0);
119 }
120 
121 #undef __FUNCT__
122 #define __FUNCT__ "MatMult_SubMatrix"
123 static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
124 {
125   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
126   Vec             xx=0;
127   PetscErrorCode  ierr;
128 
129   PetscFunctionBegin;
130   ierr = PreScaleRight(N,x,&xx);CHKERRQ(ierr);
131   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
132   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
133   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
134   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
135   ierr = VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
136   ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
137   ierr = PostScaleLeft(N,y);CHKERRQ(ierr);
138   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
139   PetscFunctionReturn(0);
140 }
141 
142 #undef __FUNCT__
143 #define __FUNCT__ "MatMultAdd_SubMatrix"
144 static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
145 {
146   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
147   Vec             xx=0;
148   PetscErrorCode  ierr;
149 
150   PetscFunctionBegin;
151   ierr = PreScaleRight(N,v1,&xx);CHKERRQ(ierr);
152   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
153   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
154   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
155   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
156   ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
157   ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
158   ierr = PostScaleLeft(N,v3);CHKERRQ(ierr);
159   ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
160   PetscFunctionReturn(0);
161 }
162 
163 #undef __FUNCT__
164 #define __FUNCT__ "MatMultTranspose_SubMatrix"
165 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
166 {
167   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
168   Vec             xx=0;
169   PetscErrorCode ierr;
170 
171   PetscFunctionBegin;
172   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
173   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
174   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
175   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
176   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
177   ierr = VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
178   ierr = VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
179   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
180   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
181   PetscFunctionReturn(0);
182 }
183 
184 #undef __FUNCT__
185 #define __FUNCT__ "MatMultTransposeAdd_SubMatrix"
186 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
187 {
188   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
189   Vec             xx =0;
190   PetscErrorCode ierr;
191 
192   PetscFunctionBegin;
193   ierr = PreScaleLeft(N,v1,&xx);CHKERRQ(ierr);
194   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
195   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
196   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
197   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
198   ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
199   ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
200   ierr = PostScaleRight(N,v3);CHKERRQ(ierr);
201   ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
202   PetscFunctionReturn(0);
203 }
204 
205 #undef __FUNCT__
206 #define __FUNCT__ "MatDestroy_SubMatrix"
207 static PetscErrorCode MatDestroy_SubMatrix(Mat N)
208 {
209   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
210   PetscErrorCode ierr;
211 
212   PetscFunctionBegin;
213   ierr = ISDestroy(Na->isrow);CHKERRQ(ierr);
214   ierr = ISDestroy(Na->iscol);CHKERRQ(ierr);
215   if (Na->left) {ierr = VecDestroy(Na->left);CHKERRQ(ierr);}
216   if (Na->right) {ierr = VecDestroy(Na->right);CHKERRQ(ierr);}
217   if (Na->olwork) {ierr = VecDestroy(Na->olwork);CHKERRQ(ierr);}
218   if (Na->orwork) {ierr = VecDestroy(Na->orwork);CHKERRQ(ierr);}
219   ierr = VecDestroy(Na->lwork);CHKERRQ(ierr);
220   ierr = VecDestroy(Na->rwork);CHKERRQ(ierr);
221   ierr = VecScatterDestroy(Na->lrestrict);CHKERRQ(ierr);
222   ierr = VecScatterDestroy(Na->rprolong);CHKERRQ(ierr);
223   ierr = MatDestroy(Na->A);CHKERRQ(ierr);
224   ierr = PetscFree(Na);CHKERRQ(ierr);
225   PetscFunctionReturn(0);
226 }
227 
228 #undef __FUNCT__
229 #define __FUNCT__ "MatCreateSubMatrix"
230 /*@
231    MatCreateSubMatrix - Creates a composite matrix that acts as a submatrix
232 
233    Collective on Mat
234 
235    Input Parameters:
236 +  A - matrix that we will extract a submatrix of
237 .  isrow - rows to be present in the submatrix
238 -  iscol - columns to be present in the submatrix
239 
240    Output Parameters:
241 .  newmat - new matrix
242 
243    Level: developer
244 
245    Notes:
246    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.
247 
248 .seealso: MatGetSubMatrix(), MatSubMatrixUpdate()
249 @*/
250 PetscErrorCode PETSCMAT_DLLEXPORT MatCreateSubMatrix(Mat A,IS isrow,IS iscol,Mat *newmat)
251 {
252   Vec            left,right;
253   PetscInt       m,n;
254   Mat            N;
255   Mat_SubMatrix *Na;
256   PetscErrorCode ierr;
257 
258   PetscFunctionBegin;
259   PetscValidHeaderSpecific(A,MAT_CLASSID,1);
260   PetscValidHeaderSpecific(isrow,IS_CLASSID,2);
261   PetscValidHeaderSpecific(iscol,IS_CLASSID,3);
262   PetscValidPointer(newmat,4);
263   *newmat = 0;
264 
265   ierr = MatCreate(((PetscObject)A)->comm,&N);CHKERRQ(ierr);
266   ierr = ISGetLocalSize(isrow,&m);CHKERRQ(ierr);
267   ierr = ISGetLocalSize(iscol,&n);CHKERRQ(ierr);
268   ierr = MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);CHKERRQ(ierr);
269   ierr = PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);CHKERRQ(ierr);
270 
271   ierr = PetscNewLog(N,Mat_SubMatrix,&Na);CHKERRQ(ierr);
272   N->data   = (void*)Na;
273   ierr = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
274   ierr = PetscObjectReference((PetscObject)isrow);CHKERRQ(ierr);
275   ierr = PetscObjectReference((PetscObject)iscol);CHKERRQ(ierr);
276   Na->A     = A;
277   Na->isrow = isrow;
278   Na->iscol = iscol;
279   Na->scale = 1.0;
280 
281   N->ops->destroy          = MatDestroy_SubMatrix;
282   N->ops->mult             = MatMult_SubMatrix;
283   N->ops->multadd          = MatMultAdd_SubMatrix;
284   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
285   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
286   N->ops->scale            = MatScale_SubMatrix;
287   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
288 
289   N->assembled = PETSC_TRUE;
290 
291   ierr = PetscLayoutSetBlockSize(N->rmap,A->rmap->bs);CHKERRQ(ierr);
292   ierr = PetscLayoutSetBlockSize(N->cmap,A->cmap->bs);CHKERRQ(ierr);
293   ierr = PetscLayoutSetUp(N->rmap);CHKERRQ(ierr);
294   ierr = PetscLayoutSetUp(N->cmap);CHKERRQ(ierr);
295 
296   ierr = MatGetVecs(A,&Na->rwork,&Na->lwork);CHKERRQ(ierr);
297   ierr = VecCreate(((PetscObject)isrow)->comm,&left);CHKERRQ(ierr);
298   ierr = VecCreate(((PetscObject)iscol)->comm,&right);CHKERRQ(ierr);
299   ierr = VecSetSizes(left,m,PETSC_DETERMINE);CHKERRQ(ierr);
300   ierr = VecSetSizes(right,n,PETSC_DETERMINE);CHKERRQ(ierr);
301   ierr = VecSetUp(left);CHKERRQ(ierr);
302   ierr = VecSetUp(right);CHKERRQ(ierr);
303   ierr = VecScatterCreate(Na->lwork,isrow,left,PETSC_NULL,&Na->lrestrict);CHKERRQ(ierr);
304   ierr = VecScatterCreate(right,PETSC_NULL,Na->rwork,iscol,&Na->rprolong);CHKERRQ(ierr);
305   ierr = VecDestroy(left);CHKERRQ(ierr);
306   ierr = VecDestroy(right);CHKERRQ(ierr);
307 
308   *newmat = N;
309   PetscFunctionReturn(0);
310 }
311 
312 
313 #undef __FUNCT__
314 #define __FUNCT__ "MatSubMatrixUpdate"
315 /*@
316    MatSubMatrixUpdate - Updates a submatrix
317 
318    Collective on Mat
319 
320    Input Parameters:
321 +  N - submatrix to update
322 .  A - full matrix in the submatrix
323 .  isrow - rows in the update (same as the first time the submatrix was created)
324 -  iscol - columns in the update (same as the first time the submatrix was created)
325 
326    Level: developer
327 
328    Notes:
329    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.
330 
331 .seealso: MatGetSubMatrix(), MatCreateSubMatrix()
332 @*/
333 PetscErrorCode PETSCMAT_DLLEXPORT MatSubMatrixUpdate(Mat N,Mat A,IS isrow,IS iscol)
334 {
335   PetscErrorCode  ierr;
336   PetscTruth      flg;
337   Mat_SubMatrix  *Na;
338 
339   PetscFunctionBegin;
340   PetscValidHeaderSpecific(N,MAT_CLASSID,1);
341   PetscValidHeaderSpecific(A,MAT_CLASSID,2);
342   PetscValidHeaderSpecific(isrow,IS_CLASSID,3);
343   PetscValidHeaderSpecific(iscol,IS_CLASSID,4);
344   ierr = PetscTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);CHKERRQ(ierr);
345   if (!flg) SETERRQ(PETSC_ERR_ARG_WRONG,"Matrix has wrong type");
346 
347   Na = (Mat_SubMatrix*)N->data;
348   ierr = ISEqual(isrow,Na->isrow,&flg);CHKERRQ(ierr);
349   if (!flg) SETERRQ(PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
350   ierr = ISEqual(iscol,Na->iscol,&flg);CHKERRQ(ierr);
351   if (!flg) SETERRQ(PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");
352 
353   ierr = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
354   ierr = MatDestroy(Na->A);CHKERRQ(ierr);
355   Na->A = A;
356 
357   Na->scale = 1.0;
358   if (Na->left) {ierr = VecDestroy(Na->left);CHKERRQ(ierr);}
359   if (Na->right) {ierr = VecDestroy(Na->right);CHKERRQ(ierr);}
360   PetscFunctionReturn(0);
361 }
362