xref: /petsc/src/mat/impls/composite/mcomposite.c (revision 55e7fe800d976e85ed2b5cd8bfdef564daa37bd9)
1 
2 #include <petsc/private/matimpl.h>        /*I "petscmat.h" I*/
3 
4 typedef struct _Mat_CompositeLink *Mat_CompositeLink;
5 struct _Mat_CompositeLink {
6   Mat               mat;
7   Vec               work;
8   Mat_CompositeLink next,prev;
9 };
10 
11 typedef struct {
12   MatCompositeType  type;
13   Mat_CompositeLink head,tail;
14   Vec               work;
15   PetscScalar       scale;        /* scale factor supplied with MatScale() */
16   Vec               left,right;   /* left and right diagonal scaling provided with MatDiagonalScale() */
17   Vec               leftwork,rightwork;
18 } Mat_Composite;
19 
20 PetscErrorCode MatDestroy_Composite(Mat mat)
21 {
22   PetscErrorCode    ierr;
23   Mat_Composite     *shell = (Mat_Composite*)mat->data;
24   Mat_CompositeLink next   = shell->head,oldnext;
25 
26   PetscFunctionBegin;
27   while (next) {
28     ierr = MatDestroy(&next->mat);CHKERRQ(ierr);
29     if (next->work && (!next->next || next->work != next->next->work)) {
30       ierr = VecDestroy(&next->work);CHKERRQ(ierr);
31     }
32     oldnext = next;
33     next    = next->next;
34     ierr    = PetscFree(oldnext);CHKERRQ(ierr);
35   }
36   ierr = VecDestroy(&shell->work);CHKERRQ(ierr);
37   ierr = VecDestroy(&shell->left);CHKERRQ(ierr);
38   ierr = VecDestroy(&shell->right);CHKERRQ(ierr);
39   ierr = VecDestroy(&shell->leftwork);CHKERRQ(ierr);
40   ierr = VecDestroy(&shell->rightwork);CHKERRQ(ierr);
41   ierr = PetscFree(mat->data);CHKERRQ(ierr);
42   PetscFunctionReturn(0);
43 }
44 
45 PetscErrorCode MatMult_Composite_Multiplicative(Mat A,Vec x,Vec y)
46 {
47   Mat_Composite     *shell = (Mat_Composite*)A->data;
48   Mat_CompositeLink next   = shell->head;
49   PetscErrorCode    ierr;
50   Vec               in,out;
51 
52   PetscFunctionBegin;
53   if (!next) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONGSTATE,"Must provide at least one matrix with MatCompositeAddMat()");
54   in = x;
55   if (shell->right) {
56     if (!shell->rightwork) {
57       ierr = VecDuplicate(shell->right,&shell->rightwork);CHKERRQ(ierr);
58     }
59     ierr = VecPointwiseMult(shell->rightwork,shell->right,in);CHKERRQ(ierr);
60     in   = shell->rightwork;
61   }
62   while (next->next) {
63     if (!next->work) { /* should reuse previous work if the same size */
64       ierr = MatCreateVecs(next->mat,NULL,&next->work);CHKERRQ(ierr);
65     }
66     out  = next->work;
67     ierr = MatMult(next->mat,in,out);CHKERRQ(ierr);
68     in   = out;
69     next = next->next;
70   }
71   ierr = MatMult(next->mat,in,y);CHKERRQ(ierr);
72   if (shell->left) {
73     ierr = VecPointwiseMult(y,shell->left,y);CHKERRQ(ierr);
74   }
75   ierr = VecScale(y,shell->scale);CHKERRQ(ierr);
76   PetscFunctionReturn(0);
77 }
78 
79 PetscErrorCode MatMultTranspose_Composite_Multiplicative(Mat A,Vec x,Vec y)
80 {
81   Mat_Composite     *shell = (Mat_Composite*)A->data;
82   Mat_CompositeLink tail   = shell->tail;
83   PetscErrorCode    ierr;
84   Vec               in,out;
85 
86   PetscFunctionBegin;
87   if (!tail) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONGSTATE,"Must provide at least one matrix with MatCompositeAddMat()");
88   in = x;
89   if (shell->left) {
90     if (!shell->leftwork) {
91       ierr = VecDuplicate(shell->left,&shell->leftwork);CHKERRQ(ierr);
92     }
93     ierr = VecPointwiseMult(shell->leftwork,shell->left,in);CHKERRQ(ierr);
94     in   = shell->leftwork;
95   }
96   while (tail->prev) {
97     if (!tail->prev->work) { /* should reuse previous work if the same size */
98       ierr = MatCreateVecs(tail->mat,NULL,&tail->prev->work);CHKERRQ(ierr);
99     }
100     out  = tail->prev->work;
101     ierr = MatMultTranspose(tail->mat,in,out);CHKERRQ(ierr);
102     in   = out;
103     tail = tail->prev;
104   }
105   ierr = MatMultTranspose(tail->mat,in,y);CHKERRQ(ierr);
106   if (shell->right) {
107     ierr = VecPointwiseMult(y,shell->right,y);CHKERRQ(ierr);
108   }
109   ierr = VecScale(y,shell->scale);CHKERRQ(ierr);
110   PetscFunctionReturn(0);
111 }
112 
113 PetscErrorCode MatMult_Composite(Mat A,Vec x,Vec y)
114 {
115   Mat_Composite     *shell = (Mat_Composite*)A->data;
116   Mat_CompositeLink next   = shell->head;
117   PetscErrorCode    ierr;
118   Vec               in;
119 
120   PetscFunctionBegin;
121   if (!next) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONGSTATE,"Must provide at least one matrix with MatCompositeAddMat()");
122   in = x;
123   if (shell->right) {
124     if (!shell->rightwork) {
125       ierr = VecDuplicate(shell->right,&shell->rightwork);CHKERRQ(ierr);
126     }
127     ierr = VecPointwiseMult(shell->rightwork,shell->right,in);CHKERRQ(ierr);
128     in   = shell->rightwork;
129   }
130   ierr = MatMult(next->mat,in,y);CHKERRQ(ierr);
131   while ((next = next->next)) {
132     ierr = MatMultAdd(next->mat,in,y,y);CHKERRQ(ierr);
133   }
134   if (shell->left) {
135     ierr = VecPointwiseMult(y,shell->left,y);CHKERRQ(ierr);
136   }
137   ierr = VecScale(y,shell->scale);CHKERRQ(ierr);
138   PetscFunctionReturn(0);
139 }
140 
141 PetscErrorCode MatMultTranspose_Composite(Mat A,Vec x,Vec y)
142 {
143   Mat_Composite     *shell = (Mat_Composite*)A->data;
144   Mat_CompositeLink next   = shell->head;
145   PetscErrorCode    ierr;
146   Vec               in;
147 
148   PetscFunctionBegin;
149   if (!next) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONGSTATE,"Must provide at least one matrix with MatCompositeAddMat()");
150   in = x;
151   if (shell->left) {
152     if (!shell->leftwork) {
153       ierr = VecDuplicate(shell->left,&shell->leftwork);CHKERRQ(ierr);
154     }
155     ierr = VecPointwiseMult(shell->leftwork,shell->left,in);CHKERRQ(ierr);
156     in   = shell->leftwork;
157   }
158   ierr = MatMultTranspose(next->mat,in,y);CHKERRQ(ierr);
159   while ((next = next->next)) {
160     ierr = MatMultTransposeAdd(next->mat,in,y,y);CHKERRQ(ierr);
161   }
162   if (shell->right) {
163     ierr = VecPointwiseMult(y,shell->right,y);CHKERRQ(ierr);
164   }
165   ierr = VecScale(y,shell->scale);CHKERRQ(ierr);
166   PetscFunctionReturn(0);
167 }
168 
169 PetscErrorCode MatGetDiagonal_Composite(Mat A,Vec v)
170 {
171   Mat_Composite     *shell = (Mat_Composite*)A->data;
172   Mat_CompositeLink next   = shell->head;
173   PetscErrorCode    ierr;
174 
175   PetscFunctionBegin;
176   if (!next) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONGSTATE,"Must provide at least one matrix with MatCompositeAddMat()");
177   if (shell->right || shell->left) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_SUP,"Cannot get diagonal if left or right scaling");
178 
179   ierr = MatGetDiagonal(next->mat,v);CHKERRQ(ierr);
180   if (next->next && !shell->work) {
181     ierr = VecDuplicate(v,&shell->work);CHKERRQ(ierr);
182   }
183   while ((next = next->next)) {
184     ierr = MatGetDiagonal(next->mat,shell->work);CHKERRQ(ierr);
185     ierr = VecAXPY(v,1.0,shell->work);CHKERRQ(ierr);
186   }
187   ierr = VecScale(v,shell->scale);CHKERRQ(ierr);
188   PetscFunctionReturn(0);
189 }
190 
191 PetscErrorCode MatAssemblyEnd_Composite(Mat Y,MatAssemblyType t)
192 {
193   PetscErrorCode ierr;
194   PetscBool      flg = PETSC_FALSE;
195 
196   PetscFunctionBegin;
197   ierr = PetscOptionsGetBool(((PetscObject)Y)->options,((PetscObject)Y)->prefix,"-mat_composite_merge",&flg,NULL);CHKERRQ(ierr);
198   if (flg) {
199     ierr = MatCompositeMerge(Y);CHKERRQ(ierr);
200   }
201   PetscFunctionReturn(0);
202 }
203 
204 PetscErrorCode MatScale_Composite(Mat inA,PetscScalar alpha)
205 {
206   Mat_Composite *a = (Mat_Composite*)inA->data;
207 
208   PetscFunctionBegin;
209   a->scale *= alpha;
210   PetscFunctionReturn(0);
211 }
212 
213 PetscErrorCode MatDiagonalScale_Composite(Mat inA,Vec left,Vec right)
214 {
215   Mat_Composite  *a = (Mat_Composite*)inA->data;
216   PetscErrorCode ierr;
217 
218   PetscFunctionBegin;
219   if (left) {
220     if (!a->left) {
221       ierr = VecDuplicate(left,&a->left);CHKERRQ(ierr);
222       ierr = VecCopy(left,a->left);CHKERRQ(ierr);
223     } else {
224       ierr = VecPointwiseMult(a->left,left,a->left);CHKERRQ(ierr);
225     }
226   }
227   if (right) {
228     if (!a->right) {
229       ierr = VecDuplicate(right,&a->right);CHKERRQ(ierr);
230       ierr = VecCopy(right,a->right);CHKERRQ(ierr);
231     } else {
232       ierr = VecPointwiseMult(a->right,right,a->right);CHKERRQ(ierr);
233     }
234   }
235   PetscFunctionReturn(0);
236 }
237 
238 static struct _MatOps MatOps_Values = {0,
239                                        0,
240                                        0,
241                                        MatMult_Composite,
242                                        0,
243                                 /*  5*/ MatMultTranspose_Composite,
244                                        0,
245                                        0,
246                                        0,
247                                        0,
248                                 /* 10*/ 0,
249                                        0,
250                                        0,
251                                        0,
252                                        0,
253                                 /* 15*/ 0,
254                                        0,
255                                        MatGetDiagonal_Composite,
256                                        MatDiagonalScale_Composite,
257                                        0,
258                                 /* 20*/ 0,
259                                        MatAssemblyEnd_Composite,
260                                        0,
261                                        0,
262                                /* 24*/ 0,
263                                        0,
264                                        0,
265                                        0,
266                                        0,
267                                /* 29*/ 0,
268                                        0,
269                                        0,
270                                        0,
271                                        0,
272                                /* 34*/ 0,
273                                        0,
274                                        0,
275                                        0,
276                                        0,
277                                /* 39*/ 0,
278                                        0,
279                                        0,
280                                        0,
281                                        0,
282                                /* 44*/ 0,
283                                        MatScale_Composite,
284                                        MatShift_Basic,
285                                        0,
286                                        0,
287                                /* 49*/ 0,
288                                        0,
289                                        0,
290                                        0,
291                                        0,
292                                /* 54*/ 0,
293                                        0,
294                                        0,
295                                        0,
296                                        0,
297                                /* 59*/ 0,
298                                        MatDestroy_Composite,
299                                        0,
300                                        0,
301                                        0,
302                                /* 64*/ 0,
303                                        0,
304                                        0,
305                                        0,
306                                        0,
307                                /* 69*/ 0,
308                                        0,
309                                        0,
310                                        0,
311                                        0,
312                                /* 74*/ 0,
313                                        0,
314                                        0,
315                                        0,
316                                        0,
317                                /* 79*/ 0,
318                                        0,
319                                        0,
320                                        0,
321                                        0,
322                                /* 84*/ 0,
323                                        0,
324                                        0,
325                                        0,
326                                        0,
327                                /* 89*/ 0,
328                                        0,
329                                        0,
330                                        0,
331                                        0,
332                                /* 94*/ 0,
333                                        0,
334                                        0,
335                                        0,
336                                        0,
337                                 /*99*/ 0,
338                                        0,
339                                        0,
340                                        0,
341                                        0,
342                                /*104*/ 0,
343                                        0,
344                                        0,
345                                        0,
346                                        0,
347                                /*109*/ 0,
348                                        0,
349                                        0,
350                                        0,
351                                        0,
352                                /*114*/ 0,
353                                        0,
354                                        0,
355                                        0,
356                                        0,
357                                /*119*/ 0,
358                                        0,
359                                        0,
360                                        0,
361                                        0,
362                                /*124*/ 0,
363                                        0,
364                                        0,
365                                        0,
366                                        0,
367                                /*129*/ 0,
368                                        0,
369                                        0,
370                                        0,
371                                        0,
372                                /*134*/ 0,
373                                        0,
374                                        0,
375                                        0,
376                                        0,
377                                /*139*/ 0,
378                                        0,
379                                        0
380 };
381 
382 /*MC
383    MATCOMPOSITE - A matrix defined by the sum (or product) of one or more matrices (all matrices are of same size and parallel layout).
384 
385    Notes:
386     to use the product of the matrices call MatCompositeSetType(mat,MAT_COMPOSITE_MULTIPLICATIVE);
387 
388   Level: advanced
389 
390 .seealso: MatCreateComposite(), MatCompositeAddMat(), MatSetType(), MatCompositeMerge(), MatCompositeSetType(), MatCompositeType
391 M*/
392 
393 PETSC_EXTERN PetscErrorCode MatCreate_Composite(Mat A)
394 {
395   Mat_Composite  *b;
396   PetscErrorCode ierr;
397 
398   PetscFunctionBegin;
399   ierr    = PetscNewLog(A,&b);CHKERRQ(ierr);
400   A->data = (void*)b;
401   ierr    = PetscMemcpy(A->ops,&MatOps_Values,sizeof(struct _MatOps));CHKERRQ(ierr);
402 
403   ierr = PetscLayoutSetUp(A->rmap);CHKERRQ(ierr);
404   ierr = PetscLayoutSetUp(A->cmap);CHKERRQ(ierr);
405 
406   A->assembled    = PETSC_TRUE;
407   A->preallocated = PETSC_TRUE;
408   b->type         = MAT_COMPOSITE_ADDITIVE;
409   b->scale        = 1.0;
410   ierr            = PetscObjectChangeTypeName((PetscObject)A,MATCOMPOSITE);CHKERRQ(ierr);
411   PetscFunctionReturn(0);
412 }
413 
414 /*@
415    MatCreateComposite - Creates a matrix as the sum of zero or more matrices
416 
417   Collective on MPI_Comm
418 
419    Input Parameters:
420 +  comm - MPI communicator
421 .  nmat - number of matrices to put in
422 -  mats - the matrices
423 
424    Output Parameter:
425 .  mat - the matrix
426 
427    Level: advanced
428 
429    Notes:
430      Alternative construction
431 $       MatCreate(comm,&mat);
432 $       MatSetSizes(mat,m,n,M,N);
433 $       MatSetType(mat,MATCOMPOSITE);
434 $       MatCompositeAddMat(mat,mats[0]);
435 $       ....
436 $       MatCompositeAddMat(mat,mats[nmat-1]);
437 $       MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY);
438 $       MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY);
439 
440      For the multiplicative form the product is mat[nmat-1]*mat[nmat-2]*....*mat[0]
441 
442 .seealso: MatDestroy(), MatMult(), MatCompositeAddMat(), MatCompositeMerge(), MatCompositeSetType(), MatCompositeType
443 
444 @*/
445 PetscErrorCode  MatCreateComposite(MPI_Comm comm,PetscInt nmat,const Mat *mats,Mat *mat)
446 {
447   PetscErrorCode ierr;
448   PetscInt       m,n,M,N,i;
449 
450   PetscFunctionBegin;
451   if (nmat < 1) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Must pass in at least one matrix");
452   PetscValidPointer(mat,3);
453 
454   ierr = MatGetLocalSize(mats[0],&m,&n);CHKERRQ(ierr);
455   ierr = MatGetSize(mats[0],&M,&N);CHKERRQ(ierr);
456   ierr = MatCreate(comm,mat);CHKERRQ(ierr);
457   ierr = MatSetSizes(*mat,m,n,M,N);CHKERRQ(ierr);
458   ierr = MatSetType(*mat,MATCOMPOSITE);CHKERRQ(ierr);
459   for (i=0; i<nmat; i++) {
460     ierr = MatCompositeAddMat(*mat,mats[i]);CHKERRQ(ierr);
461   }
462   ierr = MatAssemblyBegin(*mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
463   ierr = MatAssemblyEnd(*mat,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
464   PetscFunctionReturn(0);
465 }
466 
467 /*@
468     MatCompositeAddMat - add another matrix to a composite matrix
469 
470    Collective on Mat
471 
472     Input Parameters:
473 +   mat - the composite matrix
474 -   smat - the partial matrix
475 
476    Level: advanced
477 
478 .seealso: MatCreateComposite()
479 @*/
480 PetscErrorCode  MatCompositeAddMat(Mat mat,Mat smat)
481 {
482   Mat_Composite     *shell;
483   PetscErrorCode    ierr;
484   Mat_CompositeLink ilink,next;
485 
486   PetscFunctionBegin;
487   PetscValidHeaderSpecific(mat,MAT_CLASSID,1);
488   PetscValidHeaderSpecific(smat,MAT_CLASSID,2);
489   ierr        = PetscNewLog(mat,&ilink);CHKERRQ(ierr);
490   ilink->next = 0;
491   ierr        = PetscObjectReference((PetscObject)smat);CHKERRQ(ierr);
492   ilink->mat  = smat;
493 
494   shell = (Mat_Composite*)mat->data;
495   next  = shell->head;
496   if (!next) shell->head = ilink;
497   else {
498     while (next->next) {
499       next = next->next;
500     }
501     next->next  = ilink;
502     ilink->prev = next;
503   }
504   shell->tail = ilink;
505   PetscFunctionReturn(0);
506 }
507 
508 /*@
509    MatCompositeSetType - Indicates if the matrix is defined as the sum of a set of matrices or the product
510 
511   Collective on MPI_Comm
512 
513    Input Parameters:
514 .  mat - the composite matrix
515 
516 
517    Level: advanced
518 
519    Notes:
520       The MatType of the resulting matrix will be the same as the MatType of the FIRST
521     matrix in the composite matrix.
522 
523 .seealso: MatDestroy(), MatMult(), MatCompositeAddMat(), MatCreateComposite(), MATCOMPOSITE
524 
525 @*/
526 PetscErrorCode  MatCompositeSetType(Mat mat,MatCompositeType type)
527 {
528   Mat_Composite  *b = (Mat_Composite*)mat->data;
529   PetscBool      flg;
530   PetscErrorCode ierr;
531 
532   PetscFunctionBegin;
533   ierr = PetscObjectTypeCompare((PetscObject)mat,MATCOMPOSITE,&flg);CHKERRQ(ierr);
534   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONGSTATE,"Can only use with composite matrix");
535   if (type == MAT_COMPOSITE_MULTIPLICATIVE) {
536     mat->ops->getdiagonal   = 0;
537     mat->ops->mult          = MatMult_Composite_Multiplicative;
538     mat->ops->multtranspose = MatMultTranspose_Composite_Multiplicative;
539     b->type                 = MAT_COMPOSITE_MULTIPLICATIVE;
540   } else {
541     mat->ops->getdiagonal   = MatGetDiagonal_Composite;
542     mat->ops->mult          = MatMult_Composite;
543     mat->ops->multtranspose = MatMultTranspose_Composite;
544     b->type                 = MAT_COMPOSITE_ADDITIVE;
545   }
546   PetscFunctionReturn(0);
547 }
548 
549 
550 /*@
551    MatCompositeMerge - Given a composite matrix, replaces it with a "regular" matrix
552      by summing all the matrices inside the composite matrix.
553 
554   Collective on MPI_Comm
555 
556    Input Parameters:
557 .  mat - the composite matrix
558 
559 
560    Options Database:
561 .  -mat_composite_merge  (you must call MatAssemblyBegin()/MatAssemblyEnd() to have this checked)
562 
563    Level: advanced
564 
565    Notes:
566       The MatType of the resulting matrix will be the same as the MatType of the FIRST
567     matrix in the composite matrix.
568 
569 .seealso: MatDestroy(), MatMult(), MatCompositeAddMat(), MatCreateComposite(), MATCOMPOSITE
570 
571 @*/
572 PetscErrorCode  MatCompositeMerge(Mat mat)
573 {
574   Mat_Composite     *shell = (Mat_Composite*)mat->data;
575   Mat_CompositeLink next   = shell->head, prev = shell->tail;
576   PetscErrorCode    ierr;
577   Mat               tmat,newmat;
578   Vec               left,right;
579   PetscScalar       scale;
580 
581   PetscFunctionBegin;
582   if (!next) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONGSTATE,"Must provide at least one matrix with MatCompositeAddMat()");
583 
584   PetscFunctionBegin;
585   if (shell->type == MAT_COMPOSITE_ADDITIVE) {
586     ierr = MatDuplicate(next->mat,MAT_COPY_VALUES,&tmat);CHKERRQ(ierr);
587     while ((next = next->next)) {
588       ierr = MatAXPY(tmat,1.0,next->mat,DIFFERENT_NONZERO_PATTERN);CHKERRQ(ierr);
589     }
590   } else {
591     ierr = MatDuplicate(next->mat,MAT_COPY_VALUES,&tmat);CHKERRQ(ierr);
592     while ((prev = prev->prev)) {
593       ierr = MatMatMult(tmat,prev->mat,MAT_INITIAL_MATRIX,PETSC_DECIDE,&newmat);CHKERRQ(ierr);
594       ierr = MatDestroy(&tmat);CHKERRQ(ierr);
595       tmat = newmat;
596     }
597   }
598 
599   scale = shell->scale;
600   if ((left = shell->left)) {ierr = PetscObjectReference((PetscObject)left);CHKERRQ(ierr);}
601   if ((right = shell->right)) {ierr = PetscObjectReference((PetscObject)right);CHKERRQ(ierr);}
602 
603   ierr = MatHeaderReplace(mat,&tmat);CHKERRQ(ierr);
604 
605   ierr = MatDiagonalScale(mat,left,right);CHKERRQ(ierr);
606   ierr = MatScale(mat,scale);CHKERRQ(ierr);
607   ierr = VecDestroy(&left);CHKERRQ(ierr);
608   ierr = VecDestroy(&right);CHKERRQ(ierr);
609   PetscFunctionReturn(0);
610 }
611