1 #include <../src/mat/impls/shell/shell.h> /*I "petscmat.h" I*/
2
3 const char *const MatCompositeMergeTypes[] = {"left", "right", "MatCompositeMergeType", "MAT_COMPOSITE_", NULL};
4
5 typedef struct _Mat_CompositeLink *Mat_CompositeLink;
6 struct _Mat_CompositeLink {
7 Mat mat;
8 Vec work;
9 Mat_CompositeLink next, prev;
10 };
11
12 typedef struct {
13 MatCompositeType type;
14 Mat_CompositeLink head, tail;
15 Vec work;
16 PetscInt nmat;
17 PetscBool merge;
18 MatCompositeMergeType mergetype;
19 MatStructure structure;
20
21 PetscScalar *scalings;
22 PetscBool merge_mvctx; /* Whether need to merge mvctx of component matrices */
23 Vec *lvecs; /* [nmat] Basically, they are Mvctx->lvec of each component matrix */
24 PetscScalar *larray; /* [len] Data arrays of lvecs[] are stored consecutively in larray */
25 PetscInt len; /* Length of larray[] */
26 Vec gvec; /* Union of lvecs[] without duplicated entries */
27 PetscInt *location; /* A map that maps entries in garray[] to larray[] */
28 VecScatter Mvctx;
29 } Mat_Composite;
30
MatDestroy_Composite(Mat mat)31 static PetscErrorCode MatDestroy_Composite(Mat mat)
32 {
33 Mat_Composite *shell;
34 Mat_CompositeLink next, oldnext;
35 PetscInt i;
36
37 PetscFunctionBegin;
38 PetscCall(MatShellGetContext(mat, &shell));
39 next = shell->head;
40 while (next) {
41 PetscCall(MatDestroy(&next->mat));
42 if (next->work && (!next->next || next->work != next->next->work)) PetscCall(VecDestroy(&next->work));
43 oldnext = next;
44 next = next->next;
45 PetscCall(PetscFree(oldnext));
46 }
47 PetscCall(VecDestroy(&shell->work));
48
49 if (shell->Mvctx) {
50 for (i = 0; i < shell->nmat; i++) PetscCall(VecDestroy(&shell->lvecs[i]));
51 PetscCall(PetscFree3(shell->location, shell->larray, shell->lvecs));
52 PetscCall(PetscFree(shell->larray));
53 PetscCall(VecDestroy(&shell->gvec));
54 PetscCall(VecScatterDestroy(&shell->Mvctx));
55 }
56
57 PetscCall(PetscFree(shell->scalings));
58 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeAddMat_C", NULL));
59 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetType_C", NULL));
60 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetType_C", NULL));
61 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMergeType_C", NULL));
62 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMatStructure_C", NULL));
63 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMatStructure_C", NULL));
64 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeMerge_C", NULL));
65 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetNumberMat_C", NULL));
66 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMat_C", NULL));
67 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetScalings_C", NULL));
68 PetscCall(PetscFree(shell));
69 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatShellSetContext_C", NULL)); // needed to avoid a call to MatShellSetContext_Immutable()
70 PetscFunctionReturn(PETSC_SUCCESS);
71 }
72
MatMult_Composite_Multiplicative(Mat A,Vec x,Vec y)73 static PetscErrorCode MatMult_Composite_Multiplicative(Mat A, Vec x, Vec y)
74 {
75 Mat_Composite *shell;
76 Mat_CompositeLink next;
77 Vec out;
78
79 PetscFunctionBegin;
80 PetscCall(MatShellGetContext(A, &shell));
81 next = shell->head;
82 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
83 while (next->next) {
84 if (!next->work) { /* should reuse previous work if the same size */
85 PetscCall(MatCreateVecs(next->mat, NULL, &next->work));
86 }
87 out = next->work;
88 PetscCall(MatMult(next->mat, x, out));
89 x = out;
90 next = next->next;
91 }
92 PetscCall(MatMult(next->mat, x, y));
93 if (shell->scalings) {
94 PetscScalar scale = 1.0;
95 for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
96 PetscCall(VecScale(y, scale));
97 }
98 PetscFunctionReturn(PETSC_SUCCESS);
99 }
100
MatMultTranspose_Composite_Multiplicative(Mat A,Vec x,Vec y)101 static PetscErrorCode MatMultTranspose_Composite_Multiplicative(Mat A, Vec x, Vec y)
102 {
103 Mat_Composite *shell;
104 Mat_CompositeLink tail;
105 Vec out;
106
107 PetscFunctionBegin;
108 PetscCall(MatShellGetContext(A, &shell));
109 tail = shell->tail;
110 PetscCheck(tail, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
111 while (tail->prev) {
112 if (!tail->prev->work) { /* should reuse previous work if the same size */
113 PetscCall(MatCreateVecs(tail->mat, NULL, &tail->prev->work));
114 }
115 out = tail->prev->work;
116 PetscCall(MatMultTranspose(tail->mat, x, out));
117 x = out;
118 tail = tail->prev;
119 }
120 PetscCall(MatMultTranspose(tail->mat, x, y));
121 if (shell->scalings) {
122 PetscScalar scale = 1.0;
123 for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
124 PetscCall(VecScale(y, scale));
125 }
126 PetscFunctionReturn(PETSC_SUCCESS);
127 }
128
MatMult_Composite(Mat mat,Vec x,Vec y)129 static PetscErrorCode MatMult_Composite(Mat mat, Vec x, Vec y)
130 {
131 Mat_Composite *shell;
132 Mat_CompositeLink cur;
133 Vec y2, xin;
134 Mat A, B;
135 PetscInt i, j, k, n, nuniq, lo, hi, mid, *gindices, *buf, *tmp, tot;
136 const PetscScalar *vals;
137 const PetscInt *garray;
138 IS ix, iy;
139 PetscBool match;
140
141 PetscFunctionBegin;
142 PetscCall(MatShellGetContext(mat, &shell));
143 cur = shell->head;
144 PetscCheck(cur, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
145
146 /* Try to merge Mvctx when instructed but not yet done. We did not do it in MatAssemblyEnd() since at that time
147 we did not know whether mat is ADDITIVE or MULTIPLICATIVE. Only now we are assured mat is ADDITIVE and
148 it is legal to merge Mvctx, because all component matrices have the same size.
149 */
150 if (shell->merge_mvctx && !shell->Mvctx) {
151 /* Currently only implemented for MATMPIAIJ */
152 for (cur = shell->head; cur; cur = cur->next) {
153 PetscCall(PetscObjectTypeCompare((PetscObject)cur->mat, MATMPIAIJ, &match));
154 if (!match) {
155 shell->merge_mvctx = PETSC_FALSE;
156 goto skip_merge_mvctx;
157 }
158 }
159
160 /* Go through matrices first time to count total number of nonzero off-diag columns (may have dups) */
161 tot = 0;
162 for (cur = shell->head; cur; cur = cur->next) {
163 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, NULL));
164 PetscCall(MatGetLocalSize(B, NULL, &n));
165 tot += n;
166 }
167 PetscCall(PetscMalloc3(tot, &shell->location, tot, &shell->larray, shell->nmat, &shell->lvecs));
168 shell->len = tot;
169
170 /* Go through matrices second time to sort off-diag columns and remove dups */
171 PetscCall(PetscMalloc1(tot, &gindices)); /* No Malloc2() since we will give one to PETSc and free the other */
172 PetscCall(PetscMalloc1(tot, &buf));
173 nuniq = 0; /* Number of unique nonzero columns */
174 for (cur = shell->head; cur; cur = cur->next) {
175 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray));
176 PetscCall(MatGetLocalSize(B, NULL, &n));
177 /* Merge pre-sorted garray[0,n) and gindices[0,nuniq) to buf[] */
178 i = j = k = 0;
179 while (i < n && j < nuniq) {
180 if (garray[i] < gindices[j]) buf[k++] = garray[i++];
181 else if (garray[i] > gindices[j]) buf[k++] = gindices[j++];
182 else {
183 buf[k++] = garray[i++];
184 j++;
185 }
186 }
187 /* Copy leftover in garray[] or gindices[] */
188 if (i < n) {
189 PetscCall(PetscArraycpy(buf + k, garray + i, n - i));
190 nuniq = k + n - i;
191 } else if (j < nuniq) {
192 PetscCall(PetscArraycpy(buf + k, gindices + j, nuniq - j));
193 nuniq = k + nuniq - j;
194 } else nuniq = k;
195 /* Swap gindices and buf to merge garray of the next matrix */
196 tmp = gindices;
197 gindices = buf;
198 buf = tmp;
199 }
200 PetscCall(PetscFree(buf));
201
202 /* Go through matrices third time to build a map from gindices[] to garray[] */
203 tot = 0;
204 for (cur = shell->head, j = 0; cur; cur = cur->next, j++) { /* j-th matrix */
205 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray));
206 PetscCall(MatGetLocalSize(B, NULL, &n));
207 PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, 1, n, NULL, &shell->lvecs[j]));
208 /* This is an optimized PetscFindInt(garray[i],nuniq,gindices,&shell->location[tot+i]), using the fact that garray[] is also sorted */
209 lo = 0;
210 for (i = 0; i < n; i++) {
211 hi = nuniq;
212 while (hi - lo > 1) {
213 mid = lo + (hi - lo) / 2;
214 if (garray[i] < gindices[mid]) hi = mid;
215 else lo = mid;
216 }
217 shell->location[tot + i] = lo; /* gindices[lo] = garray[i] */
218 lo++; /* Since garray[i+1] > garray[i], we can safely advance lo */
219 }
220 tot += n;
221 }
222
223 /* Build merged Mvctx */
224 PetscCall(ISCreateGeneral(PETSC_COMM_SELF, nuniq, gindices, PETSC_OWN_POINTER, &ix));
225 PetscCall(ISCreateStride(PETSC_COMM_SELF, nuniq, 0, 1, &iy));
226 PetscCall(VecCreateMPIWithArray(PetscObjectComm((PetscObject)mat), 1, mat->cmap->n, mat->cmap->N, NULL, &xin));
227 PetscCall(VecCreateSeq(PETSC_COMM_SELF, nuniq, &shell->gvec));
228 PetscCall(VecScatterCreate(xin, ix, shell->gvec, iy, &shell->Mvctx));
229 PetscCall(VecDestroy(&xin));
230 PetscCall(ISDestroy(&ix));
231 PetscCall(ISDestroy(&iy));
232 }
233
234 skip_merge_mvctx:
235 PetscCall(VecSet(y, 0));
236 if (!((Mat_Shell *)mat->data)->left_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)mat->data)->left_work)));
237 y2 = ((Mat_Shell *)mat->data)->left_work;
238
239 if (shell->Mvctx) { /* Have a merged Mvctx */
240 /* Suppose we want to compute y = sMx, where s is the scaling factor and A, B are matrix M's diagonal/off-diagonal part. We could do
241 in y = s(Ax1 + Bx2) or y = sAx1 + sBx2. The former incurs less FLOPS than the latter, but the latter provides an opportunity to
242 overlap communication/computation since we can do sAx1 while communicating x2. Here, we use the former approach.
243 */
244 PetscCall(VecScatterBegin(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD));
245 PetscCall(VecScatterEnd(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD));
246
247 PetscCall(VecGetArrayRead(shell->gvec, &vals));
248 for (i = 0; i < shell->len; i++) shell->larray[i] = vals[shell->location[i]];
249 PetscCall(VecRestoreArrayRead(shell->gvec, &vals));
250
251 for (cur = shell->head, tot = i = 0; cur; cur = cur->next, i++) { /* i-th matrix */
252 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, &A, &B, NULL));
253 PetscUseTypeMethod(A, mult, x, y2);
254 PetscCall(MatGetLocalSize(B, NULL, &n));
255 PetscCall(VecPlaceArray(shell->lvecs[i], &shell->larray[tot]));
256 PetscUseTypeMethod(B, multadd, shell->lvecs[i], y2, y2);
257 PetscCall(VecResetArray(shell->lvecs[i]));
258 PetscCall(VecAXPY(y, shell->scalings ? shell->scalings[i] : 1.0, y2));
259 tot += n;
260 }
261 } else {
262 if (shell->scalings) {
263 for (cur = shell->head, i = 0; cur; cur = cur->next, i++) {
264 PetscCall(MatMult(cur->mat, x, y2));
265 PetscCall(VecAXPY(y, shell->scalings[i], y2));
266 }
267 } else {
268 for (cur = shell->head; cur; cur = cur->next) PetscCall(MatMultAdd(cur->mat, x, y, y));
269 }
270 }
271 PetscFunctionReturn(PETSC_SUCCESS);
272 }
273
MatMultTranspose_Composite(Mat A,Vec x,Vec y)274 static PetscErrorCode MatMultTranspose_Composite(Mat A, Vec x, Vec y)
275 {
276 Mat_Composite *shell;
277 Mat_CompositeLink next;
278 Vec y2 = NULL;
279 PetscInt i;
280
281 PetscFunctionBegin;
282 PetscCall(MatShellGetContext(A, &shell));
283 next = shell->head;
284 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
285
286 PetscCall(MatMultTranspose(next->mat, x, y));
287 if (shell->scalings) {
288 PetscCall(VecScale(y, shell->scalings[0]));
289 if (!((Mat_Shell *)A->data)->right_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)A->data)->right_work)));
290 y2 = ((Mat_Shell *)A->data)->right_work;
291 }
292 i = 1;
293 while ((next = next->next)) {
294 if (!shell->scalings) PetscCall(MatMultTransposeAdd(next->mat, x, y, y));
295 else {
296 PetscCall(MatMultTranspose(next->mat, x, y2));
297 PetscCall(VecAXPY(y, shell->scalings[i++], y2));
298 }
299 }
300 PetscFunctionReturn(PETSC_SUCCESS);
301 }
302
MatGetDiagonal_Composite(Mat A,Vec v)303 static PetscErrorCode MatGetDiagonal_Composite(Mat A, Vec v)
304 {
305 Mat_Composite *shell;
306 Mat_CompositeLink next;
307 PetscInt i;
308
309 PetscFunctionBegin;
310 PetscCall(MatShellGetContext(A, &shell));
311 next = shell->head;
312 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
313 PetscCall(MatGetDiagonal(next->mat, v));
314 if (shell->scalings) PetscCall(VecScale(v, shell->scalings[0]));
315
316 if (next->next && !shell->work) PetscCall(VecDuplicate(v, &shell->work));
317 i = 1;
318 while ((next = next->next)) {
319 PetscCall(MatGetDiagonal(next->mat, shell->work));
320 PetscCall(VecAXPY(v, shell->scalings ? shell->scalings[i++] : 1.0, shell->work));
321 }
322 PetscFunctionReturn(PETSC_SUCCESS);
323 }
324
MatAssemblyEnd_Composite(Mat Y,MatAssemblyType t)325 static PetscErrorCode MatAssemblyEnd_Composite(Mat Y, MatAssemblyType t)
326 {
327 Mat_Composite *shell;
328
329 PetscFunctionBegin;
330 PetscCall(MatShellGetContext(Y, &shell));
331 if (shell->merge) PetscCall(MatCompositeMerge(Y));
332 else PetscCall(MatAssemblyEnd_Shell(Y, t));
333 PetscFunctionReturn(PETSC_SUCCESS);
334 }
335
MatSetFromOptions_Composite(Mat A,PetscOptionItems PetscOptionsObject)336 static PetscErrorCode MatSetFromOptions_Composite(Mat A, PetscOptionItems PetscOptionsObject)
337 {
338 Mat_Composite *a;
339
340 PetscFunctionBegin;
341 PetscCall(MatShellGetContext(A, &a));
342 PetscOptionsHeadBegin(PetscOptionsObject, "MATCOMPOSITE options");
343 PetscCall(PetscOptionsBool("-mat_composite_merge", "Merge at MatAssemblyEnd", "MatCompositeMerge", a->merge, &a->merge, NULL));
344 PetscCall(PetscOptionsEnum("-mat_composite_merge_type", "Set composite merge direction", "MatCompositeSetMergeType", MatCompositeMergeTypes, (PetscEnum)a->mergetype, (PetscEnum *)&a->mergetype, NULL));
345 PetscCall(PetscOptionsBool("-mat_composite_merge_mvctx", "Merge MatMult() vecscat contexts", "MatCreateComposite", a->merge_mvctx, &a->merge_mvctx, NULL));
346 PetscOptionsHeadEnd();
347 PetscFunctionReturn(PETSC_SUCCESS);
348 }
349
350 /*@
351 MatCreateComposite - Creates a matrix as the sum or product of one or more matrices
352
353 Collective
354
355 Input Parameters:
356 + comm - MPI communicator
357 . nmat - number of matrices to put in
358 - mats - the matrices
359
360 Output Parameter:
361 . mat - the matrix
362
363 Options Database Keys:
364 + -mat_composite_merge - merge in `MatAssemblyEnd()`
365 . -mat_composite_merge_mvctx - merge Mvctx of component matrices to optimize communication in `MatMult()` for ADDITIVE matrices
366 - -mat_composite_merge_type - set merge direction
367
368 Level: advanced
369
370 Note:
371 Alternative construction
372 .vb
373 MatCreate(comm,&mat);
374 MatSetSizes(mat,m,n,M,N);
375 MatSetType(mat,MATCOMPOSITE);
376 MatCompositeAddMat(mat,mats[0]);
377 ....
378 MatCompositeAddMat(mat,mats[nmat-1]);
379 MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY);
380 MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY);
381 .ve
382
383 For the multiplicative form the product is mat[nmat-1]*mat[nmat-2]*....*mat[0]
384
385 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCompositeGetMat()`, `MatCompositeMerge()`, `MatCompositeSetType()`,
386 `MATCOMPOSITE`, `MatCompositeType`
387 @*/
MatCreateComposite(MPI_Comm comm,PetscInt nmat,const Mat * mats,Mat * mat)388 PetscErrorCode MatCreateComposite(MPI_Comm comm, PetscInt nmat, const Mat *mats, Mat *mat)
389 {
390 PetscFunctionBegin;
391 PetscCheck(nmat >= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Must pass in at least one matrix");
392 PetscAssertPointer(mat, 4);
393 PetscCall(MatCreate(comm, mat));
394 PetscCall(MatSetType(*mat, MATCOMPOSITE));
395 for (PetscInt i = 0; i < nmat; i++) PetscCall(MatCompositeAddMat(*mat, mats[i]));
396 PetscCall(MatAssemblyBegin(*mat, MAT_FINAL_ASSEMBLY));
397 PetscCall(MatAssemblyEnd(*mat, MAT_FINAL_ASSEMBLY));
398 PetscFunctionReturn(PETSC_SUCCESS);
399 }
400
MatCompositeAddMat_Composite(Mat mat,Mat smat)401 static PetscErrorCode MatCompositeAddMat_Composite(Mat mat, Mat smat)
402 {
403 Mat_Composite *shell;
404 Mat_CompositeLink ilink, next;
405 VecType vtype_mat, vtype_smat;
406 PetscBool match;
407
408 PetscFunctionBegin;
409 PetscCall(MatShellGetContext(mat, &shell));
410 next = shell->head;
411 PetscCall(PetscNew(&ilink));
412 ilink->next = NULL;
413 PetscCall(PetscObjectReference((PetscObject)smat));
414 ilink->mat = smat;
415
416 if (!next) shell->head = ilink;
417 else {
418 while (next->next) next = next->next;
419 next->next = ilink;
420 ilink->prev = next;
421 }
422 shell->tail = ilink;
423 shell->nmat += 1;
424
425 /* If all of the partial matrices have the same default vector type, then the composite matrix should also have this default type.
426 Otherwise, the default type should be "standard". */
427 PetscCall(MatGetVecType(smat, &vtype_smat));
428 if (shell->nmat == 1) PetscCall(MatSetVecType(mat, vtype_smat));
429 else {
430 PetscCall(MatGetVecType(mat, &vtype_mat));
431 PetscCall(PetscStrcmp(vtype_smat, vtype_mat, &match));
432 if (!match) PetscCall(MatSetVecType(mat, VECSTANDARD));
433 }
434
435 /* Retain the old scalings (if any) and expand it with a 1.0 for the newly added matrix */
436 if (shell->scalings) {
437 PetscCall(PetscRealloc(sizeof(PetscScalar) * shell->nmat, &shell->scalings));
438 shell->scalings[shell->nmat - 1] = 1.0;
439 }
440
441 /* The composite matrix requires PetscLayouts for its rows and columns; we copy these from the constituent partial matrices. */
442 if (shell->nmat == 1) PetscCall(PetscLayoutReference(smat->cmap, &mat->cmap));
443 PetscCall(PetscLayoutReference(smat->rmap, &mat->rmap));
444 PetscFunctionReturn(PETSC_SUCCESS);
445 }
446
447 /*@
448 MatCompositeAddMat - Add another matrix to a composite matrix.
449
450 Collective
451
452 Input Parameters:
453 + mat - the composite matrix
454 - smat - the partial matrix
455
456 Level: advanced
457
458 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE`
459 @*/
MatCompositeAddMat(Mat mat,Mat smat)460 PetscErrorCode MatCompositeAddMat(Mat mat, Mat smat)
461 {
462 PetscFunctionBegin;
463 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
464 PetscValidHeaderSpecific(smat, MAT_CLASSID, 2);
465 PetscUseMethod(mat, "MatCompositeAddMat_C", (Mat, Mat), (mat, smat));
466 PetscFunctionReturn(PETSC_SUCCESS);
467 }
468
MatCompositeSetType_Composite(Mat mat,MatCompositeType type)469 static PetscErrorCode MatCompositeSetType_Composite(Mat mat, MatCompositeType type)
470 {
471 Mat_Composite *b;
472
473 PetscFunctionBegin;
474 PetscCall(MatShellGetContext(mat, &b));
475 b->type = type;
476 if (type == MAT_COMPOSITE_MULTIPLICATIVE) {
477 PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, NULL));
478 PetscCall(MatShellSetOperation(mat, MATOP_MULT, (PetscErrorCodeFn *)MatMult_Composite_Multiplicative));
479 PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (PetscErrorCodeFn *)MatMultTranspose_Composite_Multiplicative));
480 b->merge_mvctx = PETSC_FALSE;
481 } else {
482 PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, (PetscErrorCodeFn *)MatGetDiagonal_Composite));
483 PetscCall(MatShellSetOperation(mat, MATOP_MULT, (PetscErrorCodeFn *)MatMult_Composite));
484 PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (PetscErrorCodeFn *)MatMultTranspose_Composite));
485 }
486 PetscFunctionReturn(PETSC_SUCCESS);
487 }
488
489 /*@
490 MatCompositeSetType - Indicates if the matrix is defined as the sum of a set of matrices or the product.
491
492 Logically Collective
493
494 Input Parameters:
495 + mat - the composite matrix
496 - type - the `MatCompositeType` to use for the matrix
497
498 Level: advanced
499
500 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeGetType()`, `MATCOMPOSITE`,
501 `MatCompositeType`
502 @*/
MatCompositeSetType(Mat mat,MatCompositeType type)503 PetscErrorCode MatCompositeSetType(Mat mat, MatCompositeType type)
504 {
505 PetscFunctionBegin;
506 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
507 PetscValidLogicalCollectiveEnum(mat, type, 2);
508 PetscUseMethod(mat, "MatCompositeSetType_C", (Mat, MatCompositeType), (mat, type));
509 PetscFunctionReturn(PETSC_SUCCESS);
510 }
511
MatCompositeGetType_Composite(Mat mat,MatCompositeType * type)512 static PetscErrorCode MatCompositeGetType_Composite(Mat mat, MatCompositeType *type)
513 {
514 Mat_Composite *shell;
515
516 PetscFunctionBegin;
517 PetscCall(MatShellGetContext(mat, &shell));
518 *type = shell->type;
519 PetscFunctionReturn(PETSC_SUCCESS);
520 }
521
522 /*@
523 MatCompositeGetType - Returns type of composite.
524
525 Not Collective
526
527 Input Parameter:
528 . mat - the composite matrix
529
530 Output Parameter:
531 . type - type of composite
532
533 Level: advanced
534
535 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetType()`, `MATCOMPOSITE`, `MatCompositeType`
536 @*/
MatCompositeGetType(Mat mat,MatCompositeType * type)537 PetscErrorCode MatCompositeGetType(Mat mat, MatCompositeType *type)
538 {
539 PetscFunctionBegin;
540 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
541 PetscAssertPointer(type, 2);
542 PetscUseMethod(mat, "MatCompositeGetType_C", (Mat, MatCompositeType *), (mat, type));
543 PetscFunctionReturn(PETSC_SUCCESS);
544 }
545
MatCompositeSetMatStructure_Composite(Mat mat,MatStructure str)546 static PetscErrorCode MatCompositeSetMatStructure_Composite(Mat mat, MatStructure str)
547 {
548 Mat_Composite *shell;
549
550 PetscFunctionBegin;
551 PetscCall(MatShellGetContext(mat, &shell));
552 shell->structure = str;
553 PetscFunctionReturn(PETSC_SUCCESS);
554 }
555
556 /*@
557 MatCompositeSetMatStructure - Indicates structure of matrices in the composite matrix.
558
559 Not Collective
560
561 Input Parameters:
562 + mat - the composite matrix
563 - str - either `SAME_NONZERO_PATTERN`, `DIFFERENT_NONZERO_PATTERN` (default) or `SUBSET_NONZERO_PATTERN`
564
565 Level: advanced
566
567 Note:
568 Information about the matrices structure is used in `MatCompositeMerge()` for additive composite matrix.
569
570 .seealso: [](ch_matrices), `Mat`, `MatAXPY()`, `MatCreateComposite()`, `MatCompositeMerge()` `MatCompositeGetMatStructure()`, `MATCOMPOSITE`
571 @*/
MatCompositeSetMatStructure(Mat mat,MatStructure str)572 PetscErrorCode MatCompositeSetMatStructure(Mat mat, MatStructure str)
573 {
574 PetscFunctionBegin;
575 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
576 PetscUseMethod(mat, "MatCompositeSetMatStructure_C", (Mat, MatStructure), (mat, str));
577 PetscFunctionReturn(PETSC_SUCCESS);
578 }
579
MatCompositeGetMatStructure_Composite(Mat mat,MatStructure * str)580 static PetscErrorCode MatCompositeGetMatStructure_Composite(Mat mat, MatStructure *str)
581 {
582 Mat_Composite *shell;
583
584 PetscFunctionBegin;
585 PetscCall(MatShellGetContext(mat, &shell));
586 *str = shell->structure;
587 PetscFunctionReturn(PETSC_SUCCESS);
588 }
589
590 /*@
591 MatCompositeGetMatStructure - Returns the structure of matrices in the composite matrix.
592
593 Not Collective
594
595 Input Parameter:
596 . mat - the composite matrix
597
598 Output Parameter:
599 . str - structure of the matrices
600
601 Level: advanced
602
603 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MATCOMPOSITE`
604 @*/
MatCompositeGetMatStructure(Mat mat,MatStructure * str)605 PetscErrorCode MatCompositeGetMatStructure(Mat mat, MatStructure *str)
606 {
607 PetscFunctionBegin;
608 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
609 PetscAssertPointer(str, 2);
610 PetscUseMethod(mat, "MatCompositeGetMatStructure_C", (Mat, MatStructure *), (mat, str));
611 PetscFunctionReturn(PETSC_SUCCESS);
612 }
613
MatCompositeSetMergeType_Composite(Mat mat,MatCompositeMergeType type)614 static PetscErrorCode MatCompositeSetMergeType_Composite(Mat mat, MatCompositeMergeType type)
615 {
616 Mat_Composite *shell;
617
618 PetscFunctionBegin;
619 PetscCall(MatShellGetContext(mat, &shell));
620 shell->mergetype = type;
621 PetscFunctionReturn(PETSC_SUCCESS);
622 }
623
624 /*@
625 MatCompositeSetMergeType - Sets order of `MatCompositeMerge()`.
626
627 Logically Collective
628
629 Input Parameters:
630 + mat - the composite matrix
631 - type - `MAT_COMPOSITE_MERGE RIGHT` (default) to start merge from right with the first added matrix (mat[0]),
632 `MAT_COMPOSITE_MERGE_LEFT` to start merge from left with the last added matrix (mat[nmat-1])
633
634 Level: advanced
635
636 Note:
637 The resulting matrix is the same regardless of the `MatCompositeMergeType`. Only the order of operation is changed.
638 If set to `MAT_COMPOSITE_MERGE_RIGHT` the order of the merge is mat[nmat-1]*(mat[nmat-2]*(...*(mat[1]*mat[0])))
639 otherwise the order is (((mat[nmat-1]*mat[nmat-2])*mat[nmat-3])*...)*mat[0].
640
641 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeMerge()`, `MATCOMPOSITE`
642 @*/
MatCompositeSetMergeType(Mat mat,MatCompositeMergeType type)643 PetscErrorCode MatCompositeSetMergeType(Mat mat, MatCompositeMergeType type)
644 {
645 PetscFunctionBegin;
646 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
647 PetscValidLogicalCollectiveEnum(mat, type, 2);
648 PetscUseMethod(mat, "MatCompositeSetMergeType_C", (Mat, MatCompositeMergeType), (mat, type));
649 PetscFunctionReturn(PETSC_SUCCESS);
650 }
651
MatCompositeMerge_Composite(Mat mat)652 static PetscErrorCode MatCompositeMerge_Composite(Mat mat)
653 {
654 Mat_Composite *shell;
655 Mat_CompositeLink next, prev;
656 Mat tmat, newmat;
657 Vec left, right, dshift;
658 PetscScalar scale, shift;
659 PetscInt i;
660
661 PetscFunctionBegin;
662 PetscCall(MatShellGetContext(mat, &shell));
663 next = shell->head;
664 prev = shell->tail;
665 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
666 PetscCall(MatShellGetScalingShifts(mat, &shift, &scale, &dshift, &left, &right, (Mat *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED));
667 if (shell->type == MAT_COMPOSITE_ADDITIVE) {
668 if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) {
669 i = 0;
670 PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat));
671 if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i++]));
672 while ((next = next->next)) PetscCall(MatAXPY(tmat, shell->scalings ? shell->scalings[i++] : 1.0, next->mat, shell->structure));
673 } else {
674 i = shell->nmat - 1;
675 PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat));
676 if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i--]));
677 while ((prev = prev->prev)) PetscCall(MatAXPY(tmat, shell->scalings ? shell->scalings[i--] : 1.0, prev->mat, shell->structure));
678 }
679 } else {
680 if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) {
681 PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat));
682 while ((next = next->next)) {
683 PetscCall(MatMatMult(next->mat, tmat, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &newmat));
684 PetscCall(MatDestroy(&tmat));
685 tmat = newmat;
686 }
687 } else {
688 PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat));
689 while ((prev = prev->prev)) {
690 PetscCall(MatMatMult(tmat, prev->mat, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &newmat));
691 PetscCall(MatDestroy(&tmat));
692 tmat = newmat;
693 }
694 }
695 if (shell->scalings) {
696 for (i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
697 }
698 }
699
700 if (left) PetscCall(PetscObjectReference((PetscObject)left));
701 if (right) PetscCall(PetscObjectReference((PetscObject)right));
702 if (dshift) PetscCall(PetscObjectReference((PetscObject)dshift));
703
704 PetscCall(MatHeaderReplace(mat, &tmat));
705
706 PetscCall(MatDiagonalScale(mat, left, right));
707 PetscCall(MatScale(mat, scale));
708 PetscCall(MatShift(mat, shift));
709 PetscCall(VecDestroy(&left));
710 PetscCall(VecDestroy(&right));
711 if (dshift) {
712 PetscCall(MatDiagonalSet(mat, dshift, ADD_VALUES));
713 PetscCall(VecDestroy(&dshift));
714 }
715 PetscFunctionReturn(PETSC_SUCCESS);
716 }
717
718 /*@
719 MatCompositeMerge - Given a composite matrix, replaces it with a "regular" matrix
720 by summing or computing the product of all the matrices inside the composite matrix.
721
722 Collective
723
724 Input Parameter:
725 . mat - the composite matrix
726
727 Options Database Keys:
728 + -mat_composite_merge - merge in `MatAssemblyEnd()`
729 - -mat_composite_merge_type - set merge direction
730
731 Level: advanced
732
733 Note:
734 The `MatType` of the resulting matrix will be the same as the `MatType` of the FIRST matrix in the composite matrix.
735
736 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MatCompositeSetMergeType()`, `MATCOMPOSITE`
737 @*/
MatCompositeMerge(Mat mat)738 PetscErrorCode MatCompositeMerge(Mat mat)
739 {
740 PetscFunctionBegin;
741 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
742 PetscUseMethod(mat, "MatCompositeMerge_C", (Mat), (mat));
743 PetscFunctionReturn(PETSC_SUCCESS);
744 }
745
MatCompositeGetNumberMat_Composite(Mat mat,PetscInt * nmat)746 static PetscErrorCode MatCompositeGetNumberMat_Composite(Mat mat, PetscInt *nmat)
747 {
748 Mat_Composite *shell;
749
750 PetscFunctionBegin;
751 PetscCall(MatShellGetContext(mat, &shell));
752 *nmat = shell->nmat;
753 PetscFunctionReturn(PETSC_SUCCESS);
754 }
755
756 /*@
757 MatCompositeGetNumberMat - Returns the number of matrices in the composite matrix.
758
759 Not Collective
760
761 Input Parameter:
762 . mat - the composite matrix
763
764 Output Parameter:
765 . nmat - number of matrices in the composite matrix
766
767 Level: advanced
768
769 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE`
770 @*/
MatCompositeGetNumberMat(Mat mat,PetscInt * nmat)771 PetscErrorCode MatCompositeGetNumberMat(Mat mat, PetscInt *nmat)
772 {
773 PetscFunctionBegin;
774 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
775 PetscAssertPointer(nmat, 2);
776 PetscUseMethod(mat, "MatCompositeGetNumberMat_C", (Mat, PetscInt *), (mat, nmat));
777 PetscFunctionReturn(PETSC_SUCCESS);
778 }
779
MatCompositeGetMat_Composite(Mat mat,PetscInt i,Mat * Ai)780 static PetscErrorCode MatCompositeGetMat_Composite(Mat mat, PetscInt i, Mat *Ai)
781 {
782 Mat_Composite *shell;
783 Mat_CompositeLink ilink;
784 PetscInt k;
785
786 PetscFunctionBegin;
787 PetscCall(MatShellGetContext(mat, &shell));
788 PetscCheck(i < shell->nmat, PetscObjectComm((PetscObject)mat), PETSC_ERR_ARG_OUTOFRANGE, "index out of range: %" PetscInt_FMT " >= %" PetscInt_FMT, i, shell->nmat);
789 ilink = shell->head;
790 for (k = 0; k < i; k++) ilink = ilink->next;
791 *Ai = ilink->mat;
792 PetscFunctionReturn(PETSC_SUCCESS);
793 }
794
795 /*@
796 MatCompositeGetMat - Returns the ith matrix from the composite matrix.
797
798 Logically Collective
799
800 Input Parameters:
801 + mat - the composite matrix
802 - i - the number of requested matrix
803
804 Output Parameter:
805 . Ai - ith matrix in composite
806
807 Level: advanced
808
809 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetNumberMat()`, `MatCompositeAddMat()`, `MATCOMPOSITE`
810 @*/
MatCompositeGetMat(Mat mat,PetscInt i,Mat * Ai)811 PetscErrorCode MatCompositeGetMat(Mat mat, PetscInt i, Mat *Ai)
812 {
813 PetscFunctionBegin;
814 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
815 PetscValidLogicalCollectiveInt(mat, i, 2);
816 PetscAssertPointer(Ai, 3);
817 PetscUseMethod(mat, "MatCompositeGetMat_C", (Mat, PetscInt, Mat *), (mat, i, Ai));
818 PetscFunctionReturn(PETSC_SUCCESS);
819 }
820
MatCompositeSetScalings_Composite(Mat mat,const PetscScalar * scalings)821 static PetscErrorCode MatCompositeSetScalings_Composite(Mat mat, const PetscScalar *scalings)
822 {
823 Mat_Composite *shell;
824 PetscInt nmat;
825
826 PetscFunctionBegin;
827 PetscCall(MatShellGetContext(mat, &shell));
828 PetscCall(MatCompositeGetNumberMat(mat, &nmat));
829 if (!shell->scalings) PetscCall(PetscMalloc1(nmat, &shell->scalings));
830 PetscCall(PetscArraycpy(shell->scalings, scalings, nmat));
831 PetscFunctionReturn(PETSC_SUCCESS);
832 }
833
834 /*@
835 MatCompositeSetScalings - Sets separate scaling factors for component matrices.
836
837 Logically Collective
838
839 Input Parameters:
840 + mat - the composite matrix
841 - scalings - array of scaling factors with scalings[i] being factor of i-th matrix, for i in [0, nmat)
842
843 Level: advanced
844
845 .seealso: [](ch_matrices), `Mat`, `MatScale()`, `MatDiagonalScale()`, `MATCOMPOSITE`
846 @*/
MatCompositeSetScalings(Mat mat,const PetscScalar * scalings)847 PetscErrorCode MatCompositeSetScalings(Mat mat, const PetscScalar *scalings)
848 {
849 PetscFunctionBegin;
850 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
851 PetscAssertPointer(scalings, 2);
852 PetscValidLogicalCollectiveScalar(mat, *scalings, 2);
853 PetscUseMethod(mat, "MatCompositeSetScalings_C", (Mat, const PetscScalar *), (mat, scalings));
854 PetscFunctionReturn(PETSC_SUCCESS);
855 }
856
857 /*MC
858 MATCOMPOSITE - A matrix defined by the sum (or product) of one or more matrices.
859 The matrices need to have a correct size and parallel layout for the sum or product to be valid.
860
861 Level: advanced
862
863 Note:
864 To use the product of the matrices call `MatCompositeSetType`(mat,`MAT_COMPOSITE_MULTIPLICATIVE`);
865
866 Developer Notes:
867 This is implemented on top of `MATSHELL` to get support for scaling and shifting without requiring duplicate code
868
869 Users can not call `MatShellSetOperation()` operations on this class, there is some error checking for that incorrect usage
870
871 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetScalings()`, `MatCompositeAddMat()`, `MatSetType()`, `MatCompositeSetType()`, `MatCompositeGetType()`,
872 `MatCompositeSetMatStructure()`, `MatCompositeGetMatStructure()`, `MatCompositeMerge()`, `MatCompositeSetMergeType()`, `MatCompositeGetNumberMat()`, `MatCompositeGetMat()`
873 M*/
874
MatCreate_Composite(Mat A)875 PETSC_EXTERN PetscErrorCode MatCreate_Composite(Mat A)
876 {
877 Mat_Composite *b;
878
879 PetscFunctionBegin;
880 PetscCall(PetscNew(&b));
881
882 b->type = MAT_COMPOSITE_ADDITIVE;
883 b->nmat = 0;
884 b->merge = PETSC_FALSE;
885 b->mergetype = MAT_COMPOSITE_MERGE_RIGHT;
886 b->structure = DIFFERENT_NONZERO_PATTERN;
887 b->merge_mvctx = PETSC_TRUE;
888
889 PetscCall(MatSetType(A, MATSHELL));
890 PetscCall(MatShellSetContext(A, b));
891 PetscCall(MatShellSetOperation(A, MATOP_DESTROY, (PetscErrorCodeFn *)MatDestroy_Composite));
892 PetscCall(MatShellSetOperation(A, MATOP_MULT, (PetscErrorCodeFn *)MatMult_Composite));
893 PetscCall(MatShellSetOperation(A, MATOP_MULT_TRANSPOSE, (PetscErrorCodeFn *)MatMultTranspose_Composite));
894 PetscCall(MatShellSetOperation(A, MATOP_GET_DIAGONAL, (PetscErrorCodeFn *)MatGetDiagonal_Composite));
895 PetscCall(MatShellSetOperation(A, MATOP_ASSEMBLY_END, (PetscErrorCodeFn *)MatAssemblyEnd_Composite));
896 PetscCall(MatShellSetOperation(A, MATOP_SET_FROM_OPTIONS, (PetscErrorCodeFn *)MatSetFromOptions_Composite));
897 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeAddMat_C", MatCompositeAddMat_Composite));
898 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetType_C", MatCompositeSetType_Composite));
899 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetType_C", MatCompositeGetType_Composite));
900 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMergeType_C", MatCompositeSetMergeType_Composite));
901 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMatStructure_C", MatCompositeSetMatStructure_Composite));
902 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMatStructure_C", MatCompositeGetMatStructure_Composite));
903 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeMerge_C", MatCompositeMerge_Composite));
904 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetNumberMat_C", MatCompositeGetNumberMat_Composite));
905 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMat_C", MatCompositeGetMat_Composite));
906 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetScalings_C", MatCompositeSetScalings_Composite));
907 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContext_C", MatShellSetContext_Immutable));
908 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContextDestroy_C", MatShellSetContextDestroy_Immutable));
909 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetManageScalingShifts_C", MatShellSetManageScalingShifts_Immutable));
910 PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATCOMPOSITE));
911 PetscFunctionReturn(PETSC_SUCCESS);
912 }
913