xref: /petsc/src/mat/impls/composite/mcomposite.c (revision 98d129c30f3ee9fdddc40fdbc5a989b7be64f888)
1 #include <../src/mat/impls/shell/shell.h> /*I "petscmat.h" I*/
2 
3 const char *const MatCompositeMergeTypes[] = {"left", "right", "MatCompositeMergeType", "MAT_COMPOSITE_", NULL};
4 
5 typedef struct _Mat_CompositeLink *Mat_CompositeLink;
6 struct _Mat_CompositeLink {
7   Mat               mat;
8   Vec               work;
9   Mat_CompositeLink next, prev;
10 };
11 
12 typedef struct {
13   MatCompositeType      type;
14   Mat_CompositeLink     head, tail;
15   Vec                   work;
16   PetscInt              nmat;
17   PetscBool             merge;
18   MatCompositeMergeType mergetype;
19   MatStructure          structure;
20 
21   PetscScalar *scalings;
22   PetscBool    merge_mvctx; /* Whether need to merge mvctx of component matrices */
23   Vec         *lvecs;       /* [nmat] Basically, they are Mvctx->lvec of each component matrix */
24   PetscScalar *larray;      /* [len] Data arrays of lvecs[] are stored consecutively in larray */
25   PetscInt     len;         /* Length of larray[] */
26   Vec          gvec;        /* Union of lvecs[] without duplicated entries */
27   PetscInt    *location;    /* A map that maps entries in garray[] to larray[] */
28   VecScatter   Mvctx;
29 } Mat_Composite;
30 
31 static PetscErrorCode MatDestroy_Composite(Mat mat)
32 {
33   Mat_Composite    *shell;
34   Mat_CompositeLink next, oldnext;
35   PetscInt          i;
36 
37   PetscFunctionBegin;
38   PetscCall(MatShellGetContext(mat, &shell));
39   next = shell->head;
40   while (next) {
41     PetscCall(MatDestroy(&next->mat));
42     if (next->work && (!next->next || next->work != next->next->work)) PetscCall(VecDestroy(&next->work));
43     oldnext = next;
44     next    = next->next;
45     PetscCall(PetscFree(oldnext));
46   }
47   PetscCall(VecDestroy(&shell->work));
48 
49   if (shell->Mvctx) {
50     for (i = 0; i < shell->nmat; i++) PetscCall(VecDestroy(&shell->lvecs[i]));
51     PetscCall(PetscFree3(shell->location, shell->larray, shell->lvecs));
52     PetscCall(PetscFree(shell->larray));
53     PetscCall(VecDestroy(&shell->gvec));
54     PetscCall(VecScatterDestroy(&shell->Mvctx));
55   }
56 
57   PetscCall(PetscFree(shell->scalings));
58   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeAddMat_C", NULL));
59   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetType_C", NULL));
60   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetType_C", NULL));
61   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMergeType_C", NULL));
62   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMatStructure_C", NULL));
63   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMatStructure_C", NULL));
64   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeMerge_C", NULL));
65   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetNumberMat_C", NULL));
66   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMat_C", NULL));
67   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetScalings_C", NULL));
68   PetscCall(PetscFree(shell));
69   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatShellSetContext_C", NULL)); // needed to avoid a call to MatShellSetContext_Immutable()
70   PetscFunctionReturn(PETSC_SUCCESS);
71 }
72 
73 static PetscErrorCode MatMult_Composite_Multiplicative(Mat A, Vec x, Vec y)
74 {
75   Mat_Composite    *shell;
76   Mat_CompositeLink next;
77   Vec               out;
78 
79   PetscFunctionBegin;
80   PetscCall(MatShellGetContext(A, &shell));
81   next = shell->head;
82   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
83   while (next->next) {
84     if (!next->work) { /* should reuse previous work if the same size */
85       PetscCall(MatCreateVecs(next->mat, NULL, &next->work));
86     }
87     out = next->work;
88     PetscCall(MatMult(next->mat, x, out));
89     x    = out;
90     next = next->next;
91   }
92   PetscCall(MatMult(next->mat, x, y));
93   if (shell->scalings) {
94     PetscScalar scale = 1.0;
95     for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
96     PetscCall(VecScale(y, scale));
97   }
98   PetscFunctionReturn(PETSC_SUCCESS);
99 }
100 
101 static PetscErrorCode MatMultTranspose_Composite_Multiplicative(Mat A, Vec x, Vec y)
102 {
103   Mat_Composite    *shell;
104   Mat_CompositeLink tail;
105   Vec               out;
106 
107   PetscFunctionBegin;
108   PetscCall(MatShellGetContext(A, &shell));
109   tail = shell->tail;
110   PetscCheck(tail, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
111   while (tail->prev) {
112     if (!tail->prev->work) { /* should reuse previous work if the same size */
113       PetscCall(MatCreateVecs(tail->mat, NULL, &tail->prev->work));
114     }
115     out = tail->prev->work;
116     PetscCall(MatMultTranspose(tail->mat, x, out));
117     x    = out;
118     tail = tail->prev;
119   }
120   PetscCall(MatMultTranspose(tail->mat, x, y));
121   if (shell->scalings) {
122     PetscScalar scale = 1.0;
123     for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
124     PetscCall(VecScale(y, scale));
125   }
126   PetscFunctionReturn(PETSC_SUCCESS);
127 }
128 
129 static PetscErrorCode MatMult_Composite(Mat mat, Vec x, Vec y)
130 {
131   Mat_Composite     *shell;
132   Mat_CompositeLink  cur;
133   Vec                y2, xin;
134   Mat                A, B;
135   PetscInt           i, j, k, n, nuniq, lo, hi, mid, *gindices, *buf, *tmp, tot;
136   const PetscScalar *vals;
137   const PetscInt    *garray;
138   IS                 ix, iy;
139   PetscBool          match;
140 
141   PetscFunctionBegin;
142   PetscCall(MatShellGetContext(mat, &shell));
143   cur = shell->head;
144   PetscCheck(cur, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
145 
146   /* Try to merge Mvctx when instructed but not yet done. We did not do it in MatAssemblyEnd() since at that time
147      we did not know whether mat is ADDITIVE or MULTIPLICATIVE. Only now we are assured mat is ADDITIVE and
148      it is legal to merge Mvctx, because all component matrices have the same size.
149    */
150   if (shell->merge_mvctx && !shell->Mvctx) {
151     /* Currently only implemented for MATMPIAIJ */
152     for (cur = shell->head; cur; cur = cur->next) {
153       PetscCall(PetscObjectTypeCompare((PetscObject)cur->mat, MATMPIAIJ, &match));
154       if (!match) {
155         shell->merge_mvctx = PETSC_FALSE;
156         goto skip_merge_mvctx;
157       }
158     }
159 
160     /* Go through matrices first time to count total number of nonzero off-diag columns (may have dups) */
161     tot = 0;
162     for (cur = shell->head; cur; cur = cur->next) {
163       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, NULL));
164       PetscCall(MatGetLocalSize(B, NULL, &n));
165       tot += n;
166     }
167     PetscCall(PetscMalloc3(tot, &shell->location, tot, &shell->larray, shell->nmat, &shell->lvecs));
168     shell->len = tot;
169 
170     /* Go through matrices second time to sort off-diag columns and remove dups */
171     PetscCall(PetscMalloc1(tot, &gindices)); /* No Malloc2() since we will give one to petsc and free the other */
172     PetscCall(PetscMalloc1(tot, &buf));
173     nuniq = 0; /* Number of unique nonzero columns */
174     for (cur = shell->head; cur; cur = cur->next) {
175       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray));
176       PetscCall(MatGetLocalSize(B, NULL, &n));
177       /* Merge pre-sorted garray[0,n) and gindices[0,nuniq) to buf[] */
178       i = j = k = 0;
179       while (i < n && j < nuniq) {
180         if (garray[i] < gindices[j]) buf[k++] = garray[i++];
181         else if (garray[i] > gindices[j]) buf[k++] = gindices[j++];
182         else {
183           buf[k++] = garray[i++];
184           j++;
185         }
186       }
187       /* Copy leftover in garray[] or gindices[] */
188       if (i < n) {
189         PetscCall(PetscArraycpy(buf + k, garray + i, n - i));
190         nuniq = k + n - i;
191       } else if (j < nuniq) {
192         PetscCall(PetscArraycpy(buf + k, gindices + j, nuniq - j));
193         nuniq = k + nuniq - j;
194       } else nuniq = k;
195       /* Swap gindices and buf to merge garray of the next matrix */
196       tmp      = gindices;
197       gindices = buf;
198       buf      = tmp;
199     }
200     PetscCall(PetscFree(buf));
201 
202     /* Go through matrices third time to build a map from gindices[] to garray[] */
203     tot = 0;
204     for (cur = shell->head, j = 0; cur; cur = cur->next, j++) { /* j-th matrix */
205       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray));
206       PetscCall(MatGetLocalSize(B, NULL, &n));
207       PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, 1, n, NULL, &shell->lvecs[j]));
208       /* This is an optimized PetscFindInt(garray[i],nuniq,gindices,&shell->location[tot+i]), using the fact that garray[] is also sorted */
209       lo = 0;
210       for (i = 0; i < n; i++) {
211         hi = nuniq;
212         while (hi - lo > 1) {
213           mid = lo + (hi - lo) / 2;
214           if (garray[i] < gindices[mid]) hi = mid;
215           else lo = mid;
216         }
217         shell->location[tot + i] = lo; /* gindices[lo] = garray[i] */
218         lo++;                          /* Since garray[i+1] > garray[i], we can safely advance lo */
219       }
220       tot += n;
221     }
222 
223     /* Build merged Mvctx */
224     PetscCall(ISCreateGeneral(PETSC_COMM_SELF, nuniq, gindices, PETSC_OWN_POINTER, &ix));
225     PetscCall(ISCreateStride(PETSC_COMM_SELF, nuniq, 0, 1, &iy));
226     PetscCall(VecCreateMPIWithArray(PetscObjectComm((PetscObject)mat), 1, mat->cmap->n, mat->cmap->N, NULL, &xin));
227     PetscCall(VecCreateSeq(PETSC_COMM_SELF, nuniq, &shell->gvec));
228     PetscCall(VecScatterCreate(xin, ix, shell->gvec, iy, &shell->Mvctx));
229     PetscCall(VecDestroy(&xin));
230     PetscCall(ISDestroy(&ix));
231     PetscCall(ISDestroy(&iy));
232   }
233 
234 skip_merge_mvctx:
235   PetscCall(VecSet(y, 0));
236   if (!((Mat_Shell *)mat->data)->left_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)mat->data)->left_work)));
237   y2 = ((Mat_Shell *)mat->data)->left_work;
238 
239   if (shell->Mvctx) { /* Have a merged Mvctx */
240     /* Suppose we want to compute y = sMx, where s is the scaling factor and A, B are matrix M's diagonal/off-diagonal part. We could do
241        in y = s(Ax1 + Bx2) or y = sAx1 + sBx2. The former incurs less FLOPS than the latter, but the latter provides an opportunity to
242        overlap communication/computation since we can do sAx1 while communicating x2. Here, we use the former approach.
243      */
244     PetscCall(VecScatterBegin(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD));
245     PetscCall(VecScatterEnd(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD));
246 
247     PetscCall(VecGetArrayRead(shell->gvec, &vals));
248     for (i = 0; i < shell->len; i++) shell->larray[i] = vals[shell->location[i]];
249     PetscCall(VecRestoreArrayRead(shell->gvec, &vals));
250 
251     for (cur = shell->head, tot = i = 0; cur; cur = cur->next, i++) { /* i-th matrix */
252       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, &A, &B, NULL));
253       PetscUseTypeMethod(A, mult, x, y2);
254       PetscCall(MatGetLocalSize(B, NULL, &n));
255       PetscCall(VecPlaceArray(shell->lvecs[i], &shell->larray[tot]));
256       PetscUseTypeMethod(B, multadd, shell->lvecs[i], y2, y2);
257       PetscCall(VecResetArray(shell->lvecs[i]));
258       PetscCall(VecAXPY(y, (shell->scalings ? shell->scalings[i] : 1.0), y2));
259       tot += n;
260     }
261   } else {
262     if (shell->scalings) {
263       for (cur = shell->head, i = 0; cur; cur = cur->next, i++) {
264         PetscCall(MatMult(cur->mat, x, y2));
265         PetscCall(VecAXPY(y, shell->scalings[i], y2));
266       }
267     } else {
268       for (cur = shell->head; cur; cur = cur->next) PetscCall(MatMultAdd(cur->mat, x, y, y));
269     }
270   }
271   PetscFunctionReturn(PETSC_SUCCESS);
272 }
273 
274 static PetscErrorCode MatMultTranspose_Composite(Mat A, Vec x, Vec y)
275 {
276   Mat_Composite    *shell;
277   Mat_CompositeLink next;
278   Vec               y2 = NULL;
279   PetscInt          i;
280 
281   PetscFunctionBegin;
282   PetscCall(MatShellGetContext(A, &shell));
283   next = shell->head;
284   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
285 
286   PetscCall(MatMultTranspose(next->mat, x, y));
287   if (shell->scalings) {
288     PetscCall(VecScale(y, shell->scalings[0]));
289     if (!((Mat_Shell *)A->data)->right_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)A->data)->right_work)));
290     y2 = ((Mat_Shell *)A->data)->right_work;
291   }
292   i = 1;
293   while ((next = next->next)) {
294     if (!shell->scalings) PetscCall(MatMultTransposeAdd(next->mat, x, y, y));
295     else {
296       PetscCall(MatMultTranspose(next->mat, x, y2));
297       PetscCall(VecAXPY(y, shell->scalings[i++], y2));
298     }
299   }
300   PetscFunctionReturn(PETSC_SUCCESS);
301 }
302 
303 static PetscErrorCode MatGetDiagonal_Composite(Mat A, Vec v)
304 {
305   Mat_Composite    *shell;
306   Mat_CompositeLink next;
307   PetscInt          i;
308 
309   PetscFunctionBegin;
310   PetscCall(MatShellGetContext(A, &shell));
311   next = shell->head;
312   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
313   PetscCall(MatGetDiagonal(next->mat, v));
314   if (shell->scalings) PetscCall(VecScale(v, shell->scalings[0]));
315 
316   if (next->next && !shell->work) PetscCall(VecDuplicate(v, &shell->work));
317   i = 1;
318   while ((next = next->next)) {
319     PetscCall(MatGetDiagonal(next->mat, shell->work));
320     PetscCall(VecAXPY(v, (shell->scalings ? shell->scalings[i++] : 1.0), shell->work));
321   }
322   PetscFunctionReturn(PETSC_SUCCESS);
323 }
324 
325 static PetscErrorCode MatAssemblyEnd_Composite(Mat Y, MatAssemblyType t)
326 {
327   Mat_Composite *shell;
328 
329   PetscFunctionBegin;
330   PetscCall(MatShellGetContext(Y, &shell));
331   if (shell->merge) PetscCall(MatCompositeMerge(Y));
332   else PetscCall(MatAssemblyEnd_Shell(Y, t));
333   PetscFunctionReturn(PETSC_SUCCESS);
334 }
335 
336 static PetscErrorCode MatSetFromOptions_Composite(Mat A, PetscOptionItems *PetscOptionsObject)
337 {
338   Mat_Composite *a;
339 
340   PetscFunctionBegin;
341   PetscCall(MatShellGetContext(A, &a));
342   PetscOptionsHeadBegin(PetscOptionsObject, "MATCOMPOSITE options");
343   PetscCall(PetscOptionsBool("-mat_composite_merge", "Merge at MatAssemblyEnd", "MatCompositeMerge", a->merge, &a->merge, NULL));
344   PetscCall(PetscOptionsEnum("-mat_composite_merge_type", "Set composite merge direction", "MatCompositeSetMergeType", MatCompositeMergeTypes, (PetscEnum)a->mergetype, (PetscEnum *)&a->mergetype, NULL));
345   PetscCall(PetscOptionsBool("-mat_composite_merge_mvctx", "Merge MatMult() vecscat contexts", "MatCreateComposite", a->merge_mvctx, &a->merge_mvctx, NULL));
346   PetscOptionsHeadEnd();
347   PetscFunctionReturn(PETSC_SUCCESS);
348 }
349 
350 /*@
351   MatCreateComposite - Creates a matrix as the sum or product of one or more matrices
352 
353   Collective
354 
355   Input Parameters:
356 + comm - MPI communicator
357 . nmat - number of matrices to put in
358 - mats - the matrices
359 
360   Output Parameter:
361 . mat - the matrix
362 
363   Options Database Keys:
364 + -mat_composite_merge       - merge in `MatAssemblyEnd()`
365 . -mat_composite_merge_mvctx - merge Mvctx of component matrices to optimize communication in `MatMult()` for ADDITIVE matrices
366 - -mat_composite_merge_type  - set merge direction
367 
368   Level: advanced
369 
370   Note:
371   Alternative construction
372 .vb
373        MatCreate(comm,&mat);
374        MatSetSizes(mat,m,n,M,N);
375        MatSetType(mat,MATCOMPOSITE);
376        MatCompositeAddMat(mat,mats[0]);
377        ....
378        MatCompositeAddMat(mat,mats[nmat-1]);
379        MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY);
380        MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY);
381 .ve
382 
383   For the multiplicative form the product is mat[nmat-1]*mat[nmat-2]*....*mat[0]
384 
385 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCompositeGetMat()`, `MatCompositeMerge()`, `MatCompositeSetType()`,
386           `MATCOMPOSITE`, `MatCompositeType`
387 @*/
388 PetscErrorCode MatCreateComposite(MPI_Comm comm, PetscInt nmat, const Mat *mats, Mat *mat)
389 {
390   PetscInt m, n, M, N, i;
391 
392   PetscFunctionBegin;
393   PetscCheck(nmat >= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Must pass in at least one matrix");
394   PetscAssertPointer(mat, 4);
395 
396   PetscCall(MatGetLocalSize(mats[0], PETSC_IGNORE, &n));
397   PetscCall(MatGetLocalSize(mats[nmat - 1], &m, PETSC_IGNORE));
398   PetscCall(MatGetSize(mats[0], PETSC_IGNORE, &N));
399   PetscCall(MatGetSize(mats[nmat - 1], &M, PETSC_IGNORE));
400   PetscCall(MatCreate(comm, mat));
401   PetscCall(MatSetSizes(*mat, m, n, M, N));
402   PetscCall(MatSetType(*mat, MATCOMPOSITE));
403   for (i = 0; i < nmat; i++) PetscCall(MatCompositeAddMat(*mat, mats[i]));
404   PetscCall(MatAssemblyBegin(*mat, MAT_FINAL_ASSEMBLY));
405   PetscCall(MatAssemblyEnd(*mat, MAT_FINAL_ASSEMBLY));
406   PetscFunctionReturn(PETSC_SUCCESS);
407 }
408 
409 static PetscErrorCode MatCompositeAddMat_Composite(Mat mat, Mat smat)
410 {
411   Mat_Composite    *shell;
412   Mat_CompositeLink ilink, next;
413   VecType           vtype_mat, vtype_smat;
414   PetscBool         match;
415 
416   PetscFunctionBegin;
417   PetscCall(MatShellGetContext(mat, &shell));
418   next = shell->head;
419   PetscCall(PetscNew(&ilink));
420   ilink->next = NULL;
421   PetscCall(PetscObjectReference((PetscObject)smat));
422   ilink->mat = smat;
423 
424   if (!next) shell->head = ilink;
425   else {
426     while (next->next) next = next->next;
427     next->next  = ilink;
428     ilink->prev = next;
429   }
430   shell->tail = ilink;
431   shell->nmat += 1;
432 
433   /* If all of the partial matrices have the same default vector type, then the composite matrix should also have this default type.
434      Otherwise, the default type should be "standard". */
435   PetscCall(MatGetVecType(smat, &vtype_smat));
436   if (shell->nmat == 1) PetscCall(MatSetVecType(mat, vtype_smat));
437   else {
438     PetscCall(MatGetVecType(mat, &vtype_mat));
439     PetscCall(PetscStrcmp(vtype_smat, vtype_mat, &match));
440     if (!match) PetscCall(MatSetVecType(mat, VECSTANDARD));
441   }
442 
443   /* Retain the old scalings (if any) and expand it with a 1.0 for the newly added matrix */
444   if (shell->scalings) {
445     PetscCall(PetscRealloc(sizeof(PetscScalar) * shell->nmat, &shell->scalings));
446     shell->scalings[shell->nmat - 1] = 1.0;
447   }
448   PetscFunctionReturn(PETSC_SUCCESS);
449 }
450 
451 /*@
452   MatCompositeAddMat - Add another matrix to a composite matrix.
453 
454   Collective
455 
456   Input Parameters:
457 + mat  - the composite matrix
458 - smat - the partial matrix
459 
460   Level: advanced
461 
462 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE`
463 @*/
464 PetscErrorCode MatCompositeAddMat(Mat mat, Mat smat)
465 {
466   PetscFunctionBegin;
467   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
468   PetscValidHeaderSpecific(smat, MAT_CLASSID, 2);
469   PetscUseMethod(mat, "MatCompositeAddMat_C", (Mat, Mat), (mat, smat));
470   PetscFunctionReturn(PETSC_SUCCESS);
471 }
472 
473 static PetscErrorCode MatCompositeSetType_Composite(Mat mat, MatCompositeType type)
474 {
475   Mat_Composite *b;
476 
477   PetscFunctionBegin;
478   PetscCall(MatShellGetContext(mat, &b));
479   b->type = type;
480   if (type == MAT_COMPOSITE_MULTIPLICATIVE) {
481     PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, NULL));
482     PetscCall(MatShellSetOperation(mat, MATOP_MULT, (void (*)(void))MatMult_Composite_Multiplicative));
483     PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite_Multiplicative));
484     b->merge_mvctx = PETSC_FALSE;
485   } else {
486     PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Composite));
487     PetscCall(MatShellSetOperation(mat, MATOP_MULT, (void (*)(void))MatMult_Composite));
488     PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite));
489   }
490   PetscFunctionReturn(PETSC_SUCCESS);
491 }
492 
493 /*@
494   MatCompositeSetType - Indicates if the matrix is defined as the sum of a set of matrices or the product.
495 
496   Logically Collective
497 
498   Input Parameters:
499 + mat  - the composite matrix
500 - type - the `MatCompositeType` to use for the matrix
501 
502   Level: advanced
503 
504 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeGetType()`, `MATCOMPOSITE`,
505           `MatCompositeType`
506 @*/
507 PetscErrorCode MatCompositeSetType(Mat mat, MatCompositeType type)
508 {
509   PetscFunctionBegin;
510   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
511   PetscValidLogicalCollectiveEnum(mat, type, 2);
512   PetscUseMethod(mat, "MatCompositeSetType_C", (Mat, MatCompositeType), (mat, type));
513   PetscFunctionReturn(PETSC_SUCCESS);
514 }
515 
516 static PetscErrorCode MatCompositeGetType_Composite(Mat mat, MatCompositeType *type)
517 {
518   Mat_Composite *shell;
519 
520   PetscFunctionBegin;
521   PetscCall(MatShellGetContext(mat, &shell));
522   *type = shell->type;
523   PetscFunctionReturn(PETSC_SUCCESS);
524 }
525 
526 /*@
527   MatCompositeGetType - Returns type of composite.
528 
529   Not Collective
530 
531   Input Parameter:
532 . mat - the composite matrix
533 
534   Output Parameter:
535 . type - type of composite
536 
537   Level: advanced
538 
539 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetType()`, `MATCOMPOSITE`, `MatCompositeType`
540 @*/
541 PetscErrorCode MatCompositeGetType(Mat mat, MatCompositeType *type)
542 {
543   PetscFunctionBegin;
544   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
545   PetscAssertPointer(type, 2);
546   PetscUseMethod(mat, "MatCompositeGetType_C", (Mat, MatCompositeType *), (mat, type));
547   PetscFunctionReturn(PETSC_SUCCESS);
548 }
549 
550 static PetscErrorCode MatCompositeSetMatStructure_Composite(Mat mat, MatStructure str)
551 {
552   Mat_Composite *shell;
553 
554   PetscFunctionBegin;
555   PetscCall(MatShellGetContext(mat, &shell));
556   shell->structure = str;
557   PetscFunctionReturn(PETSC_SUCCESS);
558 }
559 
560 /*@
561   MatCompositeSetMatStructure - Indicates structure of matrices in the composite matrix.
562 
563   Not Collective
564 
565   Input Parameters:
566 + mat - the composite matrix
567 - str - either `SAME_NONZERO_PATTERN`, `DIFFERENT_NONZERO_PATTERN` (default) or `SUBSET_NONZERO_PATTERN`
568 
569   Level: advanced
570 
571   Note:
572   Information about the matrices structure is used in `MatCompositeMerge()` for additive composite matrix.
573 
574 .seealso: [](ch_matrices), `Mat`, `MatAXPY()`, `MatCreateComposite()`, `MatCompositeMerge()` `MatCompositeGetMatStructure()`, `MATCOMPOSITE`
575 @*/
576 PetscErrorCode MatCompositeSetMatStructure(Mat mat, MatStructure str)
577 {
578   PetscFunctionBegin;
579   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
580   PetscUseMethod(mat, "MatCompositeSetMatStructure_C", (Mat, MatStructure), (mat, str));
581   PetscFunctionReturn(PETSC_SUCCESS);
582 }
583 
584 static PetscErrorCode MatCompositeGetMatStructure_Composite(Mat mat, MatStructure *str)
585 {
586   Mat_Composite *shell;
587 
588   PetscFunctionBegin;
589   PetscCall(MatShellGetContext(mat, &shell));
590   *str = shell->structure;
591   PetscFunctionReturn(PETSC_SUCCESS);
592 }
593 
594 /*@
595   MatCompositeGetMatStructure - Returns the structure of matrices in the composite matrix.
596 
597   Not Collective
598 
599   Input Parameter:
600 . mat - the composite matrix
601 
602   Output Parameter:
603 . str - structure of the matrices
604 
605   Level: advanced
606 
607 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MATCOMPOSITE`
608 @*/
609 PetscErrorCode MatCompositeGetMatStructure(Mat mat, MatStructure *str)
610 {
611   PetscFunctionBegin;
612   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
613   PetscAssertPointer(str, 2);
614   PetscUseMethod(mat, "MatCompositeGetMatStructure_C", (Mat, MatStructure *), (mat, str));
615   PetscFunctionReturn(PETSC_SUCCESS);
616 }
617 
618 static PetscErrorCode MatCompositeSetMergeType_Composite(Mat mat, MatCompositeMergeType type)
619 {
620   Mat_Composite *shell;
621 
622   PetscFunctionBegin;
623   PetscCall(MatShellGetContext(mat, &shell));
624   shell->mergetype = type;
625   PetscFunctionReturn(PETSC_SUCCESS);
626 }
627 
628 /*@
629   MatCompositeSetMergeType - Sets order of `MatCompositeMerge()`.
630 
631   Logically Collective
632 
633   Input Parameters:
634 + mat  - the composite matrix
635 - type - `MAT_COMPOSITE_MERGE RIGHT` (default) to start merge from right with the first added matrix (mat[0]),
636           `MAT_COMPOSITE_MERGE_LEFT` to start merge from left with the last added matrix (mat[nmat-1])
637 
638   Level: advanced
639 
640   Note:
641   The resulting matrix is the same regardless of the `MatCompositeMergeType`. Only the order of operation is changed.
642   If set to `MAT_COMPOSITE_MERGE_RIGHT` the order of the merge is mat[nmat-1]*(mat[nmat-2]*(...*(mat[1]*mat[0])))
643   otherwise the order is (((mat[nmat-1]*mat[nmat-2])*mat[nmat-3])*...)*mat[0].
644 
645 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeMerge()`, `MATCOMPOSITE`
646 @*/
647 PetscErrorCode MatCompositeSetMergeType(Mat mat, MatCompositeMergeType type)
648 {
649   PetscFunctionBegin;
650   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
651   PetscValidLogicalCollectiveEnum(mat, type, 2);
652   PetscUseMethod(mat, "MatCompositeSetMergeType_C", (Mat, MatCompositeMergeType), (mat, type));
653   PetscFunctionReturn(PETSC_SUCCESS);
654 }
655 
656 static PetscErrorCode MatCompositeMerge_Composite(Mat mat)
657 {
658   Mat_Composite    *shell;
659   Mat_CompositeLink next, prev;
660   Mat               tmat, newmat;
661   Vec               left, right, dshift;
662   PetscScalar       scale, shift;
663   PetscInt          i;
664 
665   PetscFunctionBegin;
666   PetscCall(MatShellGetContext(mat, &shell));
667   next = shell->head;
668   prev = shell->tail;
669   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
670   PetscCheck(!((Mat_Shell *)mat->data)->zrows && !((Mat_Shell *)mat->data)->zcols, PetscObjectComm((PetscObject)mat), PETSC_ERR_SUP, "Cannot call MatCompositeMerge() if MatZeroRows() or MatZeroRowsColumns() has been called on the input Mat"); // TODO FIXME: lift this limitation by calling MatZeroRows()/MatZeroRowsColumns() after the merge
671   PetscCheck(!((Mat_Shell *)mat->data)->axpy, PetscObjectComm((PetscObject)mat), PETSC_ERR_SUP, "Cannot call MatCompositeMerge() if MatAXPY() has been called on the input Mat"); // TODO FIXME: lift this limitation by calling MatAXPY() after the merge
672   scale = ((Mat_Shell *)mat->data)->vscale;
673   shift = ((Mat_Shell *)mat->data)->vshift;
674   if (shell->type == MAT_COMPOSITE_ADDITIVE) {
675     if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) {
676       i = 0;
677       PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat));
678       if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i++]));
679       while ((next = next->next)) PetscCall(MatAXPY(tmat, (shell->scalings ? shell->scalings[i++] : 1.0), next->mat, shell->structure));
680     } else {
681       i = shell->nmat - 1;
682       PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat));
683       if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i--]));
684       while ((prev = prev->prev)) PetscCall(MatAXPY(tmat, (shell->scalings ? shell->scalings[i--] : 1.0), prev->mat, shell->structure));
685     }
686   } else {
687     if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) {
688       PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat));
689       while ((next = next->next)) {
690         PetscCall(MatMatMult(next->mat, tmat, MAT_INITIAL_MATRIX, PETSC_DECIDE, &newmat));
691         PetscCall(MatDestroy(&tmat));
692         tmat = newmat;
693       }
694     } else {
695       PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat));
696       while ((prev = prev->prev)) {
697         PetscCall(MatMatMult(tmat, prev->mat, MAT_INITIAL_MATRIX, PETSC_DECIDE, &newmat));
698         PetscCall(MatDestroy(&tmat));
699         tmat = newmat;
700       }
701     }
702     if (shell->scalings) {
703       for (i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
704     }
705   }
706 
707   if ((left = ((Mat_Shell *)mat->data)->left)) PetscCall(PetscObjectReference((PetscObject)left));
708   if ((right = ((Mat_Shell *)mat->data)->right)) PetscCall(PetscObjectReference((PetscObject)right));
709   if ((dshift = ((Mat_Shell *)mat->data)->dshift)) PetscCall(PetscObjectReference((PetscObject)dshift));
710 
711   PetscCall(MatHeaderReplace(mat, &tmat));
712 
713   PetscCall(MatDiagonalScale(mat, left, right));
714   PetscCall(MatScale(mat, scale));
715   PetscCall(MatShift(mat, shift));
716   PetscCall(VecDestroy(&left));
717   PetscCall(VecDestroy(&right));
718   if (dshift) {
719     PetscCall(MatDiagonalSet(mat, dshift, ADD_VALUES));
720     PetscCall(VecDestroy(&dshift));
721   }
722   PetscFunctionReturn(PETSC_SUCCESS);
723 }
724 
725 /*@
726   MatCompositeMerge - Given a composite matrix, replaces it with a "regular" matrix
727   by summing or computing the product of all the matrices inside the composite matrix.
728 
729   Collective
730 
731   Input Parameter:
732 . mat - the composite matrix
733 
734   Options Database Keys:
735 + -mat_composite_merge      - merge in `MatAssemblyEnd()`
736 - -mat_composite_merge_type - set merge direction
737 
738   Level: advanced
739 
740   Note:
741   The `MatType` of the resulting matrix will be the same as the `MatType` of the FIRST matrix in the composite matrix.
742 
743 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MatCompositeSetMergeType()`, `MATCOMPOSITE`
744 @*/
745 PetscErrorCode MatCompositeMerge(Mat mat)
746 {
747   PetscFunctionBegin;
748   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
749   PetscUseMethod(mat, "MatCompositeMerge_C", (Mat), (mat));
750   PetscFunctionReturn(PETSC_SUCCESS);
751 }
752 
753 static PetscErrorCode MatCompositeGetNumberMat_Composite(Mat mat, PetscInt *nmat)
754 {
755   Mat_Composite *shell;
756 
757   PetscFunctionBegin;
758   PetscCall(MatShellGetContext(mat, &shell));
759   *nmat = shell->nmat;
760   PetscFunctionReturn(PETSC_SUCCESS);
761 }
762 
763 /*@
764   MatCompositeGetNumberMat - Returns the number of matrices in the composite matrix.
765 
766   Not Collective
767 
768   Input Parameter:
769 . mat - the composite matrix
770 
771   Output Parameter:
772 . nmat - number of matrices in the composite matrix
773 
774   Level: advanced
775 
776 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE`
777 @*/
778 PetscErrorCode MatCompositeGetNumberMat(Mat mat, PetscInt *nmat)
779 {
780   PetscFunctionBegin;
781   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
782   PetscAssertPointer(nmat, 2);
783   PetscUseMethod(mat, "MatCompositeGetNumberMat_C", (Mat, PetscInt *), (mat, nmat));
784   PetscFunctionReturn(PETSC_SUCCESS);
785 }
786 
787 static PetscErrorCode MatCompositeGetMat_Composite(Mat mat, PetscInt i, Mat *Ai)
788 {
789   Mat_Composite    *shell;
790   Mat_CompositeLink ilink;
791   PetscInt          k;
792 
793   PetscFunctionBegin;
794   PetscCall(MatShellGetContext(mat, &shell));
795   PetscCheck(i < shell->nmat, PetscObjectComm((PetscObject)mat), PETSC_ERR_ARG_OUTOFRANGE, "index out of range: %" PetscInt_FMT " >= %" PetscInt_FMT, i, shell->nmat);
796   ilink = shell->head;
797   for (k = 0; k < i; k++) ilink = ilink->next;
798   *Ai = ilink->mat;
799   PetscFunctionReturn(PETSC_SUCCESS);
800 }
801 
802 /*@
803   MatCompositeGetMat - Returns the ith matrix from the composite matrix.
804 
805   Logically Collective
806 
807   Input Parameters:
808 + mat - the composite matrix
809 - i   - the number of requested matrix
810 
811   Output Parameter:
812 . Ai - ith matrix in composite
813 
814   Level: advanced
815 
816 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetNumberMat()`, `MatCompositeAddMat()`, `MATCOMPOSITE`
817 @*/
818 PetscErrorCode MatCompositeGetMat(Mat mat, PetscInt i, Mat *Ai)
819 {
820   PetscFunctionBegin;
821   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
822   PetscValidLogicalCollectiveInt(mat, i, 2);
823   PetscAssertPointer(Ai, 3);
824   PetscUseMethod(mat, "MatCompositeGetMat_C", (Mat, PetscInt, Mat *), (mat, i, Ai));
825   PetscFunctionReturn(PETSC_SUCCESS);
826 }
827 
828 static PetscErrorCode MatCompositeSetScalings_Composite(Mat mat, const PetscScalar *scalings)
829 {
830   Mat_Composite *shell;
831   PetscInt       nmat;
832 
833   PetscFunctionBegin;
834   PetscCall(MatShellGetContext(mat, &shell));
835   PetscCall(MatCompositeGetNumberMat(mat, &nmat));
836   if (!shell->scalings) PetscCall(PetscMalloc1(nmat, &shell->scalings));
837   PetscCall(PetscArraycpy(shell->scalings, scalings, nmat));
838   PetscFunctionReturn(PETSC_SUCCESS);
839 }
840 
841 /*@
842   MatCompositeSetScalings - Sets separate scaling factors for component matrices.
843 
844   Logically Collective
845 
846   Input Parameters:
847 + mat      - the composite matrix
848 - scalings - array of scaling factors with scalings[i] being factor of i-th matrix, for i in [0, nmat)
849 
850   Level: advanced
851 
852 .seealso: [](ch_matrices), `Mat`, `MatScale()`, `MatDiagonalScale()`, `MATCOMPOSITE`
853 @*/
854 PetscErrorCode MatCompositeSetScalings(Mat mat, const PetscScalar *scalings)
855 {
856   PetscFunctionBegin;
857   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
858   PetscAssertPointer(scalings, 2);
859   PetscValidLogicalCollectiveScalar(mat, *scalings, 2);
860   PetscUseMethod(mat, "MatCompositeSetScalings_C", (Mat, const PetscScalar *), (mat, scalings));
861   PetscFunctionReturn(PETSC_SUCCESS);
862 }
863 
864 /*MC
865    MATCOMPOSITE - A matrix defined by the sum (or product) of one or more matrices.
866     The matrices need to have a correct size and parallel layout for the sum or product to be valid.
867 
868   Level: advanced
869 
870    Note:
871    To use the product of the matrices call `MatCompositeSetType`(mat,`MAT_COMPOSITE_MULTIPLICATIVE`);
872 
873   Developer Notes:
874   This is implemented on top of `MATSHELL` to get support for scaling and shifting without requiring duplicate code
875 
876   Users can not call `MatShellSetOperation()` operations on this class, there is some error checking for that incorrect usage
877 
878 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetScalings()`, `MatCompositeAddMat()`, `MatSetType()`, `MatCompositeSetType()`, `MatCompositeGetType()`,
879           `MatCompositeSetMatStructure()`, `MatCompositeGetMatStructure()`, `MatCompositeMerge()`, `MatCompositeSetMergeType()`, `MatCompositeGetNumberMat()`, `MatCompositeGetMat()`
880 M*/
881 
882 PETSC_EXTERN PetscErrorCode MatCreate_Composite(Mat A)
883 {
884   Mat_Composite *b;
885 
886   PetscFunctionBegin;
887   PetscCall(PetscNew(&b));
888 
889   b->type        = MAT_COMPOSITE_ADDITIVE;
890   b->nmat        = 0;
891   b->merge       = PETSC_FALSE;
892   b->mergetype   = MAT_COMPOSITE_MERGE_RIGHT;
893   b->structure   = DIFFERENT_NONZERO_PATTERN;
894   b->merge_mvctx = PETSC_TRUE;
895 
896   PetscCall(MatSetType(A, MATSHELL));
897   PetscCall(MatShellSetContext(A, b));
898   PetscCall(MatShellSetOperation(A, MATOP_DESTROY, (void (*)(void))MatDestroy_Composite));
899   PetscCall(MatShellSetOperation(A, MATOP_MULT, (void (*)(void))MatMult_Composite));
900   PetscCall(MatShellSetOperation(A, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite));
901   PetscCall(MatShellSetOperation(A, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Composite));
902   PetscCall(MatShellSetOperation(A, MATOP_ASSEMBLY_END, (void (*)(void))MatAssemblyEnd_Composite));
903   PetscCall(MatShellSetOperation(A, MATOP_SET_FROM_OPTIONS, (void (*)(void))MatSetFromOptions_Composite));
904   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeAddMat_C", MatCompositeAddMat_Composite));
905   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetType_C", MatCompositeSetType_Composite));
906   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetType_C", MatCompositeGetType_Composite));
907   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMergeType_C", MatCompositeSetMergeType_Composite));
908   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMatStructure_C", MatCompositeSetMatStructure_Composite));
909   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMatStructure_C", MatCompositeGetMatStructure_Composite));
910   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeMerge_C", MatCompositeMerge_Composite));
911   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetNumberMat_C", MatCompositeGetNumberMat_Composite));
912   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMat_C", MatCompositeGetMat_Composite));
913   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetScalings_C", MatCompositeSetScalings_Composite));
914   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContext_C", MatShellSetContext_Immutable));
915   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContextDestroy_C", MatShellSetContextDestroy_Immutable));
916   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetManageScalingShifts_C", MatShellSetManageScalingShifts_Immutable));
917   PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATCOMPOSITE));
918   PetscFunctionReturn(PETSC_SUCCESS);
919 }
920