xref: /petsc/src/mat/impls/submat/submat.c (revision 98d129c30f3ee9fdddc40fdbc5a989b7be64f888)
1 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
2 
3 typedef struct {
4   IS         isrow, iscol;   /* rows and columns in submatrix, only used to check consistency */
5   Vec        lwork, rwork;   /* work vectors inside the scatters */
6   Vec        lwork2, rwork2; /* work vectors inside the scatters */
7   VecScatter lrestrict, rprolong;
8   Mat        A;
9 } Mat_SubVirtual;
10 
11 static PetscErrorCode MatScale_SubMatrix(Mat N, PetscScalar a)
12 {
13   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
14 
15   PetscFunctionBegin;
16   PetscCall(MatScale(Na->A, a));
17   PetscFunctionReturn(PETSC_SUCCESS);
18 }
19 
20 static PetscErrorCode MatShift_SubMatrix(Mat N, PetscScalar a)
21 {
22   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
23 
24   PetscFunctionBegin;
25   PetscCall(MatShift(Na->A, a));
26   PetscFunctionReturn(PETSC_SUCCESS);
27 }
28 
29 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N, Vec left, Vec right)
30 {
31   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
32 
33   PetscFunctionBegin;
34   if (right) {
35     PetscCall(VecZeroEntries(Na->rwork));
36     PetscCall(VecScatterBegin(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
37     PetscCall(VecScatterEnd(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
38   }
39   if (left) {
40     PetscCall(VecZeroEntries(Na->lwork));
41     PetscCall(VecScatterBegin(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
42     PetscCall(VecScatterEnd(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
43   }
44   PetscCall(MatDiagonalScale(Na->A, left ? Na->lwork : NULL, right ? Na->rwork : NULL));
45   PetscFunctionReturn(PETSC_SUCCESS);
46 }
47 
48 static PetscErrorCode MatGetDiagonal_SubMatrix(Mat N, Vec d)
49 {
50   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
51 
52   PetscFunctionBegin;
53   PetscCall(MatGetDiagonal(Na->A, Na->rwork));
54   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE));
55   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE));
56   PetscFunctionReturn(PETSC_SUCCESS);
57 }
58 
59 static PetscErrorCode MatMult_SubMatrix(Mat N, Vec x, Vec y)
60 {
61   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
62 
63   PetscFunctionBegin;
64   PetscCall(VecZeroEntries(Na->rwork));
65   PetscCall(VecScatterBegin(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
66   PetscCall(VecScatterEnd(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
67   PetscCall(MatMult(Na->A, Na->rwork, Na->lwork));
68   PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD));
69   PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD));
70   PetscFunctionReturn(PETSC_SUCCESS);
71 }
72 
73 static PetscErrorCode MatMultAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
74 {
75   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
76 
77   PetscFunctionBegin;
78   PetscCall(VecZeroEntries(Na->rwork));
79   PetscCall(VecScatterBegin(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
80   PetscCall(VecScatterEnd(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
81   if (v1 == v2) {
82     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->rwork, Na->lwork));
83   } else if (v2 == v3) {
84     PetscCall(VecZeroEntries(Na->lwork));
85     PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
86     PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
87     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork, Na->lwork));
88   } else {
89     if (!Na->lwork2) {
90       PetscCall(VecDuplicate(Na->lwork, &Na->lwork2));
91     } else {
92       PetscCall(VecZeroEntries(Na->lwork2));
93     }
94     PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE));
95     PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE));
96     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork2, Na->lwork));
97   }
98   PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD));
99   PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD));
100   PetscFunctionReturn(PETSC_SUCCESS);
101 }
102 
103 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N, Vec x, Vec y)
104 {
105   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
106 
107   PetscFunctionBegin;
108   PetscCall(VecZeroEntries(Na->lwork));
109   PetscCall(VecScatterBegin(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
110   PetscCall(VecScatterEnd(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
111   PetscCall(MatMultTranspose(Na->A, Na->lwork, Na->rwork));
112   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE));
113   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE));
114   PetscFunctionReturn(PETSC_SUCCESS);
115 }
116 
117 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
118 {
119   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
120 
121   PetscFunctionBegin;
122   PetscCall(VecZeroEntries(Na->lwork));
123   PetscCall(VecScatterBegin(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
124   PetscCall(VecScatterEnd(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
125   if (v1 == v2) {
126     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->lwork, Na->rwork));
127   } else if (v2 == v3) {
128     PetscCall(VecZeroEntries(Na->rwork));
129     PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
130     PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
131     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork, Na->rwork));
132   } else {
133     if (!Na->rwork2) {
134       PetscCall(VecDuplicate(Na->rwork, &Na->rwork2));
135     } else {
136       PetscCall(VecZeroEntries(Na->rwork2));
137     }
138     PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD));
139     PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD));
140     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork2, Na->rwork));
141   }
142   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE));
143   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE));
144   PetscFunctionReturn(PETSC_SUCCESS);
145 }
146 
147 static PetscErrorCode MatDestroy_SubMatrix(Mat N)
148 {
149   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
150 
151   PetscFunctionBegin;
152   PetscCall(ISDestroy(&Na->isrow));
153   PetscCall(ISDestroy(&Na->iscol));
154   PetscCall(VecDestroy(&Na->lwork));
155   PetscCall(VecDestroy(&Na->rwork));
156   PetscCall(VecDestroy(&Na->lwork2));
157   PetscCall(VecDestroy(&Na->rwork2));
158   PetscCall(VecScatterDestroy(&Na->lrestrict));
159   PetscCall(VecScatterDestroy(&Na->rprolong));
160   PetscCall(MatDestroy(&Na->A));
161   PetscCall(PetscFree(N->data));
162   PetscFunctionReturn(PETSC_SUCCESS);
163 }
164 
165 /*@
166   MatCreateSubMatrixVirtual - Creates a virtual matrix `MATSUBMATRIX` that acts as a submatrix
167 
168   Collective
169 
170   Input Parameters:
171 + A     - matrix that we will extract a submatrix of
172 . isrow - rows to be present in the submatrix
173 - iscol - columns to be present in the submatrix
174 
175   Output Parameter:
176 . newmat - new matrix
177 
178   Level: developer
179 
180   Note:
181   Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.
182 
183 .seealso: [](ch_matrices), `Mat`, `MATSUBMATRIX`, `MATLOCALREF`, `MatCreateLocalRef()`, `MatCreateSubMatrix()`, `MatSubMatrixVirtualUpdate()`
184 @*/
185 PetscErrorCode MatCreateSubMatrixVirtual(Mat A, IS isrow, IS iscol, Mat *newmat)
186 {
187   Vec             left, right;
188   PetscInt        m, n;
189   Mat             N;
190   Mat_SubVirtual *Na;
191 
192   PetscFunctionBegin;
193   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
194   PetscValidHeaderSpecific(isrow, IS_CLASSID, 2);
195   PetscValidHeaderSpecific(iscol, IS_CLASSID, 3);
196   PetscAssertPointer(newmat, 4);
197   *newmat = NULL;
198 
199   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), &N));
200   PetscCall(ISGetLocalSize(isrow, &m));
201   PetscCall(ISGetLocalSize(iscol, &n));
202   PetscCall(MatSetSizes(N, m, n, PETSC_DETERMINE, PETSC_DETERMINE));
203   PetscCall(PetscObjectChangeTypeName((PetscObject)N, MATSUBMATRIX));
204 
205   PetscCall(PetscNew(&Na));
206   N->data = (void *)Na;
207 
208   PetscCall(PetscObjectReference((PetscObject)isrow));
209   PetscCall(PetscObjectReference((PetscObject)iscol));
210   Na->isrow = isrow;
211   Na->iscol = iscol;
212 
213   PetscCall(PetscFree(N->defaultvectype));
214   PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype));
215   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
216      the reference count of the context. This is a problem if A is already of type MATSHELL */
217   PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A));
218 
219   N->ops->destroy          = MatDestroy_SubMatrix;
220   N->ops->mult             = MatMult_SubMatrix;
221   N->ops->multadd          = MatMultAdd_SubMatrix;
222   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
223   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
224   N->ops->scale            = MatScale_SubMatrix;
225   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
226   N->ops->shift            = MatShift_SubMatrix;
227   N->ops->convert          = MatConvert_Shell;
228   N->ops->getdiagonal      = MatGetDiagonal_SubMatrix;
229 
230   PetscCall(MatSetBlockSizesFromMats(N, A, A));
231   PetscCall(PetscLayoutSetUp(N->rmap));
232   PetscCall(PetscLayoutSetUp(N->cmap));
233 
234   PetscCall(MatCreateVecs(A, &Na->rwork, &Na->lwork));
235   PetscCall(MatCreateVecs(N, &right, &left));
236   PetscCall(VecScatterCreate(Na->lwork, isrow, left, NULL, &Na->lrestrict));
237   PetscCall(VecScatterCreate(right, NULL, Na->rwork, iscol, &Na->rprolong));
238   PetscCall(VecDestroy(&left));
239   PetscCall(VecDestroy(&right));
240   PetscCall(MatSetUp(N));
241 
242   N->assembled = PETSC_TRUE;
243   *newmat      = N;
244   PetscFunctionReturn(PETSC_SUCCESS);
245 }
246 
247 /*MC
248    MATSUBMATRIX - "submatrix" - A matrix type that represents a virtual submatrix of a matrix
249 
250   Level: advanced
251 
252    Developer Note:
253    The `MatType` is `MATSUBMATRIX` but the routines associated have `SubMatrixVirtual` in them, the `MatType` name should likely be changed to
254    `MATSUBMATRIXVIRTUAL`
255 
256 .seealso: [](ch_matrices), `Mat`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrix()`
257 M*/
258 
259 /*@
260   MatSubMatrixVirtualUpdate - Updates a `MATSUBMATRIX` virtual submatrix
261 
262   Collective
263 
264   Input Parameters:
265 + N     - submatrix to update
266 . A     - full matrix in the submatrix
267 . isrow - rows in the update (same as the first time the submatrix was created)
268 - iscol - columns in the update (same as the first time the submatrix was created)
269 
270   Level: developer
271 
272   Note:
273   Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.
274 
275 .seealso: [](ch_matrices), `Mat`, `MATSUBMATRIX`, `MatCreateSubMatrixVirtual()`
276 @*/
277 PetscErrorCode MatSubMatrixVirtualUpdate(Mat N, Mat A, IS isrow, IS iscol)
278 {
279   PetscBool       flg;
280   Mat_SubVirtual *Na;
281 
282   PetscFunctionBegin;
283   PetscValidHeaderSpecific(N, MAT_CLASSID, 1);
284   PetscValidHeaderSpecific(A, MAT_CLASSID, 2);
285   PetscValidHeaderSpecific(isrow, IS_CLASSID, 3);
286   PetscValidHeaderSpecific(iscol, IS_CLASSID, 4);
287   PetscCall(PetscObjectTypeCompare((PetscObject)N, MATSUBMATRIX, &flg));
288   PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Matrix has wrong type");
289 
290   Na = (Mat_SubVirtual *)N->data;
291   PetscCall(ISEqual(isrow, Na->isrow, &flg));
292   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different row indices");
293   PetscCall(ISEqual(iscol, Na->iscol, &flg));
294   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different column indices");
295 
296   PetscCall(PetscFree(N->defaultvectype));
297   PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype));
298   PetscCall(MatDestroy(&Na->A));
299   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
300      the reference count of the context. This is a problem if A is already of type MATSHELL */
301   PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A));
302   PetscFunctionReturn(PETSC_SUCCESS);
303 }
304