xref: /petsc/src/mat/impls/submat/submat.c (revision efe48dd8cf8b935accbbb9f4bcb20bc83865fa4d)
1 #define PETSCMAT_DLL
2 
3 #include "private/matimpl.h"          /*I "petscmat.h" I*/
4 
5 typedef struct {
6   IS isrow,iscol;               /* rows and columns in submatrix, only used to check consistency */
7   Vec left,right;               /* optional scaling */
8   Vec olwork,orwork;            /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
9   Vec lwork,rwork;              /* work vectors inside the scatters */
10   VecScatter lrestrict,rprolong;
11   Mat A;
12   PetscScalar scale;
13 } Mat_SubMatrix;
14 
15 #undef __FUNCT__
16 #define __FUNCT__ "PreScaleLeft"
17 static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
18 {
19   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
20   PetscErrorCode ierr;
21 
22   PetscFunctionBegin;
23   if (!Na->left) {
24     *xx = x;
25   } else {
26     if (!Na->olwork) {
27       ierr = VecDuplicate(Na->left,&Na->olwork);CHKERRQ(ierr);
28     }
29     ierr = VecPointwiseMult(Na->olwork,x,Na->left);CHKERRQ(ierr);
30     *xx = Na->olwork;
31   }
32   PetscFunctionReturn(0);
33 }
34 
35 #undef __FUNCT__
36 #define __FUNCT__ "PreScaleRight"
37 static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
38 {
39   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
40   PetscErrorCode ierr;
41 
42   PetscFunctionBegin;
43   if (!Na->right) {
44     *xx = x;
45   } else {
46     if (!Na->orwork) {
47       ierr = VecDuplicate(Na->right,&Na->orwork);CHKERRQ(ierr);
48     }
49     ierr = VecPointwiseMult(Na->orwork,x,Na->right);CHKERRQ(ierr);
50     *xx = Na->orwork;
51   }
52   PetscFunctionReturn(0);
53 }
54 
55 #undef __FUNCT__
56 #define __FUNCT__ "PostScaleLeft"
57 static PetscErrorCode PostScaleLeft(Mat N,Vec x)
58 {
59   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
60   PetscErrorCode ierr;
61 
62   PetscFunctionBegin;
63   if (Na->left) {
64     ierr = VecPointwiseMult(x,x,Na->left);CHKERRQ(ierr);
65   }
66   PetscFunctionReturn(0);
67 }
68 
69 #undef __FUNCT__
70 #define __FUNCT__ "PostScaleRight"
71 static PetscErrorCode PostScaleRight(Mat N,Vec x)
72 {
73   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
74   PetscErrorCode ierr;
75 
76   PetscFunctionBegin;
77   if (Na->right) {
78     ierr = VecPointwiseMult(x,x,Na->right);CHKERRQ(ierr);
79   }
80   PetscFunctionReturn(0);
81 }
82 
83 #undef __FUNCT__
84 #define __FUNCT__ "MatScale_SubMatrix"
85 static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale)
86 {
87   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
88 
89   PetscFunctionBegin;
90   Na->scale *= scale;
91   PetscFunctionReturn(0);
92 }
93 
94 #undef __FUNCT__
95 #define __FUNCT__ "MatDiagonalScale_SubMatrix"
96 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
97 {
98   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;
99   PetscErrorCode ierr;
100 
101   PetscFunctionBegin;
102   if (left) {
103     if (!Na->left) {
104       ierr = VecDuplicate(left,&Na->left);CHKERRQ(ierr);
105       ierr = VecCopy(left,Na->left);CHKERRQ(ierr);
106     } else {
107       ierr = VecPointwiseMult(Na->left,left,Na->left);CHKERRQ(ierr);
108     }
109   }
110   if (right) {
111     if (!Na->right) {
112       ierr = VecDuplicate(right,&Na->right);CHKERRQ(ierr);
113       ierr = VecCopy(right,Na->right);CHKERRQ(ierr);
114     } else {
115       ierr = VecPointwiseMult(Na->right,right,Na->right);CHKERRQ(ierr);
116     }
117   }
118   PetscFunctionReturn(0);
119 }
120 
121 #undef __FUNCT__
122 #define __FUNCT__ "MatMult_SubMatrix"
123 static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
124 {
125   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
126   Vec             xx=0;
127   PetscErrorCode  ierr;
128 
129   PetscFunctionBegin;
130   ierr = PreScaleRight(N,x,&xx);CHKERRQ(ierr);
131   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
132   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
133   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
134   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
135   ierr = VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
136   ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
137   ierr = PostScaleLeft(N,y);CHKERRQ(ierr);
138   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
139   PetscFunctionReturn(0);
140 }
141 
142 #undef __FUNCT__
143 #define __FUNCT__ "MatMultAdd_SubMatrix"
144 static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
145 {
146   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
147   Vec             xx=0;
148   PetscErrorCode  ierr;
149 
150   PetscFunctionBegin;
151   ierr = PreScaleRight(N,v1,&xx);CHKERRQ(ierr);
152   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
153   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
154   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
155   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
156   if (v2 == v3) {
157     if (Na->scale == 1.0 && !Na->left) {
158       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
159       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
160     } else {
161       if (!Na->olwork) {ierr = VecDuplicate(v3,&Na->olwork);CHKERRQ(ierr);}
162       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
163       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
164       ierr = PostScaleLeft(N,Na->olwork);CHKERRQ(ierr);
165       ierr = VecAXPY(v3,Na->scale,Na->olwork);CHKERRQ(ierr);
166     }
167   } else {
168     ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
169     ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
170     ierr = PostScaleLeft(N,v3);CHKERRQ(ierr);
171     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
172   }
173   PetscFunctionReturn(0);
174 }
175 
176 #undef __FUNCT__
177 #define __FUNCT__ "MatMultTranspose_SubMatrix"
178 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
179 {
180   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
181   Vec             xx=0;
182   PetscErrorCode ierr;
183 
184   PetscFunctionBegin;
185   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
186   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
187   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
188   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
189   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
190   ierr = VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
191   ierr = VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
192   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
193   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
194   PetscFunctionReturn(0);
195 }
196 
197 #undef __FUNCT__
198 #define __FUNCT__ "MatMultTransposeAdd_SubMatrix"
199 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
200 {
201   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
202   Vec             xx =0;
203   PetscErrorCode ierr;
204 
205   PetscFunctionBegin;
206   ierr = PreScaleLeft(N,v1,&xx);CHKERRQ(ierr);
207   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
208   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
209   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
210   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
211   if (v2 == v3) {
212     if (Na->scale == 1.0 && !Na->right) {
213       ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
214       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
215     } else {
216       if (!Na->orwork) {ierr = VecDuplicate(v3,&Na->orwork);CHKERRQ(ierr);}
217       ierr = VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
218       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
219       ierr = PostScaleRight(N,Na->orwork);CHKERRQ(ierr);
220       ierr = VecAXPY(v3,Na->scale,Na->orwork);CHKERRQ(ierr);
221     }
222   } else {
223     ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
224     ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
225     ierr = PostScaleRight(N,v3);CHKERRQ(ierr);
226     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
227   }
228   PetscFunctionReturn(0);
229 }
230 
231 #undef __FUNCT__
232 #define __FUNCT__ "MatDestroy_SubMatrix"
233 static PetscErrorCode MatDestroy_SubMatrix(Mat N)
234 {
235   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
236   PetscErrorCode ierr;
237 
238   PetscFunctionBegin;
239   ierr = ISDestroy(Na->isrow);CHKERRQ(ierr);
240   ierr = ISDestroy(Na->iscol);CHKERRQ(ierr);
241   if (Na->left) {ierr = VecDestroy(Na->left);CHKERRQ(ierr);}
242   if (Na->right) {ierr = VecDestroy(Na->right);CHKERRQ(ierr);}
243   if (Na->olwork) {ierr = VecDestroy(Na->olwork);CHKERRQ(ierr);}
244   if (Na->orwork) {ierr = VecDestroy(Na->orwork);CHKERRQ(ierr);}
245   ierr = VecDestroy(Na->lwork);CHKERRQ(ierr);
246   ierr = VecDestroy(Na->rwork);CHKERRQ(ierr);
247   ierr = VecScatterDestroy(Na->lrestrict);CHKERRQ(ierr);
248   ierr = VecScatterDestroy(Na->rprolong);CHKERRQ(ierr);
249   ierr = MatDestroy(Na->A);CHKERRQ(ierr);
250   ierr = PetscFree(Na);CHKERRQ(ierr);
251   PetscFunctionReturn(0);
252 }
253 
254 #undef __FUNCT__
255 #define __FUNCT__ "MatCreateSubMatrix"
256 /*@
257    MatCreateSubMatrix - Creates a composite matrix that acts as a submatrix
258 
259    Collective on Mat
260 
261    Input Parameters:
262 +  A - matrix that we will extract a submatrix of
263 .  isrow - rows to be present in the submatrix
264 -  iscol - columns to be present in the submatrix
265 
266    Output Parameters:
267 .  newmat - new matrix
268 
269    Level: developer
270 
271    Notes:
272    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.
273 
274 .seealso: MatGetSubMatrix(), MatSubMatrixUpdate()
275 @*/
276 PetscErrorCode PETSCMAT_DLLEXPORT MatCreateSubMatrix(Mat A,IS isrow,IS iscol,Mat *newmat)
277 {
278   Vec            left,right;
279   PetscInt       m,n;
280   Mat            N;
281   Mat_SubMatrix *Na;
282   PetscErrorCode ierr;
283 
284   PetscFunctionBegin;
285   PetscValidHeaderSpecific(A,MAT_CLASSID,1);
286   PetscValidHeaderSpecific(isrow,IS_CLASSID,2);
287   PetscValidHeaderSpecific(iscol,IS_CLASSID,3);
288   PetscValidPointer(newmat,4);
289   *newmat = 0;
290 
291   ierr = MatCreate(((PetscObject)A)->comm,&N);CHKERRQ(ierr);
292   ierr = ISGetLocalSize(isrow,&m);CHKERRQ(ierr);
293   ierr = ISGetLocalSize(iscol,&n);CHKERRQ(ierr);
294   ierr = MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);CHKERRQ(ierr);
295   ierr = PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);CHKERRQ(ierr);
296 
297   ierr = PetscNewLog(N,Mat_SubMatrix,&Na);CHKERRQ(ierr);
298   N->data   = (void*)Na;
299   ierr = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
300   ierr = PetscObjectReference((PetscObject)isrow);CHKERRQ(ierr);
301   ierr = PetscObjectReference((PetscObject)iscol);CHKERRQ(ierr);
302   Na->A     = A;
303   Na->isrow = isrow;
304   Na->iscol = iscol;
305   Na->scale = 1.0;
306 
307   N->ops->destroy          = MatDestroy_SubMatrix;
308   N->ops->mult             = MatMult_SubMatrix;
309   N->ops->multadd          = MatMultAdd_SubMatrix;
310   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
311   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
312   N->ops->scale            = MatScale_SubMatrix;
313   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
314 
315   N->assembled = PETSC_TRUE;
316 
317   ierr = PetscLayoutSetBlockSize(N->rmap,A->rmap->bs);CHKERRQ(ierr);
318   ierr = PetscLayoutSetBlockSize(N->cmap,A->cmap->bs);CHKERRQ(ierr);
319   ierr = PetscLayoutSetUp(N->rmap);CHKERRQ(ierr);
320   ierr = PetscLayoutSetUp(N->cmap);CHKERRQ(ierr);
321 
322   ierr = MatGetVecs(A,&Na->rwork,&Na->lwork);CHKERRQ(ierr);
323   ierr = VecCreate(((PetscObject)isrow)->comm,&left);CHKERRQ(ierr);
324   ierr = VecCreate(((PetscObject)iscol)->comm,&right);CHKERRQ(ierr);
325   ierr = VecSetSizes(left,m,PETSC_DETERMINE);CHKERRQ(ierr);
326   ierr = VecSetSizes(right,n,PETSC_DETERMINE);CHKERRQ(ierr);
327   ierr = VecSetUp(left);CHKERRQ(ierr);
328   ierr = VecSetUp(right);CHKERRQ(ierr);
329   ierr = VecScatterCreate(Na->lwork,isrow,left,PETSC_NULL,&Na->lrestrict);CHKERRQ(ierr);
330   ierr = VecScatterCreate(right,PETSC_NULL,Na->rwork,iscol,&Na->rprolong);CHKERRQ(ierr);
331   ierr = VecDestroy(left);CHKERRQ(ierr);
332   ierr = VecDestroy(right);CHKERRQ(ierr);
333 
334   *newmat = N;
335   PetscFunctionReturn(0);
336 }
337 
338 
339 #undef __FUNCT__
340 #define __FUNCT__ "MatSubMatrixUpdate"
341 /*@
342    MatSubMatrixUpdate - Updates a submatrix
343 
344    Collective on Mat
345 
346    Input Parameters:
347 +  N - submatrix to update
348 .  A - full matrix in the submatrix
349 .  isrow - rows in the update (same as the first time the submatrix was created)
350 -  iscol - columns in the update (same as the first time the submatrix was created)
351 
352    Level: developer
353 
354    Notes:
355    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.
356 
357 .seealso: MatGetSubMatrix(), MatCreateSubMatrix()
358 @*/
359 PetscErrorCode PETSCMAT_DLLEXPORT MatSubMatrixUpdate(Mat N,Mat A,IS isrow,IS iscol)
360 {
361   PetscErrorCode  ierr;
362   PetscBool       flg;
363   Mat_SubMatrix  *Na;
364 
365   PetscFunctionBegin;
366   PetscValidHeaderSpecific(N,MAT_CLASSID,1);
367   PetscValidHeaderSpecific(A,MAT_CLASSID,2);
368   PetscValidHeaderSpecific(isrow,IS_CLASSID,3);
369   PetscValidHeaderSpecific(iscol,IS_CLASSID,4);
370   ierr = PetscTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);CHKERRQ(ierr);
371   if (!flg) SETERRQ(((PetscObject)A)->comm,PETSC_ERR_ARG_WRONG,"Matrix has wrong type");
372 
373   Na = (Mat_SubMatrix*)N->data;
374   ierr = ISEqual(isrow,Na->isrow,&flg);CHKERRQ(ierr);
375   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
376   ierr = ISEqual(iscol,Na->iscol,&flg);CHKERRQ(ierr);
377   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");
378 
379   ierr = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
380   ierr = MatDestroy(Na->A);CHKERRQ(ierr);
381   Na->A = A;
382 
383   Na->scale = 1.0;
384   if (Na->left) {ierr = VecDestroy(Na->left);CHKERRQ(ierr);}
385   if (Na->right) {ierr = VecDestroy(Na->right);CHKERRQ(ierr);}
386   PetscFunctionReturn(0);
387 }
388