xref: /petsc/src/mat/impls/scatter/mscatter.c (revision e6e75211d226c622f451867f53ce5d558649ff4f)
1 
2 /*
3    This provides a matrix that applies a VecScatter to a vector.
4 */
5 
6 #include <petsc/private/matimpl.h>        /*I "petscmat.h" I*/
7 #include <petsc/private/vecimpl.h>
8 
9 typedef struct {
10   VecScatter scatter;
11 } Mat_Scatter;
12 
13 #undef __FUNCT__
14 #define __FUNCT__ "MatScatterGetVecScatter"
15 /*@
16     MatScatterGetVecScatter - Returns the user-provided scatter set with MatScatterSetVecScatter()
17 
18     Not Collective, but not cannot use scatter if not used collectively on Mat
19 
20     Input Parameter:
21 .   mat - the matrix, should have been created with MatCreateScatter() or have type MATSCATTER
22 
23     Output Parameter:
24 .   scatter - the scatter context
25 
26     Level: intermediate
27 
28 .keywords: matrix, scatter, get
29 
30 .seealso: MatCreateScatter(), MatScatterSetVecScatter(), MATSCATTER
31 @*/
32 PetscErrorCode  MatScatterGetVecScatter(Mat mat,VecScatter *scatter)
33 {
34   Mat_Scatter *mscatter;
35 
36   PetscFunctionBegin;
37   PetscValidHeaderSpecific(mat,MAT_CLASSID,1);
38   PetscValidPointer(scatter,2);
39   mscatter = (Mat_Scatter*)mat->data;
40   *scatter = mscatter->scatter;
41   PetscFunctionReturn(0);
42 }
43 
44 #undef __FUNCT__
45 #define __FUNCT__ "MatDestroy_Scatter"
46 PetscErrorCode MatDestroy_Scatter(Mat mat)
47 {
48   PetscErrorCode ierr;
49   Mat_Scatter    *scatter = (Mat_Scatter*)mat->data;
50 
51   PetscFunctionBegin;
52   ierr = VecScatterDestroy(&scatter->scatter);CHKERRQ(ierr);
53   ierr = PetscFree(mat->data);CHKERRQ(ierr);
54   PetscFunctionReturn(0);
55 }
56 
57 #undef __FUNCT__
58 #define __FUNCT__ "MatMult_Scatter"
59 PetscErrorCode MatMult_Scatter(Mat A,Vec x,Vec y)
60 {
61   Mat_Scatter    *scatter = (Mat_Scatter*)A->data;
62   PetscErrorCode ierr;
63 
64   PetscFunctionBegin;
65   if (!scatter->scatter) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONGSTATE,"Need to first call MatScatterSetScatter()");
66   ierr = VecZeroEntries(y);CHKERRQ(ierr);
67   ierr = VecScatterBegin(scatter->scatter,x,y,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
68   ierr = VecScatterEnd(scatter->scatter,x,y,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
69   PetscFunctionReturn(0);
70 }
71 
72 #undef __FUNCT__
73 #define __FUNCT__ "MatMultAdd_Scatter"
74 PetscErrorCode MatMultAdd_Scatter(Mat A,Vec x,Vec y,Vec z)
75 {
76   Mat_Scatter    *scatter = (Mat_Scatter*)A->data;
77   PetscErrorCode ierr;
78 
79   PetscFunctionBegin;
80   if (!scatter->scatter) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONGSTATE,"Need to first call MatScatterSetScatter()");
81   if (z != y) {ierr = VecCopy(y,z);CHKERRQ(ierr);}
82   ierr = VecScatterBegin(scatter->scatter,x,z,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
83   ierr = VecScatterEnd(scatter->scatter,x,z,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
84   PetscFunctionReturn(0);
85 }
86 
87 #undef __FUNCT__
88 #define __FUNCT__ "MatMultTranspose_Scatter"
89 PetscErrorCode MatMultTranspose_Scatter(Mat A,Vec x,Vec y)
90 {
91   Mat_Scatter    *scatter = (Mat_Scatter*)A->data;
92   PetscErrorCode ierr;
93 
94   PetscFunctionBegin;
95   if (!scatter->scatter) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONGSTATE,"Need to first call MatScatterSetScatter()");
96   ierr = VecZeroEntries(y);CHKERRQ(ierr);
97   ierr = VecScatterBegin(scatter->scatter,x,y,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
98   ierr = VecScatterEnd(scatter->scatter,x,y,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
99   PetscFunctionReturn(0);
100 }
101 
102 #undef __FUNCT__
103 #define __FUNCT__ "MatMultTransposeAdd_Scatter"
104 PetscErrorCode MatMultTransposeAdd_Scatter(Mat A,Vec x,Vec y,Vec z)
105 {
106   Mat_Scatter    *scatter = (Mat_Scatter*)A->data;
107   PetscErrorCode ierr;
108 
109   PetscFunctionBegin;
110   if (!scatter->scatter) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONGSTATE,"Need to first call MatScatterSetScatter()");
111   if (z != y) {ierr = VecCopy(y,z);CHKERRQ(ierr);}
112   ierr = VecScatterBegin(scatter->scatter,x,z,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
113   ierr = VecScatterEnd(scatter->scatter,x,z,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
114   PetscFunctionReturn(0);
115 }
116 
117 static struct _MatOps MatOps_Values = {0,
118                                        0,
119                                        0,
120                                        MatMult_Scatter,
121                                /*  4*/ MatMultAdd_Scatter,
122                                        MatMultTranspose_Scatter,
123                                        MatMultTransposeAdd_Scatter,
124                                        0,
125                                        0,
126                                        0,
127                                /* 10*/ 0,
128                                        0,
129                                        0,
130                                        0,
131                                        0,
132                                /* 15*/ 0,
133                                        0,
134                                        0,
135                                        0,
136                                        0,
137                                /* 20*/ 0,
138                                        0,
139                                        0,
140                                        0,
141                                /* 24*/ 0,
142                                        0,
143                                        0,
144                                        0,
145                                        0,
146                                /* 29*/ 0,
147                                        0,
148                                        0,
149                                        0,
150                                        0,
151                                /* 34*/ 0,
152                                        0,
153                                        0,
154                                        0,
155                                        0,
156                                /* 39*/ 0,
157                                        0,
158                                        0,
159                                        0,
160                                        0,
161                                /* 44*/ 0,
162                                        0,
163                                        MatShift_Basic,
164                                        0,
165                                        0,
166                                /* 49*/ 0,
167                                        0,
168                                        0,
169                                        0,
170                                        0,
171                                /* 54*/ 0,
172                                        0,
173                                        0,
174                                        0,
175                                        0,
176                                /* 59*/ 0,
177                                        MatDestroy_Scatter,
178                                        0,
179                                        0,
180                                        0,
181                                /* 64*/ 0,
182                                        0,
183                                        0,
184                                        0,
185                                        0,
186                                /* 69*/ 0,
187                                        0,
188                                        0,
189                                        0,
190                                        0,
191                                /* 74*/ 0,
192                                        0,
193                                        0,
194                                        0,
195                                        0,
196                                /* 79*/ 0,
197                                        0,
198                                        0,
199                                        0,
200                                        0,
201                                /* 84*/ 0,
202                                        0,
203                                        0,
204                                        0,
205                                        0,
206                                /* 89*/ 0,
207                                        0,
208                                        0,
209                                        0,
210                                        0,
211                                /* 94*/ 0,
212                                        0,
213                                        0,
214                                        0,
215                                        0,
216                                 /*99*/ 0,
217                                        0,
218                                        0,
219                                        0,
220                                        0,
221                                /*104*/ 0,
222                                        0,
223                                        0,
224                                        0,
225                                        0,
226                                /*109*/ 0,
227                                        0,
228                                        0,
229                                        0,
230                                        0,
231                                /*114*/ 0,
232                                        0,
233                                        0,
234                                        0,
235                                        0,
236                                /*119*/ 0,
237                                        0,
238                                        0,
239                                        0,
240                                        0,
241                                /*124*/ 0,
242                                        0,
243                                        0,
244                                        0,
245                                        0,
246                                /*129*/ 0,
247                                        0,
248                                        0,
249                                        0,
250                                        0,
251                                /*134*/ 0,
252                                        0,
253                                        0,
254                                        0,
255                                        0,
256                                /*139*/ 0,
257                                        0,
258                                        0
259 };
260 
261 /*MC
262    MATSCATTER - MATSCATTER = "scatter" - A matrix type that simply applies a VecScatterBegin/End()
263 
264   Level: advanced
265 
266 .seealso: MatCreateScatter(), MatScatterSetVecScatter(), MatScatterGetVecScatter()
267 
268 M*/
269 
270 #undef __FUNCT__
271 #define __FUNCT__ "MatCreate_Scatter"
272 PETSC_EXTERN PetscErrorCode MatCreate_Scatter(Mat A)
273 {
274   Mat_Scatter    *b;
275   PetscErrorCode ierr;
276 
277   PetscFunctionBegin;
278   ierr = PetscMemcpy(A->ops,&MatOps_Values,sizeof(struct _MatOps));CHKERRQ(ierr);
279   ierr = PetscNewLog(A,&b);CHKERRQ(ierr);
280 
281   A->data = (void*)b;
282 
283   ierr = PetscLayoutSetUp(A->rmap);CHKERRQ(ierr);
284   ierr = PetscLayoutSetUp(A->cmap);CHKERRQ(ierr);
285 
286   A->assembled    = PETSC_TRUE;
287   A->preallocated = PETSC_FALSE;
288 
289   ierr = PetscObjectChangeTypeName((PetscObject)A,MATSCATTER);CHKERRQ(ierr);
290   PetscFunctionReturn(0);
291 }
292 
293 #undef __FUNCT__
294 #define __FUNCT__ "MatCreateScatter"
295 /*@C
296    MatCreateScatter - Creates a new matrix based on a VecScatter
297 
298   Collective on MPI_Comm
299 
300    Input Parameters:
301 +  comm - MPI communicator
302 -  scatter - a VecScatterContext
303 
304    Output Parameter:
305 .  A - the matrix
306 
307    Level: intermediate
308 
309    PETSc requires that matrices and vectors being used for certain
310    operations are partitioned accordingly.  For example, when
311    creating a scatter matrix, A, that supports parallel matrix-vector
312    products using MatMult(A,x,y) the user should set the number
313    of local matrix rows to be the number of local elements of the
314    corresponding result vector, y. Note that this is information is
315    required for use of the matrix interface routines, even though
316    the scatter matrix may not actually be physically partitioned.
317 
318 .keywords: matrix, scatter, create
319 
320 .seealso: MatScatterSetVecScatter(), MatScatterGetVecScatter(), MATSCATTER
321 @*/
322 PetscErrorCode  MatCreateScatter(MPI_Comm comm,VecScatter scatter,Mat *A)
323 {
324   PetscErrorCode ierr;
325 
326   PetscFunctionBegin;
327   ierr = MatCreate(comm,A);CHKERRQ(ierr);
328   ierr = MatSetSizes(*A,scatter->to_n,scatter->from_n,PETSC_DETERMINE,PETSC_DETERMINE);CHKERRQ(ierr);
329   ierr = MatSetType(*A,MATSCATTER);CHKERRQ(ierr);
330   ierr = MatScatterSetVecScatter(*A,scatter);CHKERRQ(ierr);
331   ierr = MatSetUp(*A);CHKERRQ(ierr);
332   PetscFunctionReturn(0);
333 }
334 
335 #undef __FUNCT__
336 #define __FUNCT__ "MatScatterSetVecScatter"
337 /*@
338     MatScatterSetVecScatter - sets that scatter that the matrix is to apply as its linear operator
339 
340    Collective on Mat
341 
342     Input Parameters:
343 +   mat - the scatter matrix
344 -   scatter - the scatter context create with VecScatterCreate()
345 
346    Level: advanced
347 
348 
349 .seealso: MatCreateScatter(), MATSCATTER
350 @*/
351 PetscErrorCode  MatScatterSetVecScatter(Mat mat,VecScatter scatter)
352 {
353   Mat_Scatter    *mscatter = (Mat_Scatter*)mat->data;
354   PetscErrorCode ierr;
355 
356   PetscFunctionBegin;
357   PetscValidHeaderSpecific(mat,MAT_CLASSID,1);
358   PetscValidHeaderSpecific(scatter,VEC_SCATTER_CLASSID,2);
359   PetscCheckSameComm((PetscObject)scatter,1,(PetscObject)mat,2);
360   if (mat->rmap->n != scatter->to_n) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Number of local rows in matrix %D not equal local scatter size %D",mat->rmap->n,scatter->to_n);
361   if (mat->cmap->n != scatter->from_n) SETERRQ2(PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"Number of local columns in matrix %D not equal local scatter size %D",mat->cmap->n,scatter->from_n);
362 
363   ierr = PetscObjectReference((PetscObject)scatter);CHKERRQ(ierr);
364   ierr = VecScatterDestroy(&mscatter->scatter);CHKERRQ(ierr);
365 
366   mscatter->scatter = scatter;
367   PetscFunctionReturn(0);
368 }
369 
370 
371