xref: /petsc/src/mat/impls/scatter/mscatter.c (revision e91c04dfc8a52dee1965211bb1cc8e5bf775178f)
1 /*
2    This provides a matrix that applies a VecScatter to a vector.
3 */
4 
5 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
6 #include <petsc/private/vecimpl.h>
7 
8 typedef struct {
9   VecScatter scatter;
10 } Mat_Scatter;
11 
12 /*@
13   MatScatterGetVecScatter - Returns the user-provided scatter set with `MatScatterSetVecScatter()` in a `MATSCATTER` matrix
14 
15   Logically Collective
16 
17   Input Parameter:
18 . mat - the matrix, should have been created with MatCreateScatter() or have type `MATSCATTER`
19 
20   Output Parameter:
21 . scatter - the scatter context
22 
23   Level: intermediate
24 
25 .seealso: [](ch_matrices), `Mat`, `MATSCATTER`, `MatCreateScatter()`, `MatScatterSetVecScatter()`
26 @*/
27 PetscErrorCode MatScatterGetVecScatter(Mat mat, VecScatter *scatter)
28 {
29   Mat_Scatter *mscatter;
30 
31   PetscFunctionBegin;
32   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
33   PetscAssertPointer(scatter, 2);
34   mscatter = (Mat_Scatter *)mat->data;
35   *scatter = mscatter->scatter;
36   PetscFunctionReturn(PETSC_SUCCESS);
37 }
38 
39 static PetscErrorCode MatDestroy_Scatter(Mat mat)
40 {
41   Mat_Scatter *scatter = (Mat_Scatter *)mat->data;
42 
43   PetscFunctionBegin;
44   PetscCall(VecScatterDestroy(&scatter->scatter));
45   PetscCall(PetscFree(mat->data));
46   PetscFunctionReturn(PETSC_SUCCESS);
47 }
48 
49 static PetscErrorCode MatMult_Scatter(Mat A, Vec x, Vec y)
50 {
51   Mat_Scatter *scatter = (Mat_Scatter *)A->data;
52 
53   PetscFunctionBegin;
54   PetscCheck(scatter->scatter, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Need to first call MatScatterSetScatter()");
55   PetscCall(VecZeroEntries(y));
56   PetscCall(VecScatterBegin(scatter->scatter, x, y, ADD_VALUES, SCATTER_FORWARD));
57   PetscCall(VecScatterEnd(scatter->scatter, x, y, ADD_VALUES, SCATTER_FORWARD));
58   PetscFunctionReturn(PETSC_SUCCESS);
59 }
60 
61 static PetscErrorCode MatMultAdd_Scatter(Mat A, Vec x, Vec y, Vec z)
62 {
63   Mat_Scatter *scatter = (Mat_Scatter *)A->data;
64 
65   PetscFunctionBegin;
66   PetscCheck(scatter->scatter, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Need to first call MatScatterSetScatter()");
67   if (z != y) PetscCall(VecCopy(y, z));
68   PetscCall(VecScatterBegin(scatter->scatter, x, z, ADD_VALUES, SCATTER_FORWARD));
69   PetscCall(VecScatterEnd(scatter->scatter, x, z, ADD_VALUES, SCATTER_FORWARD));
70   PetscFunctionReturn(PETSC_SUCCESS);
71 }
72 
73 static PetscErrorCode MatMultTranspose_Scatter(Mat A, Vec x, Vec y)
74 {
75   Mat_Scatter *scatter = (Mat_Scatter *)A->data;
76 
77   PetscFunctionBegin;
78   PetscCheck(scatter->scatter, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Need to first call MatScatterSetScatter()");
79   PetscCall(VecZeroEntries(y));
80   PetscCall(VecScatterBegin(scatter->scatter, x, y, ADD_VALUES, SCATTER_REVERSE));
81   PetscCall(VecScatterEnd(scatter->scatter, x, y, ADD_VALUES, SCATTER_REVERSE));
82   PetscFunctionReturn(PETSC_SUCCESS);
83 }
84 
85 static PetscErrorCode MatMultTransposeAdd_Scatter(Mat A, Vec x, Vec y, Vec z)
86 {
87   Mat_Scatter *scatter = (Mat_Scatter *)A->data;
88 
89   PetscFunctionBegin;
90   PetscCheck(scatter->scatter, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Need to first call MatScatterSetScatter()");
91   if (z != y) PetscCall(VecCopy(y, z));
92   PetscCall(VecScatterBegin(scatter->scatter, x, z, ADD_VALUES, SCATTER_REVERSE));
93   PetscCall(VecScatterEnd(scatter->scatter, x, z, ADD_VALUES, SCATTER_REVERSE));
94   PetscFunctionReturn(PETSC_SUCCESS);
95 }
96 
97 static struct _MatOps MatOps_Values = {NULL,
98                                        NULL,
99                                        NULL,
100                                        MatMult_Scatter,
101                                        /*  4*/ MatMultAdd_Scatter,
102                                        MatMultTranspose_Scatter,
103                                        MatMultTransposeAdd_Scatter,
104                                        NULL,
105                                        NULL,
106                                        NULL,
107                                        /* 10*/ NULL,
108                                        NULL,
109                                        NULL,
110                                        NULL,
111                                        NULL,
112                                        /* 15*/ NULL,
113                                        NULL,
114                                        NULL,
115                                        NULL,
116                                        NULL,
117                                        /* 20*/ NULL,
118                                        NULL,
119                                        NULL,
120                                        NULL,
121                                        /* 24*/ NULL,
122                                        NULL,
123                                        NULL,
124                                        NULL,
125                                        NULL,
126                                        /* 29*/ NULL,
127                                        NULL,
128                                        NULL,
129                                        NULL,
130                                        NULL,
131                                        /* 34*/ NULL,
132                                        NULL,
133                                        NULL,
134                                        NULL,
135                                        NULL,
136                                        /* 39*/ NULL,
137                                        NULL,
138                                        NULL,
139                                        NULL,
140                                        NULL,
141                                        /* 44*/ NULL,
142                                        NULL,
143                                        MatShift_Basic,
144                                        NULL,
145                                        NULL,
146                                        /* 49*/ NULL,
147                                        NULL,
148                                        NULL,
149                                        NULL,
150                                        NULL,
151                                        /* 54*/ NULL,
152                                        NULL,
153                                        NULL,
154                                        NULL,
155                                        NULL,
156                                        /* 59*/ NULL,
157                                        MatDestroy_Scatter,
158                                        NULL,
159                                        NULL,
160                                        NULL,
161                                        /* 64*/ NULL,
162                                        NULL,
163                                        NULL,
164                                        NULL,
165                                        NULL,
166                                        /* 69*/ NULL,
167                                        NULL,
168                                        NULL,
169                                        NULL,
170                                        NULL,
171                                        /* 74*/ NULL,
172                                        NULL,
173                                        NULL,
174                                        NULL,
175                                        NULL,
176                                        /* 79*/ NULL,
177                                        NULL,
178                                        NULL,
179                                        NULL,
180                                        NULL,
181                                        /* 84*/ NULL,
182                                        NULL,
183                                        NULL,
184                                        NULL,
185                                        NULL,
186                                        /* 89*/ NULL,
187                                        NULL,
188                                        NULL,
189                                        NULL,
190                                        NULL,
191                                        /* 94*/ NULL,
192                                        NULL,
193                                        NULL,
194                                        NULL,
195                                        NULL,
196                                        /*99*/ NULL,
197                                        NULL,
198                                        NULL,
199                                        NULL,
200                                        NULL,
201                                        /*104*/ NULL,
202                                        NULL,
203                                        NULL,
204                                        NULL,
205                                        NULL,
206                                        /*109*/ NULL,
207                                        NULL,
208                                        NULL,
209                                        NULL,
210                                        NULL,
211                                        /*114*/ NULL,
212                                        NULL,
213                                        NULL,
214                                        NULL,
215                                        NULL,
216                                        /*119*/ NULL,
217                                        NULL,
218                                        NULL,
219                                        NULL,
220                                        NULL,
221                                        /*124*/ NULL,
222                                        NULL,
223                                        NULL,
224                                        NULL,
225                                        NULL,
226                                        /*129*/ NULL,
227                                        NULL,
228                                        NULL,
229                                        NULL,
230                                        NULL,
231                                        /*134*/ NULL,
232                                        NULL,
233                                        NULL,
234                                        NULL,
235                                        NULL,
236                                        /*139*/ NULL,
237                                        NULL,
238                                        NULL,
239                                        NULL,
240                                        NULL,
241                                        /*144*/ NULL,
242                                        NULL,
243                                        NULL,
244                                        NULL,
245                                        NULL,
246                                        NULL,
247                                        /*150*/ NULL,
248                                        NULL,
249                                        NULL,
250                                        NULL,
251                                        NULL,
252                                        /*155*/ NULL,
253                                        NULL};
254 
255 /*MC
256    MATSCATTER - "scatter" - A matrix type that simply applies a `VecScatterBegin()` and `VecScatterEnd()` to perform `MatMult()`
257 
258   Level: advanced
259 
260 .seealso: [](ch_matrices), `Mat`, `MATSCATTER`, MatCreateScatter()`, `MatScatterSetVecScatter()`, `MatScatterGetVecScatter()`
261 M*/
262 
263 PETSC_EXTERN PetscErrorCode MatCreate_Scatter(Mat A)
264 {
265   Mat_Scatter *b;
266 
267   PetscFunctionBegin;
268   A->ops[0] = MatOps_Values;
269   PetscCall(PetscNew(&b));
270 
271   A->data = (void *)b;
272 
273   PetscCall(PetscLayoutSetUp(A->rmap));
274   PetscCall(PetscLayoutSetUp(A->cmap));
275 
276   A->assembled    = PETSC_TRUE;
277   A->preallocated = PETSC_FALSE;
278 
279   PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSCATTER));
280   PetscFunctionReturn(PETSC_SUCCESS);
281 }
282 
283 #include <petsc/private/sfimpl.h>
284 /*@
285   MatCreateScatter - Creates a new matrix of `MatType` `MATSCATTER`, based on a VecScatter
286 
287   Collective
288 
289   Input Parameters:
290 + comm    - MPI communicator
291 - scatter - a `VecScatter`
292 
293   Output Parameter:
294 . A - the matrix
295 
296   Level: intermediate
297 
298   Notes:
299   PETSc requires that matrices and vectors being used for certain
300   operations are partitioned accordingly.  For example, when
301   creating a scatter matrix, A, that supports parallel matrix-vector
302   products using `MatMult`(A,x,y) the user should set the number
303   of local matrix rows to be the number of local elements of the
304   corresponding result vector, y. Note that this is information is
305   required for use of the matrix interface routines, even though
306   the scatter matrix may not actually be physically partitioned.
307 
308   Developer Notes:
309   This directly accesses information inside the `VecScatter` associated with the matrix-vector product
310   for this matrix. This is not desirable..
311 
312 .seealso: [](ch_matrices), `Mat`, `MatScatterSetVecScatter()`, `MatScatterGetVecScatter()`, `MATSCATTER`
313 @*/
314 PetscErrorCode MatCreateScatter(MPI_Comm comm, VecScatter scatter, Mat *A)
315 {
316   PetscFunctionBegin;
317   PetscCall(MatCreate(comm, A));
318   PetscCall(MatSetSizes(*A, scatter->vscat.to_n, scatter->vscat.from_n, PETSC_DETERMINE, PETSC_DETERMINE));
319   PetscCall(MatSetType(*A, MATSCATTER));
320   PetscCall(MatScatterSetVecScatter(*A, scatter));
321   PetscCall(MatSetUp(*A));
322   PetscFunctionReturn(PETSC_SUCCESS);
323 }
324 
325 /*@
326   MatScatterSetVecScatter - sets the scatter that the matrix is to apply as its linear operator in a `MATSCATTER`
327 
328   Logically Collective
329 
330   Input Parameters:
331 + mat     - the `MATSCATTER` matrix
332 - scatter - the scatter context create with `VecScatterCreate()`
333 
334   Level: advanced
335 
336 .seealso: [](ch_matrices), `Mat`, `MATSCATTER`, `MatCreateScatter()`
337 @*/
338 PetscErrorCode MatScatterSetVecScatter(Mat mat, VecScatter scatter)
339 {
340   Mat_Scatter *mscatter = (Mat_Scatter *)mat->data;
341 
342   PetscFunctionBegin;
343   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
344   PetscValidHeaderSpecific(scatter, PETSCSF_CLASSID, 2);
345   PetscCheckSameComm(scatter, 2, mat, 1);
346   PetscCheck(mat->rmap->n == scatter->vscat.to_n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Number of local rows in matrix %" PetscInt_FMT " not equal local scatter size %" PetscInt_FMT, mat->rmap->n, scatter->vscat.to_n);
347   PetscCheck(mat->cmap->n == scatter->vscat.from_n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Number of local columns in matrix %" PetscInt_FMT " not equal local scatter size %" PetscInt_FMT, mat->cmap->n, scatter->vscat.from_n);
348 
349   PetscCall(PetscObjectReference((PetscObject)scatter));
350   PetscCall(VecScatterDestroy(&mscatter->scatter));
351 
352   mscatter->scatter = scatter;
353   PetscFunctionReturn(PETSC_SUCCESS);
354 }
355