xref: /petsc/src/mat/impls/submat/submat.c (revision 487a658c8b32ba712a1dc8280daad2fd70c1dcd9)
1 
2 #include <petsc/private/matimpl.h>          /*I "petscmat.h" I*/
3 
4 typedef struct {
5   IS          isrow,iscol;      /* rows and columns in submatrix, only used to check consistency */
6   Vec         left,right;       /* optional scaling */
7   Vec         olwork,orwork;    /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
8   Vec         lwork,rwork;      /* work vectors inside the scatters */
9   Vec         dshift;
10   VecScatter  lrestrict,rprolong;
11   Mat         A;
12   PetscScalar vscale, axpy_vscale;
13   PetscScalar vshift, axpy_vshift;
14 } Mat_SubVirtual;
15 
16 static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
17 {
18   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
19   PetscErrorCode ierr;
20 
21   PetscFunctionBegin;
22   if (!Na->left) {
23     *xx = x;
24   } else {
25     if (!Na->olwork) {
26       ierr = VecDuplicate(Na->left,&Na->olwork);CHKERRQ(ierr);
27     }
28     ierr = VecPointwiseMult(Na->olwork,x,Na->left);CHKERRQ(ierr);
29     *xx  = Na->olwork;
30   }
31   PetscFunctionReturn(0);
32 }
33 
34 static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
35 {
36   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
37   PetscErrorCode ierr;
38 
39   PetscFunctionBegin;
40   if (!Na->right) {
41     *xx = x;
42   } else {
43     if (!Na->orwork) {
44       ierr = VecDuplicate(Na->right,&Na->orwork);CHKERRQ(ierr);
45     }
46     ierr = VecPointwiseMult(Na->orwork,x,Na->right);CHKERRQ(ierr);
47     *xx  = Na->orwork;
48   }
49   PetscFunctionReturn(0);
50 }
51 
52 static PetscErrorCode PostScaleLeft(Mat N,Vec x)
53 {
54   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
55   PetscErrorCode ierr;
56 
57   PetscFunctionBegin;
58   if (Na->left) {
59     ierr = VecPointwiseMult(x,x,Na->left);CHKERRQ(ierr);
60   }
61   PetscFunctionReturn(0);
62 }
63 
64 static PetscErrorCode PostScaleRight(Mat N,Vec x)
65 {
66   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
67   PetscErrorCode ierr;
68 
69   PetscFunctionBegin;
70   if (Na->right) {
71     ierr = VecPointwiseMult(x,x,Na->right);CHKERRQ(ierr);
72   }
73   PetscFunctionReturn(0);
74 }
75 
76 /*
77          Y = vscale*Y + diag(dshift)*X + vshift*X
78 
79          On input Y already contains A*x
80 */
81 static PetscErrorCode MatSubmatShiftAndScale(Mat A,Vec X,Vec Y)
82 {
83   Mat_SubVirtual *Na = (Mat_SubVirtual*)A->data;
84   PetscErrorCode ierr;
85 
86   PetscFunctionBegin;
87   if (Na->dshift) {          /* get arrays because there is no VecPointwiseMultAdd() */
88     PetscInt          i,m;
89     const PetscScalar *x,*d;
90     PetscScalar       *y;
91     ierr = VecGetLocalSize(X,&m);CHKERRQ(ierr);
92     ierr = VecGetArrayRead(Na->dshift,&d);CHKERRQ(ierr);
93     ierr = VecGetArrayRead(X,&x);CHKERRQ(ierr);
94     ierr = VecGetArray(Y,&y);CHKERRQ(ierr);
95     for (i=0; i<m; i++) y[i] = Na->vscale*y[i] + d[i]*x[i];
96     ierr = VecRestoreArrayRead(Na->dshift,&d);CHKERRQ(ierr);
97     ierr = VecRestoreArrayRead(X,&x);CHKERRQ(ierr);
98     ierr = VecRestoreArray(Y,&y);CHKERRQ(ierr);
99   } else {
100     ierr = VecScale(Y,Na->vscale);CHKERRQ(ierr);
101   }
102   if (Na->vshift != 0.0) {ierr = VecAXPY(Y,Na->vshift,X);CHKERRQ(ierr);} /* if test is for non-square matrices */
103   PetscFunctionReturn(0);
104 }
105 
106 static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar a)
107 {
108   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
109   PetscErrorCode ierr;
110 
111   PetscFunctionBegin;
112   Na->vscale *= a;
113   Na->vshift *= a;
114   if (Na->dshift) {
115     ierr = VecScale(Na->dshift,a);CHKERRQ(ierr);
116   }
117   Na->axpy_vscale *= a;
118   PetscFunctionReturn(0);
119 }
120 
121 static PetscErrorCode MatShift_SubMatrix(Mat N,PetscScalar a)
122 {
123   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
124   PetscErrorCode ierr;
125 
126   PetscFunctionBegin;
127   if (Na->left || Na->right) {
128     if (!Na->dshift) {
129       ierr = VecDuplicate(Na->left ? Na->left : Na->right, &Na->dshift);CHKERRQ(ierr);
130       ierr = VecSet(Na->dshift,a);CHKERRQ(ierr);
131     } else {
132       if (Na->left)  {ierr = VecPointwiseMult(Na->dshift,Na->dshift,Na->left);CHKERRQ(ierr);}
133       if (Na->right) {ierr = VecPointwiseMult(Na->dshift,Na->dshift,Na->right);CHKERRQ(ierr);}
134       ierr = VecShift(Na->dshift,a);CHKERRQ(ierr);
135     }
136     if (Na->left)  {ierr = VecPointwiseDivide(Na->dshift,Na->dshift,Na->left);CHKERRQ(ierr);}
137     if (Na->right) {ierr = VecPointwiseDivide(Na->dshift,Na->dshift,Na->right);CHKERRQ(ierr);}
138   } else Na->vshift += a;
139   PetscFunctionReturn(0);
140 }
141 
142 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
143 {
144   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
145   PetscErrorCode ierr;
146 
147   PetscFunctionBegin;
148   if (left) {
149     if (!Na->left) {
150       ierr = VecDuplicate(left,&Na->left);CHKERRQ(ierr);
151       ierr = VecCopy(left,Na->left);CHKERRQ(ierr);
152     } else {
153       ierr = VecPointwiseMult(Na->left,left,Na->left);CHKERRQ(ierr);
154     }
155   }
156   if (right) {
157     if (!Na->right) {
158       ierr = VecDuplicate(right,&Na->right);CHKERRQ(ierr);
159       ierr = VecCopy(right,Na->right);CHKERRQ(ierr);
160     } else {
161       ierr = VecPointwiseMult(Na->right,right,Na->right);CHKERRQ(ierr);
162     }
163   }
164   PetscFunctionReturn(0);
165 }
166 
167 static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
168 {
169   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
170   Vec            xx  = 0;
171   PetscErrorCode ierr;
172 
173   PetscFunctionBegin;
174   ierr = PreScaleRight(N,x,&xx);CHKERRQ(ierr);
175   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
176   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
177   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
178   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
179   ierr = VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
180   ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
181   ierr = MatSubmatShiftAndScale(N,xx,y);CHKERRQ(ierr);
182   ierr = PostScaleLeft(N,y);CHKERRQ(ierr);
183   PetscFunctionReturn(0);
184 }
185 
186 static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
187 {
188   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
189   Vec            xx  = 0;
190   PetscErrorCode ierr;
191 
192   PetscFunctionBegin;
193   ierr = PreScaleRight(N,v1,&xx);CHKERRQ(ierr);
194   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
195   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
196   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
197   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
198   if (v2 == v3) {
199     if (!Na->olwork) {ierr = VecDuplicate(v3,&Na->olwork);CHKERRQ(ierr);}
200     ierr = VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
201     ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
202     ierr = MatSubmatShiftAndScale(N,xx,Na->olwork);CHKERRQ(ierr);
203     ierr = PostScaleLeft(N,Na->olwork);CHKERRQ(ierr);
204     ierr = VecAXPY(v3,1.0,Na->olwork);CHKERRQ(ierr);
205   } else {
206     ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
207     ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
208     ierr = MatSubmatShiftAndScale(N,xx,v3);CHKERRQ(ierr);
209     ierr = PostScaleLeft(N,v3);CHKERRQ(ierr);
210     ierr = VecAXPY(v3,1.0,v2);CHKERRQ(ierr);
211   }
212   PetscFunctionReturn(0);
213 }
214 
215 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
216 {
217   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
218   Vec            xx  = 0;
219   PetscErrorCode ierr;
220 
221   PetscFunctionBegin;
222   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
223   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
224   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
225   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
226   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
227   ierr = VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
228   ierr = VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
229   ierr = MatSubmatShiftAndScale(N,xx,y);CHKERRQ(ierr);
230   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
231   PetscFunctionReturn(0);
232 }
233 
234 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
235 {
236   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
237   Vec            xx  = 0;
238   PetscErrorCode ierr;
239 
240   PetscFunctionBegin;
241   ierr = PreScaleLeft(N,v1,&xx);CHKERRQ(ierr);
242   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
243   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
244   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
245   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
246   if (v2 == v3) {
247     if (!Na->orwork) {ierr = VecDuplicate(v3,&Na->orwork);CHKERRQ(ierr);}
248     ierr = VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
249     ierr = VecScatterEnd  (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
250     ierr = MatSubmatShiftAndScale(N,xx,Na->orwork);CHKERRQ(ierr);
251     ierr = PostScaleRight(N,Na->orwork);CHKERRQ(ierr);
252     ierr = VecAXPY(v3,1.0,Na->orwork);CHKERRQ(ierr);
253   } else {
254     ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
255     ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
256     ierr = MatSubmatShiftAndScale(N,xx,v3);CHKERRQ(ierr);
257     ierr = PostScaleRight(N,v3);CHKERRQ(ierr);
258     ierr = VecAXPY(v3,1.0,v2);CHKERRQ(ierr);
259   }
260   PetscFunctionReturn(0);
261 }
262 
263 static PetscErrorCode MatDestroy_SubMatrix(Mat N)
264 {
265   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
266   PetscErrorCode ierr;
267 
268   PetscFunctionBegin;
269   ierr = ISDestroy(&Na->isrow);CHKERRQ(ierr);
270   ierr = ISDestroy(&Na->iscol);CHKERRQ(ierr);
271   ierr = VecDestroy(&Na->left);CHKERRQ(ierr);
272   ierr = VecDestroy(&Na->right);CHKERRQ(ierr);
273   ierr = VecDestroy(&Na->olwork);CHKERRQ(ierr);
274   ierr = VecDestroy(&Na->orwork);CHKERRQ(ierr);
275   ierr = VecDestroy(&Na->lwork);CHKERRQ(ierr);
276   ierr = VecDestroy(&Na->rwork);CHKERRQ(ierr);
277   ierr = VecDestroy(&Na->dshift);CHKERRQ(ierr);
278   ierr = VecScatterDestroy(&Na->lrestrict);CHKERRQ(ierr);
279   ierr = VecScatterDestroy(&Na->rprolong);CHKERRQ(ierr);
280   ierr = MatDestroy(&Na->A);CHKERRQ(ierr);
281   ierr = PetscFree(N->data);CHKERRQ(ierr);
282   PetscFunctionReturn(0);
283 }
284 
285 /*@
286    MatCreateSubMatrixVirtual - Creates a virtual matrix that acts as a submatrix
287 
288    Collective on Mat
289 
290    Input Parameters:
291 +  A - matrix that we will extract a submatrix of
292 .  isrow - rows to be present in the submatrix
293 -  iscol - columns to be present in the submatrix
294 
295    Output Parameters:
296 .  newmat - new matrix
297 
298    Level: developer
299 
300    Notes:
301    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
302 
303 .seealso: MatCreateSubMatrix(), MatSubMatrixVirtualUpdate()
304 @*/
305 PetscErrorCode MatCreateSubMatrixVirtual(Mat A,IS isrow,IS iscol,Mat *newmat)
306 {
307   Vec            left,right;
308   PetscInt       m,n;
309   Mat            N;
310   Mat_SubVirtual *Na;
311   PetscErrorCode ierr;
312 
313   PetscFunctionBegin;
314   PetscValidHeaderSpecific(A,MAT_CLASSID,1);
315   PetscValidHeaderSpecific(isrow,IS_CLASSID,2);
316   PetscValidHeaderSpecific(iscol,IS_CLASSID,3);
317   PetscValidPointer(newmat,4);
318   *newmat = 0;
319 
320   ierr = MatCreate(PetscObjectComm((PetscObject)A),&N);CHKERRQ(ierr);
321   ierr = ISGetLocalSize(isrow,&m);CHKERRQ(ierr);
322   ierr = ISGetLocalSize(iscol,&n);CHKERRQ(ierr);
323   ierr = MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);CHKERRQ(ierr);
324   ierr = PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);CHKERRQ(ierr);
325 
326   ierr      = PetscNewLog(N,&Na);CHKERRQ(ierr);
327   N->data   = (void*)Na;
328   ierr      = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
329   ierr      = PetscObjectReference((PetscObject)isrow);CHKERRQ(ierr);
330   ierr      = PetscObjectReference((PetscObject)iscol);CHKERRQ(ierr);
331   Na->A     = A;
332   Na->isrow = isrow;
333   Na->iscol = iscol;
334   Na->vscale = 1.0;
335   Na->vshift = 0.0;
336 
337   N->ops->destroy          = MatDestroy_SubMatrix;
338   N->ops->mult             = MatMult_SubMatrix;
339   N->ops->multadd          = MatMultAdd_SubMatrix;
340   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
341   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
342   N->ops->scale            = MatScale_SubMatrix;
343   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
344   N->ops->shift            = MatShift_SubMatrix;
345 
346   ierr = MatSetBlockSizesFromMats(N,A,A);CHKERRQ(ierr);
347   ierr = PetscLayoutSetUp(N->rmap);CHKERRQ(ierr);
348   ierr = PetscLayoutSetUp(N->cmap);CHKERRQ(ierr);
349 
350   ierr = MatCreateVecs(A,&Na->rwork,&Na->lwork);CHKERRQ(ierr);
351   ierr = VecCreate(PetscObjectComm((PetscObject)isrow),&left);CHKERRQ(ierr);
352   ierr = VecCreate(PetscObjectComm((PetscObject)iscol),&right);CHKERRQ(ierr);
353   ierr = VecSetSizes(left,m,PETSC_DETERMINE);CHKERRQ(ierr);
354   ierr = VecSetSizes(right,n,PETSC_DETERMINE);CHKERRQ(ierr);
355   ierr = VecSetUp(left);CHKERRQ(ierr);
356   ierr = VecSetUp(right);CHKERRQ(ierr);
357   ierr = VecScatterCreate(Na->lwork,isrow,left,NULL,&Na->lrestrict);CHKERRQ(ierr);
358   ierr = VecScatterCreate(right,NULL,Na->rwork,iscol,&Na->rprolong);CHKERRQ(ierr);
359   ierr = VecDestroy(&left);CHKERRQ(ierr);
360   ierr = VecDestroy(&right);CHKERRQ(ierr);
361 
362   N->assembled = PETSC_TRUE;
363 
364   ierr = MatSetUp(N);CHKERRQ(ierr);
365 
366   *newmat      = N;
367   PetscFunctionReturn(0);
368 }
369 
370 
371 /*@
372    MatSubMatrixVirtualUpdate - Updates a submatrix
373 
374    Collective on Mat
375 
376    Input Parameters:
377 +  N - submatrix to update
378 .  A - full matrix in the submatrix
379 .  isrow - rows in the update (same as the first time the submatrix was created)
380 -  iscol - columns in the update (same as the first time the submatrix was created)
381 
382    Level: developer
383 
384    Notes:
385    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
386 
387 .seealso: MatCreateSubMatrixVirtual()
388 @*/
389 PetscErrorCode  MatSubMatrixVirtualUpdate(Mat N,Mat A,IS isrow,IS iscol)
390 {
391   PetscErrorCode ierr;
392   PetscBool      flg;
393   Mat_SubVirtual *Na;
394 
395   PetscFunctionBegin;
396   PetscValidHeaderSpecific(N,MAT_CLASSID,1);
397   PetscValidHeaderSpecific(A,MAT_CLASSID,2);
398   PetscValidHeaderSpecific(isrow,IS_CLASSID,3);
399   PetscValidHeaderSpecific(iscol,IS_CLASSID,4);
400   ierr = PetscObjectTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);CHKERRQ(ierr);
401   if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix has wrong type");
402 
403   Na   = (Mat_SubVirtual*)N->data;
404   ierr = ISEqual(isrow,Na->isrow,&flg);CHKERRQ(ierr);
405   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
406   ierr = ISEqual(iscol,Na->iscol,&flg);CHKERRQ(ierr);
407   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");
408 
409   ierr  = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
410   ierr  = MatDestroy(&Na->A);CHKERRQ(ierr);
411   Na->A = A;
412 
413   Na->vshift = 0.0;
414   Na->vscale = 1.0;
415   ierr       = VecDestroy(&Na->left);CHKERRQ(ierr);
416   ierr       = VecDestroy(&Na->right);CHKERRQ(ierr);
417   PetscFunctionReturn(0);
418 }
419