xref: /petsc/src/mat/impls/composite/mcomposite.c (revision ccfb0f9f40a0131988d7995ed9679700dae2a75a)
1 #include <../src/mat/impls/shell/shell.h> /*I "petscmat.h" I*/
2 
3 const char *const MatCompositeMergeTypes[] = {"left", "right", "MatCompositeMergeType", "MAT_COMPOSITE_", NULL};
4 
5 typedef struct _Mat_CompositeLink *Mat_CompositeLink;
6 struct _Mat_CompositeLink {
7   Mat               mat;
8   Vec               work;
9   Mat_CompositeLink next, prev;
10 };
11 
12 typedef struct {
13   MatCompositeType      type;
14   Mat_CompositeLink     head, tail;
15   Vec                   work;
16   PetscInt              nmat;
17   PetscBool             merge;
18   MatCompositeMergeType mergetype;
19   MatStructure          structure;
20 
21   PetscScalar *scalings;
22   PetscBool    merge_mvctx; /* Whether need to merge mvctx of component matrices */
23   Vec         *lvecs;       /* [nmat] Basically, they are Mvctx->lvec of each component matrix */
24   PetscScalar *larray;      /* [len] Data arrays of lvecs[] are stored consecutively in larray */
25   PetscInt     len;         /* Length of larray[] */
26   Vec          gvec;        /* Union of lvecs[] without duplicated entries */
27   PetscInt    *location;    /* A map that maps entries in garray[] to larray[] */
28   VecScatter   Mvctx;
29 } Mat_Composite;
30 
31 static PetscErrorCode MatDestroy_Composite(Mat mat)
32 {
33   Mat_Composite    *shell;
34   Mat_CompositeLink next, oldnext;
35   PetscInt          i;
36 
37   PetscFunctionBegin;
38   PetscCall(MatShellGetContext(mat, &shell));
39   next = shell->head;
40   while (next) {
41     PetscCall(MatDestroy(&next->mat));
42     if (next->work && (!next->next || next->work != next->next->work)) PetscCall(VecDestroy(&next->work));
43     oldnext = next;
44     next    = next->next;
45     PetscCall(PetscFree(oldnext));
46   }
47   PetscCall(VecDestroy(&shell->work));
48 
49   if (shell->Mvctx) {
50     for (i = 0; i < shell->nmat; i++) PetscCall(VecDestroy(&shell->lvecs[i]));
51     PetscCall(PetscFree3(shell->location, shell->larray, shell->lvecs));
52     PetscCall(PetscFree(shell->larray));
53     PetscCall(VecDestroy(&shell->gvec));
54     PetscCall(VecScatterDestroy(&shell->Mvctx));
55   }
56 
57   PetscCall(PetscFree(shell->scalings));
58   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeAddMat_C", NULL));
59   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetType_C", NULL));
60   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetType_C", NULL));
61   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMergeType_C", NULL));
62   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMatStructure_C", NULL));
63   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMatStructure_C", NULL));
64   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeMerge_C", NULL));
65   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetNumberMat_C", NULL));
66   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMat_C", NULL));
67   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetScalings_C", NULL));
68   PetscCall(PetscFree(shell));
69   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatShellSetContext_C", NULL)); // needed to avoid a call to MatShellSetContext_Immutable()
70   PetscFunctionReturn(PETSC_SUCCESS);
71 }
72 
73 static PetscErrorCode MatMult_Composite_Multiplicative(Mat A, Vec x, Vec y)
74 {
75   Mat_Composite    *shell;
76   Mat_CompositeLink next;
77   Vec               out;
78 
79   PetscFunctionBegin;
80   PetscCall(MatShellGetContext(A, &shell));
81   next = shell->head;
82   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
83   while (next->next) {
84     if (!next->work) { /* should reuse previous work if the same size */
85       PetscCall(MatCreateVecs(next->mat, NULL, &next->work));
86     }
87     out = next->work;
88     PetscCall(MatMult(next->mat, x, out));
89     x    = out;
90     next = next->next;
91   }
92   PetscCall(MatMult(next->mat, x, y));
93   if (shell->scalings) {
94     PetscScalar scale = 1.0;
95     for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
96     PetscCall(VecScale(y, scale));
97   }
98   PetscFunctionReturn(PETSC_SUCCESS);
99 }
100 
101 static PetscErrorCode MatMultTranspose_Composite_Multiplicative(Mat A, Vec x, Vec y)
102 {
103   Mat_Composite    *shell;
104   Mat_CompositeLink tail;
105   Vec               out;
106 
107   PetscFunctionBegin;
108   PetscCall(MatShellGetContext(A, &shell));
109   tail = shell->tail;
110   PetscCheck(tail, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
111   while (tail->prev) {
112     if (!tail->prev->work) { /* should reuse previous work if the same size */
113       PetscCall(MatCreateVecs(tail->mat, NULL, &tail->prev->work));
114     }
115     out = tail->prev->work;
116     PetscCall(MatMultTranspose(tail->mat, x, out));
117     x    = out;
118     tail = tail->prev;
119   }
120   PetscCall(MatMultTranspose(tail->mat, x, y));
121   if (shell->scalings) {
122     PetscScalar scale = 1.0;
123     for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
124     PetscCall(VecScale(y, scale));
125   }
126   PetscFunctionReturn(PETSC_SUCCESS);
127 }
128 
129 static PetscErrorCode MatMult_Composite(Mat mat, Vec x, Vec y)
130 {
131   Mat_Composite     *shell;
132   Mat_CompositeLink  cur;
133   Vec                y2, xin;
134   Mat                A, B;
135   PetscInt           i, j, k, n, nuniq, lo, hi, mid, *gindices, *buf, *tmp, tot;
136   const PetscScalar *vals;
137   const PetscInt    *garray;
138   IS                 ix, iy;
139   PetscBool          match;
140 
141   PetscFunctionBegin;
142   PetscCall(MatShellGetContext(mat, &shell));
143   cur = shell->head;
144   PetscCheck(cur, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
145 
146   /* Try to merge Mvctx when instructed but not yet done. We did not do it in MatAssemblyEnd() since at that time
147      we did not know whether mat is ADDITIVE or MULTIPLICATIVE. Only now we are assured mat is ADDITIVE and
148      it is legal to merge Mvctx, because all component matrices have the same size.
149    */
150   if (shell->merge_mvctx && !shell->Mvctx) {
151     /* Currently only implemented for MATMPIAIJ */
152     for (cur = shell->head; cur; cur = cur->next) {
153       PetscCall(PetscObjectTypeCompare((PetscObject)cur->mat, MATMPIAIJ, &match));
154       if (!match) {
155         shell->merge_mvctx = PETSC_FALSE;
156         goto skip_merge_mvctx;
157       }
158     }
159 
160     /* Go through matrices first time to count total number of nonzero off-diag columns (may have dups) */
161     tot = 0;
162     for (cur = shell->head; cur; cur = cur->next) {
163       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, NULL));
164       PetscCall(MatGetLocalSize(B, NULL, &n));
165       tot += n;
166     }
167     PetscCall(PetscMalloc3(tot, &shell->location, tot, &shell->larray, shell->nmat, &shell->lvecs));
168     shell->len = tot;
169 
170     /* Go through matrices second time to sort off-diag columns and remove dups */
171     PetscCall(PetscMalloc1(tot, &gindices)); /* No Malloc2() since we will give one to PETSc and free the other */
172     PetscCall(PetscMalloc1(tot, &buf));
173     nuniq = 0; /* Number of unique nonzero columns */
174     for (cur = shell->head; cur; cur = cur->next) {
175       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray));
176       PetscCall(MatGetLocalSize(B, NULL, &n));
177       /* Merge pre-sorted garray[0,n) and gindices[0,nuniq) to buf[] */
178       i = j = k = 0;
179       while (i < n && j < nuniq) {
180         if (garray[i] < gindices[j]) buf[k++] = garray[i++];
181         else if (garray[i] > gindices[j]) buf[k++] = gindices[j++];
182         else {
183           buf[k++] = garray[i++];
184           j++;
185         }
186       }
187       /* Copy leftover in garray[] or gindices[] */
188       if (i < n) {
189         PetscCall(PetscArraycpy(buf + k, garray + i, n - i));
190         nuniq = k + n - i;
191       } else if (j < nuniq) {
192         PetscCall(PetscArraycpy(buf + k, gindices + j, nuniq - j));
193         nuniq = k + nuniq - j;
194       } else nuniq = k;
195       /* Swap gindices and buf to merge garray of the next matrix */
196       tmp      = gindices;
197       gindices = buf;
198       buf      = tmp;
199     }
200     PetscCall(PetscFree(buf));
201 
202     /* Go through matrices third time to build a map from gindices[] to garray[] */
203     tot = 0;
204     for (cur = shell->head, j = 0; cur; cur = cur->next, j++) { /* j-th matrix */
205       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray));
206       PetscCall(MatGetLocalSize(B, NULL, &n));
207       PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, 1, n, NULL, &shell->lvecs[j]));
208       /* This is an optimized PetscFindInt(garray[i],nuniq,gindices,&shell->location[tot+i]), using the fact that garray[] is also sorted */
209       lo = 0;
210       for (i = 0; i < n; i++) {
211         hi = nuniq;
212         while (hi - lo > 1) {
213           mid = lo + (hi - lo) / 2;
214           if (garray[i] < gindices[mid]) hi = mid;
215           else lo = mid;
216         }
217         shell->location[tot + i] = lo; /* gindices[lo] = garray[i] */
218         lo++;                          /* Since garray[i+1] > garray[i], we can safely advance lo */
219       }
220       tot += n;
221     }
222 
223     /* Build merged Mvctx */
224     PetscCall(ISCreateGeneral(PETSC_COMM_SELF, nuniq, gindices, PETSC_OWN_POINTER, &ix));
225     PetscCall(ISCreateStride(PETSC_COMM_SELF, nuniq, 0, 1, &iy));
226     PetscCall(VecCreateMPIWithArray(PetscObjectComm((PetscObject)mat), 1, mat->cmap->n, mat->cmap->N, NULL, &xin));
227     PetscCall(VecCreateSeq(PETSC_COMM_SELF, nuniq, &shell->gvec));
228     PetscCall(VecScatterCreate(xin, ix, shell->gvec, iy, &shell->Mvctx));
229     PetscCall(VecDestroy(&xin));
230     PetscCall(ISDestroy(&ix));
231     PetscCall(ISDestroy(&iy));
232   }
233 
234 skip_merge_mvctx:
235   PetscCall(VecSet(y, 0));
236   if (!((Mat_Shell *)mat->data)->left_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)mat->data)->left_work)));
237   y2 = ((Mat_Shell *)mat->data)->left_work;
238 
239   if (shell->Mvctx) { /* Have a merged Mvctx */
240     /* Suppose we want to compute y = sMx, where s is the scaling factor and A, B are matrix M's diagonal/off-diagonal part. We could do
241        in y = s(Ax1 + Bx2) or y = sAx1 + sBx2. The former incurs less FLOPS than the latter, but the latter provides an opportunity to
242        overlap communication/computation since we can do sAx1 while communicating x2. Here, we use the former approach.
243      */
244     PetscCall(VecScatterBegin(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD));
245     PetscCall(VecScatterEnd(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD));
246 
247     PetscCall(VecGetArrayRead(shell->gvec, &vals));
248     for (i = 0; i < shell->len; i++) shell->larray[i] = vals[shell->location[i]];
249     PetscCall(VecRestoreArrayRead(shell->gvec, &vals));
250 
251     for (cur = shell->head, tot = i = 0; cur; cur = cur->next, i++) { /* i-th matrix */
252       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, &A, &B, NULL));
253       PetscUseTypeMethod(A, mult, x, y2);
254       PetscCall(MatGetLocalSize(B, NULL, &n));
255       PetscCall(VecPlaceArray(shell->lvecs[i], &shell->larray[tot]));
256       PetscUseTypeMethod(B, multadd, shell->lvecs[i], y2, y2);
257       PetscCall(VecResetArray(shell->lvecs[i]));
258       PetscCall(VecAXPY(y, shell->scalings ? shell->scalings[i] : 1.0, y2));
259       tot += n;
260     }
261   } else {
262     if (shell->scalings) {
263       for (cur = shell->head, i = 0; cur; cur = cur->next, i++) {
264         PetscCall(MatMult(cur->mat, x, y2));
265         PetscCall(VecAXPY(y, shell->scalings[i], y2));
266       }
267     } else {
268       for (cur = shell->head; cur; cur = cur->next) PetscCall(MatMultAdd(cur->mat, x, y, y));
269     }
270   }
271   PetscFunctionReturn(PETSC_SUCCESS);
272 }
273 
274 static PetscErrorCode MatMultTranspose_Composite(Mat A, Vec x, Vec y)
275 {
276   Mat_Composite    *shell;
277   Mat_CompositeLink next;
278   Vec               y2 = NULL;
279   PetscInt          i;
280 
281   PetscFunctionBegin;
282   PetscCall(MatShellGetContext(A, &shell));
283   next = shell->head;
284   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
285 
286   PetscCall(MatMultTranspose(next->mat, x, y));
287   if (shell->scalings) {
288     PetscCall(VecScale(y, shell->scalings[0]));
289     if (!((Mat_Shell *)A->data)->right_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)A->data)->right_work)));
290     y2 = ((Mat_Shell *)A->data)->right_work;
291   }
292   i = 1;
293   while ((next = next->next)) {
294     if (!shell->scalings) PetscCall(MatMultTransposeAdd(next->mat, x, y, y));
295     else {
296       PetscCall(MatMultTranspose(next->mat, x, y2));
297       PetscCall(VecAXPY(y, shell->scalings[i++], y2));
298     }
299   }
300   PetscFunctionReturn(PETSC_SUCCESS);
301 }
302 
303 static PetscErrorCode MatGetDiagonal_Composite(Mat A, Vec v)
304 {
305   Mat_Composite    *shell;
306   Mat_CompositeLink next;
307   PetscInt          i;
308 
309   PetscFunctionBegin;
310   PetscCall(MatShellGetContext(A, &shell));
311   next = shell->head;
312   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
313   PetscCall(MatGetDiagonal(next->mat, v));
314   if (shell->scalings) PetscCall(VecScale(v, shell->scalings[0]));
315 
316   if (next->next && !shell->work) PetscCall(VecDuplicate(v, &shell->work));
317   i = 1;
318   while ((next = next->next)) {
319     PetscCall(MatGetDiagonal(next->mat, shell->work));
320     PetscCall(VecAXPY(v, shell->scalings ? shell->scalings[i++] : 1.0, shell->work));
321   }
322   PetscFunctionReturn(PETSC_SUCCESS);
323 }
324 
325 static PetscErrorCode MatAssemblyEnd_Composite(Mat Y, MatAssemblyType t)
326 {
327   Mat_Composite *shell;
328 
329   PetscFunctionBegin;
330   PetscCall(MatShellGetContext(Y, &shell));
331   if (shell->merge) PetscCall(MatCompositeMerge(Y));
332   else PetscCall(MatAssemblyEnd_Shell(Y, t));
333   PetscFunctionReturn(PETSC_SUCCESS);
334 }
335 
336 static PetscErrorCode MatSetFromOptions_Composite(Mat A, PetscOptionItems PetscOptionsObject)
337 {
338   Mat_Composite *a;
339 
340   PetscFunctionBegin;
341   PetscCall(MatShellGetContext(A, &a));
342   PetscOptionsHeadBegin(PetscOptionsObject, "MATCOMPOSITE options");
343   PetscCall(PetscOptionsBool("-mat_composite_merge", "Merge at MatAssemblyEnd", "MatCompositeMerge", a->merge, &a->merge, NULL));
344   PetscCall(PetscOptionsEnum("-mat_composite_merge_type", "Set composite merge direction", "MatCompositeSetMergeType", MatCompositeMergeTypes, (PetscEnum)a->mergetype, (PetscEnum *)&a->mergetype, NULL));
345   PetscCall(PetscOptionsBool("-mat_composite_merge_mvctx", "Merge MatMult() vecscat contexts", "MatCreateComposite", a->merge_mvctx, &a->merge_mvctx, NULL));
346   PetscOptionsHeadEnd();
347   PetscFunctionReturn(PETSC_SUCCESS);
348 }
349 
350 /*@
351   MatCreateComposite - Creates a matrix as the sum or product of one or more matrices
352 
353   Collective
354 
355   Input Parameters:
356 + comm - MPI communicator
357 . nmat - number of matrices to put in
358 - mats - the matrices
359 
360   Output Parameter:
361 . mat - the matrix
362 
363   Options Database Keys:
364 + -mat_composite_merge       - merge in `MatAssemblyEnd()`
365 . -mat_composite_merge_mvctx - merge Mvctx of component matrices to optimize communication in `MatMult()` for ADDITIVE matrices
366 - -mat_composite_merge_type  - set merge direction
367 
368   Level: advanced
369 
370   Note:
371   Alternative construction
372 .vb
373        MatCreate(comm,&mat);
374        MatSetSizes(mat,m,n,M,N);
375        MatSetType(mat,MATCOMPOSITE);
376        MatCompositeAddMat(mat,mats[0]);
377        ....
378        MatCompositeAddMat(mat,mats[nmat-1]);
379        MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY);
380        MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY);
381 .ve
382 
383   For the multiplicative form the product is mat[nmat-1]*mat[nmat-2]*....*mat[0]
384 
385 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCompositeGetMat()`, `MatCompositeMerge()`, `MatCompositeSetType()`,
386           `MATCOMPOSITE`, `MatCompositeType`
387 @*/
388 PetscErrorCode MatCreateComposite(MPI_Comm comm, PetscInt nmat, const Mat *mats, Mat *mat)
389 {
390   PetscFunctionBegin;
391   PetscCheck(nmat >= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Must pass in at least one matrix");
392   PetscAssertPointer(mat, 4);
393   PetscCall(MatCreate(comm, mat));
394   PetscCall(MatSetType(*mat, MATCOMPOSITE));
395   for (PetscInt i = 0; i < nmat; i++) PetscCall(MatCompositeAddMat(*mat, mats[i]));
396   PetscCall(MatAssemblyBegin(*mat, MAT_FINAL_ASSEMBLY));
397   PetscCall(MatAssemblyEnd(*mat, MAT_FINAL_ASSEMBLY));
398   PetscFunctionReturn(PETSC_SUCCESS);
399 }
400 
401 static PetscErrorCode MatCompositeAddMat_Composite(Mat mat, Mat smat)
402 {
403   Mat_Composite    *shell;
404   Mat_CompositeLink ilink, next;
405   VecType           vtype_mat, vtype_smat;
406   PetscBool         match;
407 
408   PetscFunctionBegin;
409   PetscCall(MatShellGetContext(mat, &shell));
410   next = shell->head;
411   PetscCall(PetscNew(&ilink));
412   ilink->next = NULL;
413   PetscCall(PetscObjectReference((PetscObject)smat));
414   ilink->mat = smat;
415 
416   if (!next) shell->head = ilink;
417   else {
418     while (next->next) next = next->next;
419     next->next  = ilink;
420     ilink->prev = next;
421   }
422   shell->tail = ilink;
423   shell->nmat += 1;
424 
425   /* If all of the partial matrices have the same default vector type, then the composite matrix should also have this default type.
426      Otherwise, the default type should be "standard". */
427   PetscCall(MatGetVecType(smat, &vtype_smat));
428   if (shell->nmat == 1) PetscCall(MatSetVecType(mat, vtype_smat));
429   else {
430     PetscCall(MatGetVecType(mat, &vtype_mat));
431     PetscCall(PetscStrcmp(vtype_smat, vtype_mat, &match));
432     if (!match) PetscCall(MatSetVecType(mat, VECSTANDARD));
433   }
434 
435   /* Retain the old scalings (if any) and expand it with a 1.0 for the newly added matrix */
436   if (shell->scalings) {
437     PetscCall(PetscRealloc(sizeof(PetscScalar) * shell->nmat, &shell->scalings));
438     shell->scalings[shell->nmat - 1] = 1.0;
439   }
440 
441   /* The composite matrix requires PetscLayouts for its rows and columns; we copy these from the constituent partial matrices. */
442   if (shell->nmat == 1) PetscCall(PetscLayoutReference(smat->cmap, &mat->cmap));
443   PetscCall(PetscLayoutReference(smat->rmap, &mat->rmap));
444   PetscFunctionReturn(PETSC_SUCCESS);
445 }
446 
447 /*@
448   MatCompositeAddMat - Add another matrix to a composite matrix.
449 
450   Collective
451 
452   Input Parameters:
453 + mat  - the composite matrix
454 - smat - the partial matrix
455 
456   Level: advanced
457 
458 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE`
459 @*/
460 PetscErrorCode MatCompositeAddMat(Mat mat, Mat smat)
461 {
462   PetscFunctionBegin;
463   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
464   PetscValidHeaderSpecific(smat, MAT_CLASSID, 2);
465   PetscUseMethod(mat, "MatCompositeAddMat_C", (Mat, Mat), (mat, smat));
466   PetscFunctionReturn(PETSC_SUCCESS);
467 }
468 
469 static PetscErrorCode MatCompositeSetType_Composite(Mat mat, MatCompositeType type)
470 {
471   Mat_Composite *b;
472 
473   PetscFunctionBegin;
474   PetscCall(MatShellGetContext(mat, &b));
475   b->type = type;
476   if (type == MAT_COMPOSITE_MULTIPLICATIVE) {
477     PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, NULL));
478     PetscCall(MatShellSetOperation(mat, MATOP_MULT, (PetscErrorCodeFn *)MatMult_Composite_Multiplicative));
479     PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (PetscErrorCodeFn *)MatMultTranspose_Composite_Multiplicative));
480     b->merge_mvctx = PETSC_FALSE;
481   } else {
482     PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, (PetscErrorCodeFn *)MatGetDiagonal_Composite));
483     PetscCall(MatShellSetOperation(mat, MATOP_MULT, (PetscErrorCodeFn *)MatMult_Composite));
484     PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (PetscErrorCodeFn *)MatMultTranspose_Composite));
485   }
486   PetscFunctionReturn(PETSC_SUCCESS);
487 }
488 
489 /*@
490   MatCompositeSetType - Indicates if the matrix is defined as the sum of a set of matrices or the product.
491 
492   Logically Collective
493 
494   Input Parameters:
495 + mat  - the composite matrix
496 - type - the `MatCompositeType` to use for the matrix
497 
498   Level: advanced
499 
500 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeGetType()`, `MATCOMPOSITE`,
501           `MatCompositeType`
502 @*/
503 PetscErrorCode MatCompositeSetType(Mat mat, MatCompositeType type)
504 {
505   PetscFunctionBegin;
506   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
507   PetscValidLogicalCollectiveEnum(mat, type, 2);
508   PetscUseMethod(mat, "MatCompositeSetType_C", (Mat, MatCompositeType), (mat, type));
509   PetscFunctionReturn(PETSC_SUCCESS);
510 }
511 
512 static PetscErrorCode MatCompositeGetType_Composite(Mat mat, MatCompositeType *type)
513 {
514   Mat_Composite *shell;
515 
516   PetscFunctionBegin;
517   PetscCall(MatShellGetContext(mat, &shell));
518   *type = shell->type;
519   PetscFunctionReturn(PETSC_SUCCESS);
520 }
521 
522 /*@
523   MatCompositeGetType - Returns type of composite.
524 
525   Not Collective
526 
527   Input Parameter:
528 . mat - the composite matrix
529 
530   Output Parameter:
531 . type - type of composite
532 
533   Level: advanced
534 
535 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetType()`, `MATCOMPOSITE`, `MatCompositeType`
536 @*/
537 PetscErrorCode MatCompositeGetType(Mat mat, MatCompositeType *type)
538 {
539   PetscFunctionBegin;
540   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
541   PetscAssertPointer(type, 2);
542   PetscUseMethod(mat, "MatCompositeGetType_C", (Mat, MatCompositeType *), (mat, type));
543   PetscFunctionReturn(PETSC_SUCCESS);
544 }
545 
546 static PetscErrorCode MatCompositeSetMatStructure_Composite(Mat mat, MatStructure str)
547 {
548   Mat_Composite *shell;
549 
550   PetscFunctionBegin;
551   PetscCall(MatShellGetContext(mat, &shell));
552   shell->structure = str;
553   PetscFunctionReturn(PETSC_SUCCESS);
554 }
555 
556 /*@
557   MatCompositeSetMatStructure - Indicates structure of matrices in the composite matrix.
558 
559   Not Collective
560 
561   Input Parameters:
562 + mat - the composite matrix
563 - str - either `SAME_NONZERO_PATTERN`, `DIFFERENT_NONZERO_PATTERN` (default) or `SUBSET_NONZERO_PATTERN`
564 
565   Level: advanced
566 
567   Note:
568   Information about the matrices structure is used in `MatCompositeMerge()` for additive composite matrix.
569 
570 .seealso: [](ch_matrices), `Mat`, `MatAXPY()`, `MatCreateComposite()`, `MatCompositeMerge()` `MatCompositeGetMatStructure()`, `MATCOMPOSITE`
571 @*/
572 PetscErrorCode MatCompositeSetMatStructure(Mat mat, MatStructure str)
573 {
574   PetscFunctionBegin;
575   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
576   PetscUseMethod(mat, "MatCompositeSetMatStructure_C", (Mat, MatStructure), (mat, str));
577   PetscFunctionReturn(PETSC_SUCCESS);
578 }
579 
580 static PetscErrorCode MatCompositeGetMatStructure_Composite(Mat mat, MatStructure *str)
581 {
582   Mat_Composite *shell;
583 
584   PetscFunctionBegin;
585   PetscCall(MatShellGetContext(mat, &shell));
586   *str = shell->structure;
587   PetscFunctionReturn(PETSC_SUCCESS);
588 }
589 
590 /*@
591   MatCompositeGetMatStructure - Returns the structure of matrices in the composite matrix.
592 
593   Not Collective
594 
595   Input Parameter:
596 . mat - the composite matrix
597 
598   Output Parameter:
599 . str - structure of the matrices
600 
601   Level: advanced
602 
603 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MATCOMPOSITE`
604 @*/
605 PetscErrorCode MatCompositeGetMatStructure(Mat mat, MatStructure *str)
606 {
607   PetscFunctionBegin;
608   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
609   PetscAssertPointer(str, 2);
610   PetscUseMethod(mat, "MatCompositeGetMatStructure_C", (Mat, MatStructure *), (mat, str));
611   PetscFunctionReturn(PETSC_SUCCESS);
612 }
613 
614 static PetscErrorCode MatCompositeSetMergeType_Composite(Mat mat, MatCompositeMergeType type)
615 {
616   Mat_Composite *shell;
617 
618   PetscFunctionBegin;
619   PetscCall(MatShellGetContext(mat, &shell));
620   shell->mergetype = type;
621   PetscFunctionReturn(PETSC_SUCCESS);
622 }
623 
624 /*@
625   MatCompositeSetMergeType - Sets order of `MatCompositeMerge()`.
626 
627   Logically Collective
628 
629   Input Parameters:
630 + mat  - the composite matrix
631 - type - `MAT_COMPOSITE_MERGE RIGHT` (default) to start merge from right with the first added matrix (mat[0]),
632           `MAT_COMPOSITE_MERGE_LEFT` to start merge from left with the last added matrix (mat[nmat-1])
633 
634   Level: advanced
635 
636   Note:
637   The resulting matrix is the same regardless of the `MatCompositeMergeType`. Only the order of operation is changed.
638   If set to `MAT_COMPOSITE_MERGE_RIGHT` the order of the merge is mat[nmat-1]*(mat[nmat-2]*(...*(mat[1]*mat[0])))
639   otherwise the order is (((mat[nmat-1]*mat[nmat-2])*mat[nmat-3])*...)*mat[0].
640 
641 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeMerge()`, `MATCOMPOSITE`
642 @*/
643 PetscErrorCode MatCompositeSetMergeType(Mat mat, MatCompositeMergeType type)
644 {
645   PetscFunctionBegin;
646   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
647   PetscValidLogicalCollectiveEnum(mat, type, 2);
648   PetscUseMethod(mat, "MatCompositeSetMergeType_C", (Mat, MatCompositeMergeType), (mat, type));
649   PetscFunctionReturn(PETSC_SUCCESS);
650 }
651 
652 static PetscErrorCode MatCompositeMerge_Composite(Mat mat)
653 {
654   Mat_Composite    *shell;
655   Mat_CompositeLink next, prev;
656   Mat               tmat, newmat;
657   Vec               left, right, dshift;
658   PetscScalar       scale, shift;
659   PetscInt          i;
660 
661   PetscFunctionBegin;
662   PetscCall(MatShellGetContext(mat, &shell));
663   next = shell->head;
664   prev = shell->tail;
665   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
666   PetscCall(MatShellGetScalingShifts(mat, &shift, &scale, &dshift, &left, &right, (Mat *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED));
667   if (shell->type == MAT_COMPOSITE_ADDITIVE) {
668     if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) {
669       i = 0;
670       PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat));
671       if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i++]));
672       while ((next = next->next)) PetscCall(MatAXPY(tmat, shell->scalings ? shell->scalings[i++] : 1.0, next->mat, shell->structure));
673     } else {
674       i = shell->nmat - 1;
675       PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat));
676       if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i--]));
677       while ((prev = prev->prev)) PetscCall(MatAXPY(tmat, shell->scalings ? shell->scalings[i--] : 1.0, prev->mat, shell->structure));
678     }
679   } else {
680     if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) {
681       PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat));
682       while ((next = next->next)) {
683         PetscCall(MatMatMult(next->mat, tmat, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &newmat));
684         PetscCall(MatDestroy(&tmat));
685         tmat = newmat;
686       }
687     } else {
688       PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat));
689       while ((prev = prev->prev)) {
690         PetscCall(MatMatMult(tmat, prev->mat, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &newmat));
691         PetscCall(MatDestroy(&tmat));
692         tmat = newmat;
693       }
694     }
695     if (shell->scalings) {
696       for (i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
697     }
698   }
699 
700   if (left) PetscCall(PetscObjectReference((PetscObject)left));
701   if (right) PetscCall(PetscObjectReference((PetscObject)right));
702   if (dshift) PetscCall(PetscObjectReference((PetscObject)dshift));
703 
704   PetscCall(MatHeaderReplace(mat, &tmat));
705 
706   PetscCall(MatDiagonalScale(mat, left, right));
707   PetscCall(MatScale(mat, scale));
708   PetscCall(MatShift(mat, shift));
709   PetscCall(VecDestroy(&left));
710   PetscCall(VecDestroy(&right));
711   if (dshift) {
712     PetscCall(MatDiagonalSet(mat, dshift, ADD_VALUES));
713     PetscCall(VecDestroy(&dshift));
714   }
715   PetscFunctionReturn(PETSC_SUCCESS);
716 }
717 
718 /*@
719   MatCompositeMerge - Given a composite matrix, replaces it with a "regular" matrix
720   by summing or computing the product of all the matrices inside the composite matrix.
721 
722   Collective
723 
724   Input Parameter:
725 . mat - the composite matrix
726 
727   Options Database Keys:
728 + -mat_composite_merge      - merge in `MatAssemblyEnd()`
729 - -mat_composite_merge_type - set merge direction
730 
731   Level: advanced
732 
733   Note:
734   The `MatType` of the resulting matrix will be the same as the `MatType` of the FIRST matrix in the composite matrix.
735 
736 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MatCompositeSetMergeType()`, `MATCOMPOSITE`
737 @*/
738 PetscErrorCode MatCompositeMerge(Mat mat)
739 {
740   PetscFunctionBegin;
741   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
742   PetscUseMethod(mat, "MatCompositeMerge_C", (Mat), (mat));
743   PetscFunctionReturn(PETSC_SUCCESS);
744 }
745 
746 static PetscErrorCode MatCompositeGetNumberMat_Composite(Mat mat, PetscInt *nmat)
747 {
748   Mat_Composite *shell;
749 
750   PetscFunctionBegin;
751   PetscCall(MatShellGetContext(mat, &shell));
752   *nmat = shell->nmat;
753   PetscFunctionReturn(PETSC_SUCCESS);
754 }
755 
756 /*@
757   MatCompositeGetNumberMat - Returns the number of matrices in the composite matrix.
758 
759   Not Collective
760 
761   Input Parameter:
762 . mat - the composite matrix
763 
764   Output Parameter:
765 . nmat - number of matrices in the composite matrix
766 
767   Level: advanced
768 
769 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE`
770 @*/
771 PetscErrorCode MatCompositeGetNumberMat(Mat mat, PetscInt *nmat)
772 {
773   PetscFunctionBegin;
774   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
775   PetscAssertPointer(nmat, 2);
776   PetscUseMethod(mat, "MatCompositeGetNumberMat_C", (Mat, PetscInt *), (mat, nmat));
777   PetscFunctionReturn(PETSC_SUCCESS);
778 }
779 
780 static PetscErrorCode MatCompositeGetMat_Composite(Mat mat, PetscInt i, Mat *Ai)
781 {
782   Mat_Composite    *shell;
783   Mat_CompositeLink ilink;
784   PetscInt          k;
785 
786   PetscFunctionBegin;
787   PetscCall(MatShellGetContext(mat, &shell));
788   PetscCheck(i < shell->nmat, PetscObjectComm((PetscObject)mat), PETSC_ERR_ARG_OUTOFRANGE, "index out of range: %" PetscInt_FMT " >= %" PetscInt_FMT, i, shell->nmat);
789   ilink = shell->head;
790   for (k = 0; k < i; k++) ilink = ilink->next;
791   *Ai = ilink->mat;
792   PetscFunctionReturn(PETSC_SUCCESS);
793 }
794 
795 /*@
796   MatCompositeGetMat - Returns the ith matrix from the composite matrix.
797 
798   Logically Collective
799 
800   Input Parameters:
801 + mat - the composite matrix
802 - i   - the number of requested matrix
803 
804   Output Parameter:
805 . Ai - ith matrix in composite
806 
807   Level: advanced
808 
809 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetNumberMat()`, `MatCompositeAddMat()`, `MATCOMPOSITE`
810 @*/
811 PetscErrorCode MatCompositeGetMat(Mat mat, PetscInt i, Mat *Ai)
812 {
813   PetscFunctionBegin;
814   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
815   PetscValidLogicalCollectiveInt(mat, i, 2);
816   PetscAssertPointer(Ai, 3);
817   PetscUseMethod(mat, "MatCompositeGetMat_C", (Mat, PetscInt, Mat *), (mat, i, Ai));
818   PetscFunctionReturn(PETSC_SUCCESS);
819 }
820 
821 static PetscErrorCode MatCompositeSetScalings_Composite(Mat mat, const PetscScalar *scalings)
822 {
823   Mat_Composite *shell;
824   PetscInt       nmat;
825 
826   PetscFunctionBegin;
827   PetscCall(MatShellGetContext(mat, &shell));
828   PetscCall(MatCompositeGetNumberMat(mat, &nmat));
829   if (!shell->scalings) PetscCall(PetscMalloc1(nmat, &shell->scalings));
830   PetscCall(PetscArraycpy(shell->scalings, scalings, nmat));
831   PetscFunctionReturn(PETSC_SUCCESS);
832 }
833 
834 /*@
835   MatCompositeSetScalings - Sets separate scaling factors for component matrices.
836 
837   Logically Collective
838 
839   Input Parameters:
840 + mat      - the composite matrix
841 - scalings - array of scaling factors with scalings[i] being factor of i-th matrix, for i in [0, nmat)
842 
843   Level: advanced
844 
845 .seealso: [](ch_matrices), `Mat`, `MatScale()`, `MatDiagonalScale()`, `MATCOMPOSITE`
846 @*/
847 PetscErrorCode MatCompositeSetScalings(Mat mat, const PetscScalar *scalings)
848 {
849   PetscFunctionBegin;
850   PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
851   PetscAssertPointer(scalings, 2);
852   PetscValidLogicalCollectiveScalar(mat, *scalings, 2);
853   PetscUseMethod(mat, "MatCompositeSetScalings_C", (Mat, const PetscScalar *), (mat, scalings));
854   PetscFunctionReturn(PETSC_SUCCESS);
855 }
856 
857 /*MC
858    MATCOMPOSITE - A matrix defined by the sum (or product) of one or more matrices.
859     The matrices need to have a correct size and parallel layout for the sum or product to be valid.
860 
861   Level: advanced
862 
863    Note:
864    To use the product of the matrices call `MatCompositeSetType`(mat,`MAT_COMPOSITE_MULTIPLICATIVE`);
865 
866   Developer Notes:
867   This is implemented on top of `MATSHELL` to get support for scaling and shifting without requiring duplicate code
868 
869   Users can not call `MatShellSetOperation()` operations on this class, there is some error checking for that incorrect usage
870 
871 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetScalings()`, `MatCompositeAddMat()`, `MatSetType()`, `MatCompositeSetType()`, `MatCompositeGetType()`,
872           `MatCompositeSetMatStructure()`, `MatCompositeGetMatStructure()`, `MatCompositeMerge()`, `MatCompositeSetMergeType()`, `MatCompositeGetNumberMat()`, `MatCompositeGetMat()`
873 M*/
874 
875 PETSC_EXTERN PetscErrorCode MatCreate_Composite(Mat A)
876 {
877   Mat_Composite *b;
878 
879   PetscFunctionBegin;
880   PetscCall(PetscNew(&b));
881 
882   b->type        = MAT_COMPOSITE_ADDITIVE;
883   b->nmat        = 0;
884   b->merge       = PETSC_FALSE;
885   b->mergetype   = MAT_COMPOSITE_MERGE_RIGHT;
886   b->structure   = DIFFERENT_NONZERO_PATTERN;
887   b->merge_mvctx = PETSC_TRUE;
888 
889   PetscCall(MatSetType(A, MATSHELL));
890   PetscCall(MatShellSetContext(A, b));
891   PetscCall(MatShellSetOperation(A, MATOP_DESTROY, (PetscErrorCodeFn *)MatDestroy_Composite));
892   PetscCall(MatShellSetOperation(A, MATOP_MULT, (PetscErrorCodeFn *)MatMult_Composite));
893   PetscCall(MatShellSetOperation(A, MATOP_MULT_TRANSPOSE, (PetscErrorCodeFn *)MatMultTranspose_Composite));
894   PetscCall(MatShellSetOperation(A, MATOP_GET_DIAGONAL, (PetscErrorCodeFn *)MatGetDiagonal_Composite));
895   PetscCall(MatShellSetOperation(A, MATOP_ASSEMBLY_END, (PetscErrorCodeFn *)MatAssemblyEnd_Composite));
896   PetscCall(MatShellSetOperation(A, MATOP_SET_FROM_OPTIONS, (PetscErrorCodeFn *)MatSetFromOptions_Composite));
897   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeAddMat_C", MatCompositeAddMat_Composite));
898   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetType_C", MatCompositeSetType_Composite));
899   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetType_C", MatCompositeGetType_Composite));
900   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMergeType_C", MatCompositeSetMergeType_Composite));
901   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMatStructure_C", MatCompositeSetMatStructure_Composite));
902   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMatStructure_C", MatCompositeGetMatStructure_Composite));
903   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeMerge_C", MatCompositeMerge_Composite));
904   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetNumberMat_C", MatCompositeGetNumberMat_Composite));
905   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMat_C", MatCompositeGetMat_Composite));
906   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetScalings_C", MatCompositeSetScalings_Composite));
907   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContext_C", MatShellSetContext_Immutable));
908   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContextDestroy_C", MatShellSetContextDestroy_Immutable));
909   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetManageScalingShifts_C", MatShellSetManageScalingShifts_Immutable));
910   PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATCOMPOSITE));
911   PetscFunctionReturn(PETSC_SUCCESS);
912 }
913