1 #include <../src/mat/impls/shell/shell.h> /*I "petscmat.h" I*/ 2 3 const char *const MatCompositeMergeTypes[] = {"left", "right", "MatCompositeMergeType", "MAT_COMPOSITE_", NULL}; 4 5 typedef struct _Mat_CompositeLink *Mat_CompositeLink; 6 struct _Mat_CompositeLink { 7 Mat mat; 8 Vec work; 9 Mat_CompositeLink next, prev; 10 }; 11 12 typedef struct { 13 MatCompositeType type; 14 Mat_CompositeLink head, tail; 15 Vec work; 16 PetscInt nmat; 17 PetscBool merge; 18 MatCompositeMergeType mergetype; 19 MatStructure structure; 20 21 PetscScalar *scalings; 22 PetscBool merge_mvctx; /* Whether need to merge mvctx of component matrices */ 23 Vec *lvecs; /* [nmat] Basically, they are Mvctx->lvec of each component matrix */ 24 PetscScalar *larray; /* [len] Data arrays of lvecs[] are stored consecutively in larray */ 25 PetscInt len; /* Length of larray[] */ 26 Vec gvec; /* Union of lvecs[] without duplicated entries */ 27 PetscInt *location; /* A map that maps entries in garray[] to larray[] */ 28 VecScatter Mvctx; 29 } Mat_Composite; 30 31 static PetscErrorCode MatDestroy_Composite(Mat mat) 32 { 33 Mat_Composite *shell; 34 Mat_CompositeLink next, oldnext; 35 PetscInt i; 36 37 PetscFunctionBegin; 38 PetscCall(MatShellGetContext(mat, &shell)); 39 next = shell->head; 40 while (next) { 41 PetscCall(MatDestroy(&next->mat)); 42 if (next->work && (!next->next || next->work != next->next->work)) PetscCall(VecDestroy(&next->work)); 43 oldnext = next; 44 next = next->next; 45 PetscCall(PetscFree(oldnext)); 46 } 47 PetscCall(VecDestroy(&shell->work)); 48 49 if (shell->Mvctx) { 50 for (i = 0; i < shell->nmat; i++) PetscCall(VecDestroy(&shell->lvecs[i])); 51 PetscCall(PetscFree3(shell->location, shell->larray, shell->lvecs)); 52 PetscCall(PetscFree(shell->larray)); 53 PetscCall(VecDestroy(&shell->gvec)); 54 PetscCall(VecScatterDestroy(&shell->Mvctx)); 55 } 56 57 PetscCall(PetscFree(shell->scalings)); 58 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeAddMat_C", NULL)); 59 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetType_C", NULL)); 60 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetType_C", NULL)); 61 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMergeType_C", NULL)); 62 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMatStructure_C", NULL)); 63 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMatStructure_C", NULL)); 64 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeMerge_C", NULL)); 65 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetNumberMat_C", NULL)); 66 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMat_C", NULL)); 67 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetScalings_C", NULL)); 68 PetscCall(PetscFree(shell)); 69 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatShellSetContext_C", NULL)); // needed to avoid a call to MatShellSetContext_Immutable() 70 PetscFunctionReturn(PETSC_SUCCESS); 71 } 72 73 static PetscErrorCode MatMult_Composite_Multiplicative(Mat A, Vec x, Vec y) 74 { 75 Mat_Composite *shell; 76 Mat_CompositeLink next; 77 Vec out; 78 79 PetscFunctionBegin; 80 PetscCall(MatShellGetContext(A, &shell)); 81 next = shell->head; 82 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 83 while (next->next) { 84 if (!next->work) { /* should reuse previous work if the same size */ 85 PetscCall(MatCreateVecs(next->mat, NULL, &next->work)); 86 } 87 out = next->work; 88 PetscCall(MatMult(next->mat, x, out)); 89 x = out; 90 next = next->next; 91 } 92 PetscCall(MatMult(next->mat, x, y)); 93 if (shell->scalings) { 94 PetscScalar scale = 1.0; 95 for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i]; 96 PetscCall(VecScale(y, scale)); 97 } 98 PetscFunctionReturn(PETSC_SUCCESS); 99 } 100 101 static PetscErrorCode MatMultTranspose_Composite_Multiplicative(Mat A, Vec x, Vec y) 102 { 103 Mat_Composite *shell; 104 Mat_CompositeLink tail; 105 Vec out; 106 107 PetscFunctionBegin; 108 PetscCall(MatShellGetContext(A, &shell)); 109 tail = shell->tail; 110 PetscCheck(tail, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 111 while (tail->prev) { 112 if (!tail->prev->work) { /* should reuse previous work if the same size */ 113 PetscCall(MatCreateVecs(tail->mat, NULL, &tail->prev->work)); 114 } 115 out = tail->prev->work; 116 PetscCall(MatMultTranspose(tail->mat, x, out)); 117 x = out; 118 tail = tail->prev; 119 } 120 PetscCall(MatMultTranspose(tail->mat, x, y)); 121 if (shell->scalings) { 122 PetscScalar scale = 1.0; 123 for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i]; 124 PetscCall(VecScale(y, scale)); 125 } 126 PetscFunctionReturn(PETSC_SUCCESS); 127 } 128 129 static PetscErrorCode MatMult_Composite(Mat mat, Vec x, Vec y) 130 { 131 Mat_Composite *shell; 132 Mat_CompositeLink cur; 133 Vec y2, xin; 134 Mat A, B; 135 PetscInt i, j, k, n, nuniq, lo, hi, mid, *gindices, *buf, *tmp, tot; 136 const PetscScalar *vals; 137 const PetscInt *garray; 138 IS ix, iy; 139 PetscBool match; 140 141 PetscFunctionBegin; 142 PetscCall(MatShellGetContext(mat, &shell)); 143 cur = shell->head; 144 PetscCheck(cur, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 145 146 /* Try to merge Mvctx when instructed but not yet done. We did not do it in MatAssemblyEnd() since at that time 147 we did not know whether mat is ADDITIVE or MULTIPLICATIVE. Only now we are assured mat is ADDITIVE and 148 it is legal to merge Mvctx, because all component matrices have the same size. 149 */ 150 if (shell->merge_mvctx && !shell->Mvctx) { 151 /* Currently only implemented for MATMPIAIJ */ 152 for (cur = shell->head; cur; cur = cur->next) { 153 PetscCall(PetscObjectTypeCompare((PetscObject)cur->mat, MATMPIAIJ, &match)); 154 if (!match) { 155 shell->merge_mvctx = PETSC_FALSE; 156 goto skip_merge_mvctx; 157 } 158 } 159 160 /* Go through matrices first time to count total number of nonzero off-diag columns (may have dups) */ 161 tot = 0; 162 for (cur = shell->head; cur; cur = cur->next) { 163 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, NULL)); 164 PetscCall(MatGetLocalSize(B, NULL, &n)); 165 tot += n; 166 } 167 PetscCall(PetscMalloc3(tot, &shell->location, tot, &shell->larray, shell->nmat, &shell->lvecs)); 168 shell->len = tot; 169 170 /* Go through matrices second time to sort off-diag columns and remove dups */ 171 PetscCall(PetscMalloc1(tot, &gindices)); /* No Malloc2() since we will give one to petsc and free the other */ 172 PetscCall(PetscMalloc1(tot, &buf)); 173 nuniq = 0; /* Number of unique nonzero columns */ 174 for (cur = shell->head; cur; cur = cur->next) { 175 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray)); 176 PetscCall(MatGetLocalSize(B, NULL, &n)); 177 /* Merge pre-sorted garray[0,n) and gindices[0,nuniq) to buf[] */ 178 i = j = k = 0; 179 while (i < n && j < nuniq) { 180 if (garray[i] < gindices[j]) buf[k++] = garray[i++]; 181 else if (garray[i] > gindices[j]) buf[k++] = gindices[j++]; 182 else { 183 buf[k++] = garray[i++]; 184 j++; 185 } 186 } 187 /* Copy leftover in garray[] or gindices[] */ 188 if (i < n) { 189 PetscCall(PetscArraycpy(buf + k, garray + i, n - i)); 190 nuniq = k + n - i; 191 } else if (j < nuniq) { 192 PetscCall(PetscArraycpy(buf + k, gindices + j, nuniq - j)); 193 nuniq = k + nuniq - j; 194 } else nuniq = k; 195 /* Swap gindices and buf to merge garray of the next matrix */ 196 tmp = gindices; 197 gindices = buf; 198 buf = tmp; 199 } 200 PetscCall(PetscFree(buf)); 201 202 /* Go through matrices third time to build a map from gindices[] to garray[] */ 203 tot = 0; 204 for (cur = shell->head, j = 0; cur; cur = cur->next, j++) { /* j-th matrix */ 205 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray)); 206 PetscCall(MatGetLocalSize(B, NULL, &n)); 207 PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, 1, n, NULL, &shell->lvecs[j])); 208 /* This is an optimized PetscFindInt(garray[i],nuniq,gindices,&shell->location[tot+i]), using the fact that garray[] is also sorted */ 209 lo = 0; 210 for (i = 0; i < n; i++) { 211 hi = nuniq; 212 while (hi - lo > 1) { 213 mid = lo + (hi - lo) / 2; 214 if (garray[i] < gindices[mid]) hi = mid; 215 else lo = mid; 216 } 217 shell->location[tot + i] = lo; /* gindices[lo] = garray[i] */ 218 lo++; /* Since garray[i+1] > garray[i], we can safely advance lo */ 219 } 220 tot += n; 221 } 222 223 /* Build merged Mvctx */ 224 PetscCall(ISCreateGeneral(PETSC_COMM_SELF, nuniq, gindices, PETSC_OWN_POINTER, &ix)); 225 PetscCall(ISCreateStride(PETSC_COMM_SELF, nuniq, 0, 1, &iy)); 226 PetscCall(VecCreateMPIWithArray(PetscObjectComm((PetscObject)mat), 1, mat->cmap->n, mat->cmap->N, NULL, &xin)); 227 PetscCall(VecCreateSeq(PETSC_COMM_SELF, nuniq, &shell->gvec)); 228 PetscCall(VecScatterCreate(xin, ix, shell->gvec, iy, &shell->Mvctx)); 229 PetscCall(VecDestroy(&xin)); 230 PetscCall(ISDestroy(&ix)); 231 PetscCall(ISDestroy(&iy)); 232 } 233 234 skip_merge_mvctx: 235 PetscCall(VecSet(y, 0)); 236 if (!((Mat_Shell *)mat->data)->left_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)mat->data)->left_work))); 237 y2 = ((Mat_Shell *)mat->data)->left_work; 238 239 if (shell->Mvctx) { /* Have a merged Mvctx */ 240 /* Suppose we want to compute y = sMx, where s is the scaling factor and A, B are matrix M's diagonal/off-diagonal part. We could do 241 in y = s(Ax1 + Bx2) or y = sAx1 + sBx2. The former incurs less FLOPS than the latter, but the latter provides an opportunity to 242 overlap communication/computation since we can do sAx1 while communicating x2. Here, we use the former approach. 243 */ 244 PetscCall(VecScatterBegin(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD)); 245 PetscCall(VecScatterEnd(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD)); 246 247 PetscCall(VecGetArrayRead(shell->gvec, &vals)); 248 for (i = 0; i < shell->len; i++) shell->larray[i] = vals[shell->location[i]]; 249 PetscCall(VecRestoreArrayRead(shell->gvec, &vals)); 250 251 for (cur = shell->head, tot = i = 0; cur; cur = cur->next, i++) { /* i-th matrix */ 252 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, &A, &B, NULL)); 253 PetscUseTypeMethod(A, mult, x, y2); 254 PetscCall(MatGetLocalSize(B, NULL, &n)); 255 PetscCall(VecPlaceArray(shell->lvecs[i], &shell->larray[tot])); 256 PetscUseTypeMethod(B, multadd, shell->lvecs[i], y2, y2); 257 PetscCall(VecResetArray(shell->lvecs[i])); 258 PetscCall(VecAXPY(y, (shell->scalings ? shell->scalings[i] : 1.0), y2)); 259 tot += n; 260 } 261 } else { 262 if (shell->scalings) { 263 for (cur = shell->head, i = 0; cur; cur = cur->next, i++) { 264 PetscCall(MatMult(cur->mat, x, y2)); 265 PetscCall(VecAXPY(y, shell->scalings[i], y2)); 266 } 267 } else { 268 for (cur = shell->head; cur; cur = cur->next) PetscCall(MatMultAdd(cur->mat, x, y, y)); 269 } 270 } 271 PetscFunctionReturn(PETSC_SUCCESS); 272 } 273 274 static PetscErrorCode MatMultTranspose_Composite(Mat A, Vec x, Vec y) 275 { 276 Mat_Composite *shell; 277 Mat_CompositeLink next; 278 Vec y2 = NULL; 279 PetscInt i; 280 281 PetscFunctionBegin; 282 PetscCall(MatShellGetContext(A, &shell)); 283 next = shell->head; 284 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 285 286 PetscCall(MatMultTranspose(next->mat, x, y)); 287 if (shell->scalings) { 288 PetscCall(VecScale(y, shell->scalings[0])); 289 if (!((Mat_Shell *)A->data)->right_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)A->data)->right_work))); 290 y2 = ((Mat_Shell *)A->data)->right_work; 291 } 292 i = 1; 293 while ((next = next->next)) { 294 if (!shell->scalings) PetscCall(MatMultTransposeAdd(next->mat, x, y, y)); 295 else { 296 PetscCall(MatMultTranspose(next->mat, x, y2)); 297 PetscCall(VecAXPY(y, shell->scalings[i++], y2)); 298 } 299 } 300 PetscFunctionReturn(PETSC_SUCCESS); 301 } 302 303 static PetscErrorCode MatGetDiagonal_Composite(Mat A, Vec v) 304 { 305 Mat_Composite *shell; 306 Mat_CompositeLink next; 307 PetscInt i; 308 309 PetscFunctionBegin; 310 PetscCall(MatShellGetContext(A, &shell)); 311 next = shell->head; 312 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 313 PetscCall(MatGetDiagonal(next->mat, v)); 314 if (shell->scalings) PetscCall(VecScale(v, shell->scalings[0])); 315 316 if (next->next && !shell->work) PetscCall(VecDuplicate(v, &shell->work)); 317 i = 1; 318 while ((next = next->next)) { 319 PetscCall(MatGetDiagonal(next->mat, shell->work)); 320 PetscCall(VecAXPY(v, (shell->scalings ? shell->scalings[i++] : 1.0), shell->work)); 321 } 322 PetscFunctionReturn(PETSC_SUCCESS); 323 } 324 325 static PetscErrorCode MatAssemblyEnd_Composite(Mat Y, MatAssemblyType t) 326 { 327 Mat_Composite *shell; 328 329 PetscFunctionBegin; 330 PetscCall(MatShellGetContext(Y, &shell)); 331 if (shell->merge) PetscCall(MatCompositeMerge(Y)); 332 else PetscCall(MatAssemblyEnd_Shell(Y, t)); 333 PetscFunctionReturn(PETSC_SUCCESS); 334 } 335 336 static PetscErrorCode MatSetFromOptions_Composite(Mat A, PetscOptionItems *PetscOptionsObject) 337 { 338 Mat_Composite *a; 339 340 PetscFunctionBegin; 341 PetscCall(MatShellGetContext(A, &a)); 342 PetscOptionsHeadBegin(PetscOptionsObject, "MATCOMPOSITE options"); 343 PetscCall(PetscOptionsBool("-mat_composite_merge", "Merge at MatAssemblyEnd", "MatCompositeMerge", a->merge, &a->merge, NULL)); 344 PetscCall(PetscOptionsEnum("-mat_composite_merge_type", "Set composite merge direction", "MatCompositeSetMergeType", MatCompositeMergeTypes, (PetscEnum)a->mergetype, (PetscEnum *)&a->mergetype, NULL)); 345 PetscCall(PetscOptionsBool("-mat_composite_merge_mvctx", "Merge MatMult() vecscat contexts", "MatCreateComposite", a->merge_mvctx, &a->merge_mvctx, NULL)); 346 PetscOptionsHeadEnd(); 347 PetscFunctionReturn(PETSC_SUCCESS); 348 } 349 350 /*@ 351 MatCreateComposite - Creates a matrix as the sum or product of one or more matrices 352 353 Collective 354 355 Input Parameters: 356 + comm - MPI communicator 357 . nmat - number of matrices to put in 358 - mats - the matrices 359 360 Output Parameter: 361 . mat - the matrix 362 363 Options Database Keys: 364 + -mat_composite_merge - merge in `MatAssemblyEnd()` 365 . -mat_composite_merge_mvctx - merge Mvctx of component matrices to optimize communication in `MatMult()` for ADDITIVE matrices 366 - -mat_composite_merge_type - set merge direction 367 368 Level: advanced 369 370 Note: 371 Alternative construction 372 .vb 373 MatCreate(comm,&mat); 374 MatSetSizes(mat,m,n,M,N); 375 MatSetType(mat,MATCOMPOSITE); 376 MatCompositeAddMat(mat,mats[0]); 377 .... 378 MatCompositeAddMat(mat,mats[nmat-1]); 379 MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY); 380 MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY); 381 .ve 382 383 For the multiplicative form the product is mat[nmat-1]*mat[nmat-2]*....*mat[0] 384 385 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCompositeGetMat()`, `MatCompositeMerge()`, `MatCompositeSetType()`, 386 `MATCOMPOSITE`, `MatCompositeType` 387 @*/ 388 PetscErrorCode MatCreateComposite(MPI_Comm comm, PetscInt nmat, const Mat *mats, Mat *mat) 389 { 390 PetscInt m, n, M, N, i; 391 392 PetscFunctionBegin; 393 PetscCheck(nmat >= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Must pass in at least one matrix"); 394 PetscAssertPointer(mat, 4); 395 396 PetscCall(MatGetLocalSize(mats[0], PETSC_IGNORE, &n)); 397 PetscCall(MatGetLocalSize(mats[nmat - 1], &m, PETSC_IGNORE)); 398 PetscCall(MatGetSize(mats[0], PETSC_IGNORE, &N)); 399 PetscCall(MatGetSize(mats[nmat - 1], &M, PETSC_IGNORE)); 400 PetscCall(MatCreate(comm, mat)); 401 PetscCall(MatSetSizes(*mat, m, n, M, N)); 402 PetscCall(MatSetType(*mat, MATCOMPOSITE)); 403 for (i = 0; i < nmat; i++) PetscCall(MatCompositeAddMat(*mat, mats[i])); 404 PetscCall(MatAssemblyBegin(*mat, MAT_FINAL_ASSEMBLY)); 405 PetscCall(MatAssemblyEnd(*mat, MAT_FINAL_ASSEMBLY)); 406 PetscFunctionReturn(PETSC_SUCCESS); 407 } 408 409 static PetscErrorCode MatCompositeAddMat_Composite(Mat mat, Mat smat) 410 { 411 Mat_Composite *shell; 412 Mat_CompositeLink ilink, next; 413 VecType vtype_mat, vtype_smat; 414 PetscBool match; 415 416 PetscFunctionBegin; 417 PetscCall(MatShellGetContext(mat, &shell)); 418 next = shell->head; 419 PetscCall(PetscNew(&ilink)); 420 ilink->next = NULL; 421 PetscCall(PetscObjectReference((PetscObject)smat)); 422 ilink->mat = smat; 423 424 if (!next) shell->head = ilink; 425 else { 426 while (next->next) next = next->next; 427 next->next = ilink; 428 ilink->prev = next; 429 } 430 shell->tail = ilink; 431 shell->nmat += 1; 432 433 /* If all of the partial matrices have the same default vector type, then the composite matrix should also have this default type. 434 Otherwise, the default type should be "standard". */ 435 PetscCall(MatGetVecType(smat, &vtype_smat)); 436 if (shell->nmat == 1) PetscCall(MatSetVecType(mat, vtype_smat)); 437 else { 438 PetscCall(MatGetVecType(mat, &vtype_mat)); 439 PetscCall(PetscStrcmp(vtype_smat, vtype_mat, &match)); 440 if (!match) PetscCall(MatSetVecType(mat, VECSTANDARD)); 441 } 442 443 /* Retain the old scalings (if any) and expand it with a 1.0 for the newly added matrix */ 444 if (shell->scalings) { 445 PetscCall(PetscRealloc(sizeof(PetscScalar) * shell->nmat, &shell->scalings)); 446 shell->scalings[shell->nmat - 1] = 1.0; 447 } 448 PetscFunctionReturn(PETSC_SUCCESS); 449 } 450 451 /*@ 452 MatCompositeAddMat - Add another matrix to a composite matrix. 453 454 Collective 455 456 Input Parameters: 457 + mat - the composite matrix 458 - smat - the partial matrix 459 460 Level: advanced 461 462 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE` 463 @*/ 464 PetscErrorCode MatCompositeAddMat(Mat mat, Mat smat) 465 { 466 PetscFunctionBegin; 467 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 468 PetscValidHeaderSpecific(smat, MAT_CLASSID, 2); 469 PetscUseMethod(mat, "MatCompositeAddMat_C", (Mat, Mat), (mat, smat)); 470 PetscFunctionReturn(PETSC_SUCCESS); 471 } 472 473 static PetscErrorCode MatCompositeSetType_Composite(Mat mat, MatCompositeType type) 474 { 475 Mat_Composite *b; 476 477 PetscFunctionBegin; 478 PetscCall(MatShellGetContext(mat, &b)); 479 b->type = type; 480 if (type == MAT_COMPOSITE_MULTIPLICATIVE) { 481 PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, NULL)); 482 PetscCall(MatShellSetOperation(mat, MATOP_MULT, (void (*)(void))MatMult_Composite_Multiplicative)); 483 PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite_Multiplicative)); 484 b->merge_mvctx = PETSC_FALSE; 485 } else { 486 PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Composite)); 487 PetscCall(MatShellSetOperation(mat, MATOP_MULT, (void (*)(void))MatMult_Composite)); 488 PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite)); 489 } 490 PetscFunctionReturn(PETSC_SUCCESS); 491 } 492 493 /*@ 494 MatCompositeSetType - Indicates if the matrix is defined as the sum of a set of matrices or the product. 495 496 Logically Collective 497 498 Input Parameters: 499 + mat - the composite matrix 500 - type - the `MatCompositeType` to use for the matrix 501 502 Level: advanced 503 504 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeGetType()`, `MATCOMPOSITE`, 505 `MatCompositeType` 506 @*/ 507 PetscErrorCode MatCompositeSetType(Mat mat, MatCompositeType type) 508 { 509 PetscFunctionBegin; 510 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 511 PetscValidLogicalCollectiveEnum(mat, type, 2); 512 PetscUseMethod(mat, "MatCompositeSetType_C", (Mat, MatCompositeType), (mat, type)); 513 PetscFunctionReturn(PETSC_SUCCESS); 514 } 515 516 static PetscErrorCode MatCompositeGetType_Composite(Mat mat, MatCompositeType *type) 517 { 518 Mat_Composite *shell; 519 520 PetscFunctionBegin; 521 PetscCall(MatShellGetContext(mat, &shell)); 522 *type = shell->type; 523 PetscFunctionReturn(PETSC_SUCCESS); 524 } 525 526 /*@ 527 MatCompositeGetType - Returns type of composite. 528 529 Not Collective 530 531 Input Parameter: 532 . mat - the composite matrix 533 534 Output Parameter: 535 . type - type of composite 536 537 Level: advanced 538 539 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetType()`, `MATCOMPOSITE`, `MatCompositeType` 540 @*/ 541 PetscErrorCode MatCompositeGetType(Mat mat, MatCompositeType *type) 542 { 543 PetscFunctionBegin; 544 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 545 PetscAssertPointer(type, 2); 546 PetscUseMethod(mat, "MatCompositeGetType_C", (Mat, MatCompositeType *), (mat, type)); 547 PetscFunctionReturn(PETSC_SUCCESS); 548 } 549 550 static PetscErrorCode MatCompositeSetMatStructure_Composite(Mat mat, MatStructure str) 551 { 552 Mat_Composite *shell; 553 554 PetscFunctionBegin; 555 PetscCall(MatShellGetContext(mat, &shell)); 556 shell->structure = str; 557 PetscFunctionReturn(PETSC_SUCCESS); 558 } 559 560 /*@ 561 MatCompositeSetMatStructure - Indicates structure of matrices in the composite matrix. 562 563 Not Collective 564 565 Input Parameters: 566 + mat - the composite matrix 567 - str - either `SAME_NONZERO_PATTERN`, `DIFFERENT_NONZERO_PATTERN` (default) or `SUBSET_NONZERO_PATTERN` 568 569 Level: advanced 570 571 Note: 572 Information about the matrices structure is used in `MatCompositeMerge()` for additive composite matrix. 573 574 .seealso: [](ch_matrices), `Mat`, `MatAXPY()`, `MatCreateComposite()`, `MatCompositeMerge()` `MatCompositeGetMatStructure()`, `MATCOMPOSITE` 575 @*/ 576 PetscErrorCode MatCompositeSetMatStructure(Mat mat, MatStructure str) 577 { 578 PetscFunctionBegin; 579 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 580 PetscUseMethod(mat, "MatCompositeSetMatStructure_C", (Mat, MatStructure), (mat, str)); 581 PetscFunctionReturn(PETSC_SUCCESS); 582 } 583 584 static PetscErrorCode MatCompositeGetMatStructure_Composite(Mat mat, MatStructure *str) 585 { 586 Mat_Composite *shell; 587 588 PetscFunctionBegin; 589 PetscCall(MatShellGetContext(mat, &shell)); 590 *str = shell->structure; 591 PetscFunctionReturn(PETSC_SUCCESS); 592 } 593 594 /*@ 595 MatCompositeGetMatStructure - Returns the structure of matrices in the composite matrix. 596 597 Not Collective 598 599 Input Parameter: 600 . mat - the composite matrix 601 602 Output Parameter: 603 . str - structure of the matrices 604 605 Level: advanced 606 607 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MATCOMPOSITE` 608 @*/ 609 PetscErrorCode MatCompositeGetMatStructure(Mat mat, MatStructure *str) 610 { 611 PetscFunctionBegin; 612 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 613 PetscAssertPointer(str, 2); 614 PetscUseMethod(mat, "MatCompositeGetMatStructure_C", (Mat, MatStructure *), (mat, str)); 615 PetscFunctionReturn(PETSC_SUCCESS); 616 } 617 618 static PetscErrorCode MatCompositeSetMergeType_Composite(Mat mat, MatCompositeMergeType type) 619 { 620 Mat_Composite *shell; 621 622 PetscFunctionBegin; 623 PetscCall(MatShellGetContext(mat, &shell)); 624 shell->mergetype = type; 625 PetscFunctionReturn(PETSC_SUCCESS); 626 } 627 628 /*@ 629 MatCompositeSetMergeType - Sets order of `MatCompositeMerge()`. 630 631 Logically Collective 632 633 Input Parameters: 634 + mat - the composite matrix 635 - type - `MAT_COMPOSITE_MERGE RIGHT` (default) to start merge from right with the first added matrix (mat[0]), 636 `MAT_COMPOSITE_MERGE_LEFT` to start merge from left with the last added matrix (mat[nmat-1]) 637 638 Level: advanced 639 640 Note: 641 The resulting matrix is the same regardless of the `MatCompositeMergeType`. Only the order of operation is changed. 642 If set to `MAT_COMPOSITE_MERGE_RIGHT` the order of the merge is mat[nmat-1]*(mat[nmat-2]*(...*(mat[1]*mat[0]))) 643 otherwise the order is (((mat[nmat-1]*mat[nmat-2])*mat[nmat-3])*...)*mat[0]. 644 645 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeMerge()`, `MATCOMPOSITE` 646 @*/ 647 PetscErrorCode MatCompositeSetMergeType(Mat mat, MatCompositeMergeType type) 648 { 649 PetscFunctionBegin; 650 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 651 PetscValidLogicalCollectiveEnum(mat, type, 2); 652 PetscUseMethod(mat, "MatCompositeSetMergeType_C", (Mat, MatCompositeMergeType), (mat, type)); 653 PetscFunctionReturn(PETSC_SUCCESS); 654 } 655 656 static PetscErrorCode MatCompositeMerge_Composite(Mat mat) 657 { 658 Mat_Composite *shell; 659 Mat_CompositeLink next, prev; 660 Mat tmat, newmat; 661 Vec left, right, dshift; 662 PetscScalar scale, shift; 663 PetscInt i; 664 665 PetscFunctionBegin; 666 PetscCall(MatShellGetContext(mat, &shell)); 667 next = shell->head; 668 prev = shell->tail; 669 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 670 PetscCall(MatShellGetScalingShifts(mat, &shift, &scale, &dshift, &left, &right, (Mat *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED)); 671 if (shell->type == MAT_COMPOSITE_ADDITIVE) { 672 if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) { 673 i = 0; 674 PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat)); 675 if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i++])); 676 while ((next = next->next)) PetscCall(MatAXPY(tmat, (shell->scalings ? shell->scalings[i++] : 1.0), next->mat, shell->structure)); 677 } else { 678 i = shell->nmat - 1; 679 PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat)); 680 if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i--])); 681 while ((prev = prev->prev)) PetscCall(MatAXPY(tmat, (shell->scalings ? shell->scalings[i--] : 1.0), prev->mat, shell->structure)); 682 } 683 } else { 684 if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) { 685 PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat)); 686 while ((next = next->next)) { 687 PetscCall(MatMatMult(next->mat, tmat, MAT_INITIAL_MATRIX, PETSC_DECIDE, &newmat)); 688 PetscCall(MatDestroy(&tmat)); 689 tmat = newmat; 690 } 691 } else { 692 PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat)); 693 while ((prev = prev->prev)) { 694 PetscCall(MatMatMult(tmat, prev->mat, MAT_INITIAL_MATRIX, PETSC_DECIDE, &newmat)); 695 PetscCall(MatDestroy(&tmat)); 696 tmat = newmat; 697 } 698 } 699 if (shell->scalings) { 700 for (i = 0; i < shell->nmat; i++) scale *= shell->scalings[i]; 701 } 702 } 703 704 if (left) PetscCall(PetscObjectReference((PetscObject)left)); 705 if (right) PetscCall(PetscObjectReference((PetscObject)right)); 706 if (dshift) PetscCall(PetscObjectReference((PetscObject)dshift)); 707 708 PetscCall(MatHeaderReplace(mat, &tmat)); 709 710 PetscCall(MatDiagonalScale(mat, left, right)); 711 PetscCall(MatScale(mat, scale)); 712 PetscCall(MatShift(mat, shift)); 713 PetscCall(VecDestroy(&left)); 714 PetscCall(VecDestroy(&right)); 715 if (dshift) { 716 PetscCall(MatDiagonalSet(mat, dshift, ADD_VALUES)); 717 PetscCall(VecDestroy(&dshift)); 718 } 719 PetscFunctionReturn(PETSC_SUCCESS); 720 } 721 722 /*@ 723 MatCompositeMerge - Given a composite matrix, replaces it with a "regular" matrix 724 by summing or computing the product of all the matrices inside the composite matrix. 725 726 Collective 727 728 Input Parameter: 729 . mat - the composite matrix 730 731 Options Database Keys: 732 + -mat_composite_merge - merge in `MatAssemblyEnd()` 733 - -mat_composite_merge_type - set merge direction 734 735 Level: advanced 736 737 Note: 738 The `MatType` of the resulting matrix will be the same as the `MatType` of the FIRST matrix in the composite matrix. 739 740 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MatCompositeSetMergeType()`, `MATCOMPOSITE` 741 @*/ 742 PetscErrorCode MatCompositeMerge(Mat mat) 743 { 744 PetscFunctionBegin; 745 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 746 PetscUseMethod(mat, "MatCompositeMerge_C", (Mat), (mat)); 747 PetscFunctionReturn(PETSC_SUCCESS); 748 } 749 750 static PetscErrorCode MatCompositeGetNumberMat_Composite(Mat mat, PetscInt *nmat) 751 { 752 Mat_Composite *shell; 753 754 PetscFunctionBegin; 755 PetscCall(MatShellGetContext(mat, &shell)); 756 *nmat = shell->nmat; 757 PetscFunctionReturn(PETSC_SUCCESS); 758 } 759 760 /*@ 761 MatCompositeGetNumberMat - Returns the number of matrices in the composite matrix. 762 763 Not Collective 764 765 Input Parameter: 766 . mat - the composite matrix 767 768 Output Parameter: 769 . nmat - number of matrices in the composite matrix 770 771 Level: advanced 772 773 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE` 774 @*/ 775 PetscErrorCode MatCompositeGetNumberMat(Mat mat, PetscInt *nmat) 776 { 777 PetscFunctionBegin; 778 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 779 PetscAssertPointer(nmat, 2); 780 PetscUseMethod(mat, "MatCompositeGetNumberMat_C", (Mat, PetscInt *), (mat, nmat)); 781 PetscFunctionReturn(PETSC_SUCCESS); 782 } 783 784 static PetscErrorCode MatCompositeGetMat_Composite(Mat mat, PetscInt i, Mat *Ai) 785 { 786 Mat_Composite *shell; 787 Mat_CompositeLink ilink; 788 PetscInt k; 789 790 PetscFunctionBegin; 791 PetscCall(MatShellGetContext(mat, &shell)); 792 PetscCheck(i < shell->nmat, PetscObjectComm((PetscObject)mat), PETSC_ERR_ARG_OUTOFRANGE, "index out of range: %" PetscInt_FMT " >= %" PetscInt_FMT, i, shell->nmat); 793 ilink = shell->head; 794 for (k = 0; k < i; k++) ilink = ilink->next; 795 *Ai = ilink->mat; 796 PetscFunctionReturn(PETSC_SUCCESS); 797 } 798 799 /*@ 800 MatCompositeGetMat - Returns the ith matrix from the composite matrix. 801 802 Logically Collective 803 804 Input Parameters: 805 + mat - the composite matrix 806 - i - the number of requested matrix 807 808 Output Parameter: 809 . Ai - ith matrix in composite 810 811 Level: advanced 812 813 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetNumberMat()`, `MatCompositeAddMat()`, `MATCOMPOSITE` 814 @*/ 815 PetscErrorCode MatCompositeGetMat(Mat mat, PetscInt i, Mat *Ai) 816 { 817 PetscFunctionBegin; 818 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 819 PetscValidLogicalCollectiveInt(mat, i, 2); 820 PetscAssertPointer(Ai, 3); 821 PetscUseMethod(mat, "MatCompositeGetMat_C", (Mat, PetscInt, Mat *), (mat, i, Ai)); 822 PetscFunctionReturn(PETSC_SUCCESS); 823 } 824 825 static PetscErrorCode MatCompositeSetScalings_Composite(Mat mat, const PetscScalar *scalings) 826 { 827 Mat_Composite *shell; 828 PetscInt nmat; 829 830 PetscFunctionBegin; 831 PetscCall(MatShellGetContext(mat, &shell)); 832 PetscCall(MatCompositeGetNumberMat(mat, &nmat)); 833 if (!shell->scalings) PetscCall(PetscMalloc1(nmat, &shell->scalings)); 834 PetscCall(PetscArraycpy(shell->scalings, scalings, nmat)); 835 PetscFunctionReturn(PETSC_SUCCESS); 836 } 837 838 /*@ 839 MatCompositeSetScalings - Sets separate scaling factors for component matrices. 840 841 Logically Collective 842 843 Input Parameters: 844 + mat - the composite matrix 845 - scalings - array of scaling factors with scalings[i] being factor of i-th matrix, for i in [0, nmat) 846 847 Level: advanced 848 849 .seealso: [](ch_matrices), `Mat`, `MatScale()`, `MatDiagonalScale()`, `MATCOMPOSITE` 850 @*/ 851 PetscErrorCode MatCompositeSetScalings(Mat mat, const PetscScalar *scalings) 852 { 853 PetscFunctionBegin; 854 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 855 PetscAssertPointer(scalings, 2); 856 PetscValidLogicalCollectiveScalar(mat, *scalings, 2); 857 PetscUseMethod(mat, "MatCompositeSetScalings_C", (Mat, const PetscScalar *), (mat, scalings)); 858 PetscFunctionReturn(PETSC_SUCCESS); 859 } 860 861 /*MC 862 MATCOMPOSITE - A matrix defined by the sum (or product) of one or more matrices. 863 The matrices need to have a correct size and parallel layout for the sum or product to be valid. 864 865 Level: advanced 866 867 Note: 868 To use the product of the matrices call `MatCompositeSetType`(mat,`MAT_COMPOSITE_MULTIPLICATIVE`); 869 870 Developer Notes: 871 This is implemented on top of `MATSHELL` to get support for scaling and shifting without requiring duplicate code 872 873 Users can not call `MatShellSetOperation()` operations on this class, there is some error checking for that incorrect usage 874 875 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetScalings()`, `MatCompositeAddMat()`, `MatSetType()`, `MatCompositeSetType()`, `MatCompositeGetType()`, 876 `MatCompositeSetMatStructure()`, `MatCompositeGetMatStructure()`, `MatCompositeMerge()`, `MatCompositeSetMergeType()`, `MatCompositeGetNumberMat()`, `MatCompositeGetMat()` 877 M*/ 878 879 PETSC_EXTERN PetscErrorCode MatCreate_Composite(Mat A) 880 { 881 Mat_Composite *b; 882 883 PetscFunctionBegin; 884 PetscCall(PetscNew(&b)); 885 886 b->type = MAT_COMPOSITE_ADDITIVE; 887 b->nmat = 0; 888 b->merge = PETSC_FALSE; 889 b->mergetype = MAT_COMPOSITE_MERGE_RIGHT; 890 b->structure = DIFFERENT_NONZERO_PATTERN; 891 b->merge_mvctx = PETSC_TRUE; 892 893 PetscCall(MatSetType(A, MATSHELL)); 894 PetscCall(MatShellSetContext(A, b)); 895 PetscCall(MatShellSetOperation(A, MATOP_DESTROY, (void (*)(void))MatDestroy_Composite)); 896 PetscCall(MatShellSetOperation(A, MATOP_MULT, (void (*)(void))MatMult_Composite)); 897 PetscCall(MatShellSetOperation(A, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite)); 898 PetscCall(MatShellSetOperation(A, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Composite)); 899 PetscCall(MatShellSetOperation(A, MATOP_ASSEMBLY_END, (void (*)(void))MatAssemblyEnd_Composite)); 900 PetscCall(MatShellSetOperation(A, MATOP_SET_FROM_OPTIONS, (void (*)(void))MatSetFromOptions_Composite)); 901 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeAddMat_C", MatCompositeAddMat_Composite)); 902 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetType_C", MatCompositeSetType_Composite)); 903 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetType_C", MatCompositeGetType_Composite)); 904 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMergeType_C", MatCompositeSetMergeType_Composite)); 905 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMatStructure_C", MatCompositeSetMatStructure_Composite)); 906 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMatStructure_C", MatCompositeGetMatStructure_Composite)); 907 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeMerge_C", MatCompositeMerge_Composite)); 908 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetNumberMat_C", MatCompositeGetNumberMat_Composite)); 909 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMat_C", MatCompositeGetMat_Composite)); 910 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetScalings_C", MatCompositeSetScalings_Composite)); 911 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContext_C", MatShellSetContext_Immutable)); 912 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContextDestroy_C", MatShellSetContextDestroy_Immutable)); 913 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetManageScalingShifts_C", MatShellSetManageScalingShifts_Immutable)); 914 PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATCOMPOSITE)); 915 PetscFunctionReturn(PETSC_SUCCESS); 916 } 917