xref: /petsc/src/mat/impls/submat/submat.c (revision 697336901c45ac77e1fd620fe1fca906cf3f95c8)
1 
2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
3 
4 typedef struct {
5   IS         isrow, iscol;   /* rows and columns in submatrix, only used to check consistency */
6   Vec        lwork, rwork;   /* work vectors inside the scatters */
7   Vec        lwork2, rwork2; /* work vectors inside the scatters */
8   VecScatter lrestrict, rprolong;
9   Mat        A;
10 } Mat_SubVirtual;
11 
12 static PetscErrorCode MatScale_SubMatrix(Mat N, PetscScalar a)
13 {
14   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
15 
16   PetscFunctionBegin;
17   PetscCall(MatScale(Na->A, a));
18   PetscFunctionReturn(PETSC_SUCCESS);
19 }
20 
21 static PetscErrorCode MatShift_SubMatrix(Mat N, PetscScalar a)
22 {
23   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
24 
25   PetscFunctionBegin;
26   PetscCall(MatShift(Na->A, a));
27   PetscFunctionReturn(PETSC_SUCCESS);
28 }
29 
30 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N, Vec left, Vec right)
31 {
32   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
33 
34   PetscFunctionBegin;
35   if (right) {
36     PetscCall(VecZeroEntries(Na->rwork));
37     PetscCall(VecScatterBegin(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
38     PetscCall(VecScatterEnd(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
39   }
40   if (left) {
41     PetscCall(VecZeroEntries(Na->lwork));
42     PetscCall(VecScatterBegin(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
43     PetscCall(VecScatterEnd(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
44   }
45   PetscCall(MatDiagonalScale(Na->A, left ? Na->lwork : NULL, right ? Na->rwork : NULL));
46   PetscFunctionReturn(PETSC_SUCCESS);
47 }
48 
49 static PetscErrorCode MatGetDiagonal_SubMatrix(Mat N, Vec d)
50 {
51   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
52 
53   PetscFunctionBegin;
54   PetscCall(MatGetDiagonal(Na->A, Na->rwork));
55   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE));
56   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE));
57   PetscFunctionReturn(PETSC_SUCCESS);
58 }
59 
60 static PetscErrorCode MatMult_SubMatrix(Mat N, Vec x, Vec y)
61 {
62   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
63 
64   PetscFunctionBegin;
65   PetscCall(VecZeroEntries(Na->rwork));
66   PetscCall(VecScatterBegin(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
67   PetscCall(VecScatterEnd(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
68   PetscCall(MatMult(Na->A, Na->rwork, Na->lwork));
69   PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD));
70   PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD));
71   PetscFunctionReturn(PETSC_SUCCESS);
72 }
73 
74 static PetscErrorCode MatMultAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
75 {
76   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
77 
78   PetscFunctionBegin;
79   PetscCall(VecZeroEntries(Na->rwork));
80   PetscCall(VecScatterBegin(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
81   PetscCall(VecScatterEnd(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
82   if (v1 == v2) {
83     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->rwork, Na->lwork));
84   } else if (v2 == v3) {
85     PetscCall(VecZeroEntries(Na->lwork));
86     PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
87     PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
88     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork, Na->lwork));
89   } else {
90     if (!Na->lwork2) {
91       PetscCall(VecDuplicate(Na->lwork, &Na->lwork2));
92     } else {
93       PetscCall(VecZeroEntries(Na->lwork2));
94     }
95     PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE));
96     PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE));
97     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork2, Na->lwork));
98   }
99   PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD));
100   PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD));
101   PetscFunctionReturn(PETSC_SUCCESS);
102 }
103 
104 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N, Vec x, Vec y)
105 {
106   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
107 
108   PetscFunctionBegin;
109   PetscCall(VecZeroEntries(Na->lwork));
110   PetscCall(VecScatterBegin(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
111   PetscCall(VecScatterEnd(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
112   PetscCall(MatMultTranspose(Na->A, Na->lwork, Na->rwork));
113   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE));
114   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE));
115   PetscFunctionReturn(PETSC_SUCCESS);
116 }
117 
118 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
119 {
120   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
121 
122   PetscFunctionBegin;
123   PetscCall(VecZeroEntries(Na->lwork));
124   PetscCall(VecScatterBegin(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
125   PetscCall(VecScatterEnd(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
126   if (v1 == v2) {
127     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->lwork, Na->rwork));
128   } else if (v2 == v3) {
129     PetscCall(VecZeroEntries(Na->rwork));
130     PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
131     PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
132     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork, Na->rwork));
133   } else {
134     if (!Na->rwork2) {
135       PetscCall(VecDuplicate(Na->rwork, &Na->rwork2));
136     } else {
137       PetscCall(VecZeroEntries(Na->rwork2));
138     }
139     PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD));
140     PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD));
141     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork2, Na->rwork));
142   }
143   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE));
144   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE));
145   PetscFunctionReturn(PETSC_SUCCESS);
146 }
147 
148 static PetscErrorCode MatDestroy_SubMatrix(Mat N)
149 {
150   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
151 
152   PetscFunctionBegin;
153   PetscCall(ISDestroy(&Na->isrow));
154   PetscCall(ISDestroy(&Na->iscol));
155   PetscCall(VecDestroy(&Na->lwork));
156   PetscCall(VecDestroy(&Na->rwork));
157   PetscCall(VecDestroy(&Na->lwork2));
158   PetscCall(VecDestroy(&Na->rwork2));
159   PetscCall(VecScatterDestroy(&Na->lrestrict));
160   PetscCall(VecScatterDestroy(&Na->rprolong));
161   PetscCall(MatDestroy(&Na->A));
162   PetscCall(PetscFree(N->data));
163   PetscFunctionReturn(PETSC_SUCCESS);
164 }
165 
166 /*@
167   MatCreateSubMatrixVirtual - Creates a virtual matrix `MATSUBMATRIX` that acts as a submatrix
168 
169   Collective
170 
171   Input Parameters:
172 + A     - matrix that we will extract a submatrix of
173 . isrow - rows to be present in the submatrix
174 - iscol - columns to be present in the submatrix
175 
176   Output Parameter:
177 . newmat - new matrix
178 
179   Level: developer
180 
181   Note:
182   Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.
183 
184 .seealso: [](ch_matrices), `Mat`, `MATSUBMATRIX`, `MATLOCALREF`, `MatCreateLocalRef()`, `MatCreateSubMatrix()`, `MatSubMatrixVirtualUpdate()`
185 @*/
186 PetscErrorCode MatCreateSubMatrixVirtual(Mat A, IS isrow, IS iscol, Mat *newmat)
187 {
188   Vec             left, right;
189   PetscInt        m, n;
190   Mat             N;
191   Mat_SubVirtual *Na;
192 
193   PetscFunctionBegin;
194   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
195   PetscValidHeaderSpecific(isrow, IS_CLASSID, 2);
196   PetscValidHeaderSpecific(iscol, IS_CLASSID, 3);
197   PetscValidPointer(newmat, 4);
198   *newmat = NULL;
199 
200   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), &N));
201   PetscCall(ISGetLocalSize(isrow, &m));
202   PetscCall(ISGetLocalSize(iscol, &n));
203   PetscCall(MatSetSizes(N, m, n, PETSC_DETERMINE, PETSC_DETERMINE));
204   PetscCall(PetscObjectChangeTypeName((PetscObject)N, MATSUBMATRIX));
205 
206   PetscCall(PetscNew(&Na));
207   N->data = (void *)Na;
208 
209   PetscCall(PetscObjectReference((PetscObject)isrow));
210   PetscCall(PetscObjectReference((PetscObject)iscol));
211   Na->isrow = isrow;
212   Na->iscol = iscol;
213 
214   PetscCall(PetscFree(N->defaultvectype));
215   PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype));
216   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
217      the reference count of the context. This is a problem if A is already of type MATSHELL */
218   PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A));
219 
220   N->ops->destroy          = MatDestroy_SubMatrix;
221   N->ops->mult             = MatMult_SubMatrix;
222   N->ops->multadd          = MatMultAdd_SubMatrix;
223   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
224   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
225   N->ops->scale            = MatScale_SubMatrix;
226   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
227   N->ops->shift            = MatShift_SubMatrix;
228   N->ops->convert          = MatConvert_Shell;
229   N->ops->getdiagonal      = MatGetDiagonal_SubMatrix;
230 
231   PetscCall(MatSetBlockSizesFromMats(N, A, A));
232   PetscCall(PetscLayoutSetUp(N->rmap));
233   PetscCall(PetscLayoutSetUp(N->cmap));
234 
235   PetscCall(MatCreateVecs(A, &Na->rwork, &Na->lwork));
236   PetscCall(MatCreateVecs(N, &right, &left));
237   PetscCall(VecScatterCreate(Na->lwork, isrow, left, NULL, &Na->lrestrict));
238   PetscCall(VecScatterCreate(right, NULL, Na->rwork, iscol, &Na->rprolong));
239   PetscCall(VecDestroy(&left));
240   PetscCall(VecDestroy(&right));
241   PetscCall(MatSetUp(N));
242 
243   N->assembled = PETSC_TRUE;
244   *newmat      = N;
245   PetscFunctionReturn(PETSC_SUCCESS);
246 }
247 
248 /*MC
249    MATSUBMATRIX - "submatrix" - A matrix type that represents a virtual submatrix of a matrix
250 
251   Level: advanced
252 
253    Developer Note:
254    The `MatType` is `MATSUBMATRIX` but the routines associated have `SubMatrixVirtual` in them, the `MatType` name should likely be changed to
255    `MATSUBMATRIXVIRTUAL`
256 
257 .seealso: [](ch_matrices), `Mat`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrix()`
258 M*/
259 
260 /*@
261   MatSubMatrixVirtualUpdate - Updates a `MATSUBMATRIX` virtual submatrix
262 
263   Collective
264 
265   Input Parameters:
266 + N     - submatrix to update
267 . A     - full matrix in the submatrix
268 . isrow - rows in the update (same as the first time the submatrix was created)
269 - iscol - columns in the update (same as the first time the submatrix was created)
270 
271   Level: developer
272 
273   Note:
274   Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.
275 
276 .seealso: [](ch_matrices), `Mat`, `MATSUBMATRIX`, `MatCreateSubMatrixVirtual()`
277 @*/
278 PetscErrorCode MatSubMatrixVirtualUpdate(Mat N, Mat A, IS isrow, IS iscol)
279 {
280   PetscBool       flg;
281   Mat_SubVirtual *Na;
282 
283   PetscFunctionBegin;
284   PetscValidHeaderSpecific(N, MAT_CLASSID, 1);
285   PetscValidHeaderSpecific(A, MAT_CLASSID, 2);
286   PetscValidHeaderSpecific(isrow, IS_CLASSID, 3);
287   PetscValidHeaderSpecific(iscol, IS_CLASSID, 4);
288   PetscCall(PetscObjectTypeCompare((PetscObject)N, MATSUBMATRIX, &flg));
289   PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Matrix has wrong type");
290 
291   Na = (Mat_SubVirtual *)N->data;
292   PetscCall(ISEqual(isrow, Na->isrow, &flg));
293   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different row indices");
294   PetscCall(ISEqual(iscol, Na->iscol, &flg));
295   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different column indices");
296 
297   PetscCall(PetscFree(N->defaultvectype));
298   PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype));
299   PetscCall(MatDestroy(&Na->A));
300   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
301      the reference count of the context. This is a problem if A is already of type MATSHELL */
302   PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A));
303   PetscFunctionReturn(PETSC_SUCCESS);
304 }
305