xref: /petsc/src/mat/impls/submat/submat.c (revision 2d30e087755efd99e28fdfe792ffbeb2ee1ea928)
1 
2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
3 
4 typedef struct {
5   IS         isrow, iscol;   /* rows and columns in submatrix, only used to check consistency */
6   Vec        lwork, rwork;   /* work vectors inside the scatters */
7   Vec        lwork2, rwork2; /* work vectors inside the scatters */
8   VecScatter lrestrict, rprolong;
9   Mat        A;
10 } Mat_SubVirtual;
11 
12 static PetscErrorCode MatScale_SubMatrix(Mat N, PetscScalar a) {
13   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
14 
15   PetscFunctionBegin;
16   PetscCall(MatScale(Na->A, a));
17   PetscFunctionReturn(0);
18 }
19 
20 static PetscErrorCode MatShift_SubMatrix(Mat N, PetscScalar a) {
21   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
22 
23   PetscFunctionBegin;
24   PetscCall(MatShift(Na->A, a));
25   PetscFunctionReturn(0);
26 }
27 
28 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N, Vec left, Vec right) {
29   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
30 
31   PetscFunctionBegin;
32   if (right) {
33     PetscCall(VecZeroEntries(Na->rwork));
34     PetscCall(VecScatterBegin(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
35     PetscCall(VecScatterEnd(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
36   }
37   if (left) {
38     PetscCall(VecZeroEntries(Na->lwork));
39     PetscCall(VecScatterBegin(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
40     PetscCall(VecScatterEnd(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
41   }
42   PetscCall(MatDiagonalScale(Na->A, left ? Na->lwork : NULL, right ? Na->rwork : NULL));
43   PetscFunctionReturn(0);
44 }
45 
46 static PetscErrorCode MatGetDiagonal_SubMatrix(Mat N, Vec d) {
47   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
48 
49   PetscFunctionBegin;
50   PetscCall(MatGetDiagonal(Na->A, Na->rwork));
51   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE));
52   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE));
53   PetscFunctionReturn(0);
54 }
55 
56 static PetscErrorCode MatMult_SubMatrix(Mat N, Vec x, Vec y) {
57   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
58 
59   PetscFunctionBegin;
60   PetscCall(VecZeroEntries(Na->rwork));
61   PetscCall(VecScatterBegin(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
62   PetscCall(VecScatterEnd(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
63   PetscCall(MatMult(Na->A, Na->rwork, Na->lwork));
64   PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD));
65   PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD));
66   PetscFunctionReturn(0);
67 }
68 
69 static PetscErrorCode MatMultAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3) {
70   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
71 
72   PetscFunctionBegin;
73   PetscCall(VecZeroEntries(Na->rwork));
74   PetscCall(VecScatterBegin(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
75   PetscCall(VecScatterEnd(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
76   if (v1 == v2) {
77     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->rwork, Na->lwork));
78   } else if (v2 == v3) {
79     PetscCall(VecZeroEntries(Na->lwork));
80     PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
81     PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
82     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork, Na->lwork));
83   } else {
84     if (!Na->lwork2) {
85       PetscCall(VecDuplicate(Na->lwork, &Na->lwork2));
86     } else {
87       PetscCall(VecZeroEntries(Na->lwork2));
88     }
89     PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE));
90     PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE));
91     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork2, Na->lwork));
92   }
93   PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD));
94   PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD));
95   PetscFunctionReturn(0);
96 }
97 
98 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N, Vec x, Vec y) {
99   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
100 
101   PetscFunctionBegin;
102   PetscCall(VecZeroEntries(Na->lwork));
103   PetscCall(VecScatterBegin(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
104   PetscCall(VecScatterEnd(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
105   PetscCall(MatMultTranspose(Na->A, Na->lwork, Na->rwork));
106   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE));
107   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE));
108   PetscFunctionReturn(0);
109 }
110 
111 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3) {
112   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
113 
114   PetscFunctionBegin;
115   PetscCall(VecZeroEntries(Na->lwork));
116   PetscCall(VecScatterBegin(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
117   PetscCall(VecScatterEnd(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
118   if (v1 == v2) {
119     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->lwork, Na->rwork));
120   } else if (v2 == v3) {
121     PetscCall(VecZeroEntries(Na->rwork));
122     PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
123     PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
124     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork, Na->rwork));
125   } else {
126     if (!Na->rwork2) {
127       PetscCall(VecDuplicate(Na->rwork, &Na->rwork2));
128     } else {
129       PetscCall(VecZeroEntries(Na->rwork2));
130     }
131     PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD));
132     PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD));
133     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork2, Na->rwork));
134   }
135   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE));
136   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE));
137   PetscFunctionReturn(0);
138 }
139 
140 static PetscErrorCode MatDestroy_SubMatrix(Mat N) {
141   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
142 
143   PetscFunctionBegin;
144   PetscCall(ISDestroy(&Na->isrow));
145   PetscCall(ISDestroy(&Na->iscol));
146   PetscCall(VecDestroy(&Na->lwork));
147   PetscCall(VecDestroy(&Na->rwork));
148   PetscCall(VecDestroy(&Na->lwork2));
149   PetscCall(VecDestroy(&Na->rwork2));
150   PetscCall(VecScatterDestroy(&Na->lrestrict));
151   PetscCall(VecScatterDestroy(&Na->rprolong));
152   PetscCall(MatDestroy(&Na->A));
153   PetscCall(PetscFree(N->data));
154   PetscFunctionReturn(0);
155 }
156 
157 /*@
158    MatCreateSubMatrixVirtual - Creates a virtual matrix `MATSUBMATRIX` that acts as a submatrix
159 
160    Collective on A
161 
162    Input Parameters:
163 +  A - matrix that we will extract a submatrix of
164 .  isrow - rows to be present in the submatrix
165 -  iscol - columns to be present in the submatrix
166 
167    Output Parameters:
168 .  newmat - new matrix
169 
170    Level: developer
171 
172    Note:
173    Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.
174 
175    Developer Note:
176    The `MatType` is `MATSUBMATRIX` but the routines associated have `SubMatrixVirtual` in them, the `MatType` should likely be changed
177 
178 .seealso: `MATSUBMATRIX`, `MATLOCALREF`, `MatCreateLocalRef()`, `MatCreateSubMatrix()`, `MatSubMatrixVirtualUpdate()`
179 @*/
180 PetscErrorCode MatCreateSubMatrixVirtual(Mat A, IS isrow, IS iscol, Mat *newmat) {
181   Vec             left, right;
182   PetscInt        m, n;
183   Mat             N;
184   Mat_SubVirtual *Na;
185 
186   PetscFunctionBegin;
187   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
188   PetscValidHeaderSpecific(isrow, IS_CLASSID, 2);
189   PetscValidHeaderSpecific(iscol, IS_CLASSID, 3);
190   PetscValidPointer(newmat, 4);
191   *newmat = NULL;
192 
193   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), &N));
194   PetscCall(ISGetLocalSize(isrow, &m));
195   PetscCall(ISGetLocalSize(iscol, &n));
196   PetscCall(MatSetSizes(N, m, n, PETSC_DETERMINE, PETSC_DETERMINE));
197   PetscCall(PetscObjectChangeTypeName((PetscObject)N, MATSUBMATRIX));
198 
199   PetscCall(PetscNewLog(N, &Na));
200   N->data = (void *)Na;
201 
202   PetscCall(PetscObjectReference((PetscObject)isrow));
203   PetscCall(PetscObjectReference((PetscObject)iscol));
204   Na->isrow = isrow;
205   Na->iscol = iscol;
206 
207   PetscCall(PetscFree(N->defaultvectype));
208   PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype));
209   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
210      the reference count of the context. This is a problem if A is already of type MATSHELL */
211   PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A));
212 
213   N->ops->destroy          = MatDestroy_SubMatrix;
214   N->ops->mult             = MatMult_SubMatrix;
215   N->ops->multadd          = MatMultAdd_SubMatrix;
216   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
217   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
218   N->ops->scale            = MatScale_SubMatrix;
219   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
220   N->ops->shift            = MatShift_SubMatrix;
221   N->ops->convert          = MatConvert_Shell;
222   N->ops->getdiagonal      = MatGetDiagonal_SubMatrix;
223 
224   PetscCall(MatSetBlockSizesFromMats(N, A, A));
225   PetscCall(PetscLayoutSetUp(N->rmap));
226   PetscCall(PetscLayoutSetUp(N->cmap));
227 
228   PetscCall(MatCreateVecs(A, &Na->rwork, &Na->lwork));
229   PetscCall(MatCreateVecs(N, &right, &left));
230   PetscCall(VecScatterCreate(Na->lwork, isrow, left, NULL, &Na->lrestrict));
231   PetscCall(VecScatterCreate(right, NULL, Na->rwork, iscol, &Na->rprolong));
232   PetscCall(VecDestroy(&left));
233   PetscCall(VecDestroy(&right));
234   PetscCall(MatSetUp(N));
235 
236   N->assembled = PETSC_TRUE;
237   *newmat      = N;
238   PetscFunctionReturn(0);
239 }
240 
241 /*MC
242    MATSUBMATRIX - "submatrix" - A matrix type that represents a virtual submatrix of a matrix
243 
244   Level: advanced
245 
246   Developer Note:
247   This should be named `MATSUBMATRIXVIRTUAL`
248 
249 .seealso: `Mat`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrix()`
250 M*/
251 
252 /*@
253    MatSubMatrixVirtualUpdate - Updates a `MATSUBMATRIX` virtual submatrix
254 
255    Collective on N
256 
257    Input Parameters:
258 +  N - submatrix to update
259 .  A - full matrix in the submatrix
260 .  isrow - rows in the update (same as the first time the submatrix was created)
261 -  iscol - columns in the update (same as the first time the submatrix was created)
262 
263    Level: developer
264 
265    Note:
266    Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.
267 
268 .seealso: MATSUBMATRIX`, `MatCreateSubMatrixVirtual()`
269 @*/
270 PetscErrorCode MatSubMatrixVirtualUpdate(Mat N, Mat A, IS isrow, IS iscol) {
271   PetscBool       flg;
272   Mat_SubVirtual *Na;
273 
274   PetscFunctionBegin;
275   PetscValidHeaderSpecific(N, MAT_CLASSID, 1);
276   PetscValidHeaderSpecific(A, MAT_CLASSID, 2);
277   PetscValidHeaderSpecific(isrow, IS_CLASSID, 3);
278   PetscValidHeaderSpecific(iscol, IS_CLASSID, 4);
279   PetscCall(PetscObjectTypeCompare((PetscObject)N, MATSUBMATRIX, &flg));
280   PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Matrix has wrong type");
281 
282   Na = (Mat_SubVirtual *)N->data;
283   PetscCall(ISEqual(isrow, Na->isrow, &flg));
284   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different row indices");
285   PetscCall(ISEqual(iscol, Na->iscol, &flg));
286   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different column indices");
287 
288   PetscCall(PetscFree(N->defaultvectype));
289   PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype));
290   PetscCall(MatDestroy(&Na->A));
291   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
292      the reference count of the context. This is a problem if A is already of type MATSHELL */
293   PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A));
294   PetscFunctionReturn(0);
295 }
296