1 #include <../src/mat/impls/shell/shell.h> /*I "petscmat.h" I*/ 2 3 const char *const MatCompositeMergeTypes[] = {"left", "right", "MatCompositeMergeType", "MAT_COMPOSITE_", NULL}; 4 5 typedef struct _Mat_CompositeLink *Mat_CompositeLink; 6 struct _Mat_CompositeLink { 7 Mat mat; 8 Vec work; 9 Mat_CompositeLink next, prev; 10 }; 11 12 typedef struct { 13 MatCompositeType type; 14 Mat_CompositeLink head, tail; 15 Vec work; 16 PetscInt nmat; 17 PetscBool merge; 18 MatCompositeMergeType mergetype; 19 MatStructure structure; 20 21 PetscScalar *scalings; 22 PetscBool merge_mvctx; /* Whether need to merge mvctx of component matrices */ 23 Vec *lvecs; /* [nmat] Basically, they are Mvctx->lvec of each component matrix */ 24 PetscScalar *larray; /* [len] Data arrays of lvecs[] are stored consecutively in larray */ 25 PetscInt len; /* Length of larray[] */ 26 Vec gvec; /* Union of lvecs[] without duplicated entries */ 27 PetscInt *location; /* A map that maps entries in garray[] to larray[] */ 28 VecScatter Mvctx; 29 } Mat_Composite; 30 31 static PetscErrorCode MatDestroy_Composite(Mat mat) 32 { 33 Mat_Composite *shell; 34 Mat_CompositeLink next, oldnext; 35 PetscInt i; 36 37 PetscFunctionBegin; 38 PetscCall(MatShellGetContext(mat, &shell)); 39 next = shell->head; 40 while (next) { 41 PetscCall(MatDestroy(&next->mat)); 42 if (next->work && (!next->next || next->work != next->next->work)) PetscCall(VecDestroy(&next->work)); 43 oldnext = next; 44 next = next->next; 45 PetscCall(PetscFree(oldnext)); 46 } 47 PetscCall(VecDestroy(&shell->work)); 48 49 if (shell->Mvctx) { 50 for (i = 0; i < shell->nmat; i++) PetscCall(VecDestroy(&shell->lvecs[i])); 51 PetscCall(PetscFree3(shell->location, shell->larray, shell->lvecs)); 52 PetscCall(PetscFree(shell->larray)); 53 PetscCall(VecDestroy(&shell->gvec)); 54 PetscCall(VecScatterDestroy(&shell->Mvctx)); 55 } 56 57 PetscCall(PetscFree(shell->scalings)); 58 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeAddMat_C", NULL)); 59 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetType_C", NULL)); 60 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetType_C", NULL)); 61 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMergeType_C", NULL)); 62 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMatStructure_C", NULL)); 63 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMatStructure_C", NULL)); 64 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeMerge_C", NULL)); 65 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetNumberMat_C", NULL)); 66 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMat_C", NULL)); 67 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetScalings_C", NULL)); 68 PetscCall(PetscFree(shell)); 69 PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatShellSetContext_C", NULL)); // needed to avoid a call to MatShellSetContext_Immutable() 70 PetscFunctionReturn(PETSC_SUCCESS); 71 } 72 73 static PetscErrorCode MatMult_Composite_Multiplicative(Mat A, Vec x, Vec y) 74 { 75 Mat_Composite *shell; 76 Mat_CompositeLink next; 77 Vec out; 78 79 PetscFunctionBegin; 80 PetscCall(MatShellGetContext(A, &shell)); 81 next = shell->head; 82 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 83 while (next->next) { 84 if (!next->work) { /* should reuse previous work if the same size */ 85 PetscCall(MatCreateVecs(next->mat, NULL, &next->work)); 86 } 87 out = next->work; 88 PetscCall(MatMult(next->mat, x, out)); 89 x = out; 90 next = next->next; 91 } 92 PetscCall(MatMult(next->mat, x, y)); 93 if (shell->scalings) { 94 PetscScalar scale = 1.0; 95 for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i]; 96 PetscCall(VecScale(y, scale)); 97 } 98 PetscFunctionReturn(PETSC_SUCCESS); 99 } 100 101 static PetscErrorCode MatMultTranspose_Composite_Multiplicative(Mat A, Vec x, Vec y) 102 { 103 Mat_Composite *shell; 104 Mat_CompositeLink tail; 105 Vec out; 106 107 PetscFunctionBegin; 108 PetscCall(MatShellGetContext(A, &shell)); 109 tail = shell->tail; 110 PetscCheck(tail, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 111 while (tail->prev) { 112 if (!tail->prev->work) { /* should reuse previous work if the same size */ 113 PetscCall(MatCreateVecs(tail->mat, NULL, &tail->prev->work)); 114 } 115 out = tail->prev->work; 116 PetscCall(MatMultTranspose(tail->mat, x, out)); 117 x = out; 118 tail = tail->prev; 119 } 120 PetscCall(MatMultTranspose(tail->mat, x, y)); 121 if (shell->scalings) { 122 PetscScalar scale = 1.0; 123 for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i]; 124 PetscCall(VecScale(y, scale)); 125 } 126 PetscFunctionReturn(PETSC_SUCCESS); 127 } 128 129 static PetscErrorCode MatMult_Composite(Mat mat, Vec x, Vec y) 130 { 131 Mat_Composite *shell; 132 Mat_CompositeLink cur; 133 Vec y2, xin; 134 Mat A, B; 135 PetscInt i, j, k, n, nuniq, lo, hi, mid, *gindices, *buf, *tmp, tot; 136 const PetscScalar *vals; 137 const PetscInt *garray; 138 IS ix, iy; 139 PetscBool match; 140 141 PetscFunctionBegin; 142 PetscCall(MatShellGetContext(mat, &shell)); 143 cur = shell->head; 144 PetscCheck(cur, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 145 146 /* Try to merge Mvctx when instructed but not yet done. We did not do it in MatAssemblyEnd() since at that time 147 we did not know whether mat is ADDITIVE or MULTIPLICATIVE. Only now we are assured mat is ADDITIVE and 148 it is legal to merge Mvctx, because all component matrices have the same size. 149 */ 150 if (shell->merge_mvctx && !shell->Mvctx) { 151 /* Currently only implemented for MATMPIAIJ */ 152 for (cur = shell->head; cur; cur = cur->next) { 153 PetscCall(PetscObjectTypeCompare((PetscObject)cur->mat, MATMPIAIJ, &match)); 154 if (!match) { 155 shell->merge_mvctx = PETSC_FALSE; 156 goto skip_merge_mvctx; 157 } 158 } 159 160 /* Go through matrices first time to count total number of nonzero off-diag columns (may have dups) */ 161 tot = 0; 162 for (cur = shell->head; cur; cur = cur->next) { 163 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, NULL)); 164 PetscCall(MatGetLocalSize(B, NULL, &n)); 165 tot += n; 166 } 167 PetscCall(PetscMalloc3(tot, &shell->location, tot, &shell->larray, shell->nmat, &shell->lvecs)); 168 shell->len = tot; 169 170 /* Go through matrices second time to sort off-diag columns and remove dups */ 171 PetscCall(PetscMalloc1(tot, &gindices)); /* No Malloc2() since we will give one to petsc and free the other */ 172 PetscCall(PetscMalloc1(tot, &buf)); 173 nuniq = 0; /* Number of unique nonzero columns */ 174 for (cur = shell->head; cur; cur = cur->next) { 175 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray)); 176 PetscCall(MatGetLocalSize(B, NULL, &n)); 177 /* Merge pre-sorted garray[0,n) and gindices[0,nuniq) to buf[] */ 178 i = j = k = 0; 179 while (i < n && j < nuniq) { 180 if (garray[i] < gindices[j]) buf[k++] = garray[i++]; 181 else if (garray[i] > gindices[j]) buf[k++] = gindices[j++]; 182 else { 183 buf[k++] = garray[i++]; 184 j++; 185 } 186 } 187 /* Copy leftover in garray[] or gindices[] */ 188 if (i < n) { 189 PetscCall(PetscArraycpy(buf + k, garray + i, n - i)); 190 nuniq = k + n - i; 191 } else if (j < nuniq) { 192 PetscCall(PetscArraycpy(buf + k, gindices + j, nuniq - j)); 193 nuniq = k + nuniq - j; 194 } else nuniq = k; 195 /* Swap gindices and buf to merge garray of the next matrix */ 196 tmp = gindices; 197 gindices = buf; 198 buf = tmp; 199 } 200 PetscCall(PetscFree(buf)); 201 202 /* Go through matrices third time to build a map from gindices[] to garray[] */ 203 tot = 0; 204 for (cur = shell->head, j = 0; cur; cur = cur->next, j++) { /* j-th matrix */ 205 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray)); 206 PetscCall(MatGetLocalSize(B, NULL, &n)); 207 PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, 1, n, NULL, &shell->lvecs[j])); 208 /* This is an optimized PetscFindInt(garray[i],nuniq,gindices,&shell->location[tot+i]), using the fact that garray[] is also sorted */ 209 lo = 0; 210 for (i = 0; i < n; i++) { 211 hi = nuniq; 212 while (hi - lo > 1) { 213 mid = lo + (hi - lo) / 2; 214 if (garray[i] < gindices[mid]) hi = mid; 215 else lo = mid; 216 } 217 shell->location[tot + i] = lo; /* gindices[lo] = garray[i] */ 218 lo++; /* Since garray[i+1] > garray[i], we can safely advance lo */ 219 } 220 tot += n; 221 } 222 223 /* Build merged Mvctx */ 224 PetscCall(ISCreateGeneral(PETSC_COMM_SELF, nuniq, gindices, PETSC_OWN_POINTER, &ix)); 225 PetscCall(ISCreateStride(PETSC_COMM_SELF, nuniq, 0, 1, &iy)); 226 PetscCall(VecCreateMPIWithArray(PetscObjectComm((PetscObject)mat), 1, mat->cmap->n, mat->cmap->N, NULL, &xin)); 227 PetscCall(VecCreateSeq(PETSC_COMM_SELF, nuniq, &shell->gvec)); 228 PetscCall(VecScatterCreate(xin, ix, shell->gvec, iy, &shell->Mvctx)); 229 PetscCall(VecDestroy(&xin)); 230 PetscCall(ISDestroy(&ix)); 231 PetscCall(ISDestroy(&iy)); 232 } 233 234 skip_merge_mvctx: 235 PetscCall(VecSet(y, 0)); 236 if (!((Mat_Shell *)mat->data)->left_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)mat->data)->left_work))); 237 y2 = ((Mat_Shell *)mat->data)->left_work; 238 239 if (shell->Mvctx) { /* Have a merged Mvctx */ 240 /* Suppose we want to compute y = sMx, where s is the scaling factor and A, B are matrix M's diagonal/off-diagonal part. We could do 241 in y = s(Ax1 + Bx2) or y = sAx1 + sBx2. The former incurs less FLOPS than the latter, but the latter provides an opportunity to 242 overlap communication/computation since we can do sAx1 while communicating x2. Here, we use the former approach. 243 */ 244 PetscCall(VecScatterBegin(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD)); 245 PetscCall(VecScatterEnd(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD)); 246 247 PetscCall(VecGetArrayRead(shell->gvec, &vals)); 248 for (i = 0; i < shell->len; i++) shell->larray[i] = vals[shell->location[i]]; 249 PetscCall(VecRestoreArrayRead(shell->gvec, &vals)); 250 251 for (cur = shell->head, tot = i = 0; cur; cur = cur->next, i++) { /* i-th matrix */ 252 PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, &A, &B, NULL)); 253 PetscUseTypeMethod(A, mult, x, y2); 254 PetscCall(MatGetLocalSize(B, NULL, &n)); 255 PetscCall(VecPlaceArray(shell->lvecs[i], &shell->larray[tot])); 256 PetscUseTypeMethod(B, multadd, shell->lvecs[i], y2, y2); 257 PetscCall(VecResetArray(shell->lvecs[i])); 258 PetscCall(VecAXPY(y, (shell->scalings ? shell->scalings[i] : 1.0), y2)); 259 tot += n; 260 } 261 } else { 262 if (shell->scalings) { 263 for (cur = shell->head, i = 0; cur; cur = cur->next, i++) { 264 PetscCall(MatMult(cur->mat, x, y2)); 265 PetscCall(VecAXPY(y, shell->scalings[i], y2)); 266 } 267 } else { 268 for (cur = shell->head; cur; cur = cur->next) PetscCall(MatMultAdd(cur->mat, x, y, y)); 269 } 270 } 271 PetscFunctionReturn(PETSC_SUCCESS); 272 } 273 274 static PetscErrorCode MatMultTranspose_Composite(Mat A, Vec x, Vec y) 275 { 276 Mat_Composite *shell; 277 Mat_CompositeLink next; 278 Vec y2 = NULL; 279 PetscInt i; 280 281 PetscFunctionBegin; 282 PetscCall(MatShellGetContext(A, &shell)); 283 next = shell->head; 284 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 285 286 PetscCall(MatMultTranspose(next->mat, x, y)); 287 if (shell->scalings) { 288 PetscCall(VecScale(y, shell->scalings[0])); 289 if (!((Mat_Shell *)A->data)->right_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)A->data)->right_work))); 290 y2 = ((Mat_Shell *)A->data)->right_work; 291 } 292 i = 1; 293 while ((next = next->next)) { 294 if (!shell->scalings) PetscCall(MatMultTransposeAdd(next->mat, x, y, y)); 295 else { 296 PetscCall(MatMultTranspose(next->mat, x, y2)); 297 PetscCall(VecAXPY(y, shell->scalings[i++], y2)); 298 } 299 } 300 PetscFunctionReturn(PETSC_SUCCESS); 301 } 302 303 static PetscErrorCode MatGetDiagonal_Composite(Mat A, Vec v) 304 { 305 Mat_Composite *shell; 306 Mat_CompositeLink next; 307 PetscInt i; 308 309 PetscFunctionBegin; 310 PetscCall(MatShellGetContext(A, &shell)); 311 next = shell->head; 312 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 313 PetscCall(MatGetDiagonal(next->mat, v)); 314 if (shell->scalings) PetscCall(VecScale(v, shell->scalings[0])); 315 316 if (next->next && !shell->work) PetscCall(VecDuplicate(v, &shell->work)); 317 i = 1; 318 while ((next = next->next)) { 319 PetscCall(MatGetDiagonal(next->mat, shell->work)); 320 PetscCall(VecAXPY(v, (shell->scalings ? shell->scalings[i++] : 1.0), shell->work)); 321 } 322 PetscFunctionReturn(PETSC_SUCCESS); 323 } 324 325 static PetscErrorCode MatAssemblyEnd_Composite(Mat Y, MatAssemblyType t) 326 { 327 Mat_Composite *shell; 328 329 PetscFunctionBegin; 330 PetscCall(MatShellGetContext(Y, &shell)); 331 if (shell->merge) PetscCall(MatCompositeMerge(Y)); 332 else PetscCall(MatAssemblyEnd_Shell(Y, t)); 333 PetscFunctionReturn(PETSC_SUCCESS); 334 } 335 336 static PetscErrorCode MatSetFromOptions_Composite(Mat A, PetscOptionItems *PetscOptionsObject) 337 { 338 Mat_Composite *a; 339 340 PetscFunctionBegin; 341 PetscCall(MatShellGetContext(A, &a)); 342 PetscOptionsHeadBegin(PetscOptionsObject, "MATCOMPOSITE options"); 343 PetscCall(PetscOptionsBool("-mat_composite_merge", "Merge at MatAssemblyEnd", "MatCompositeMerge", a->merge, &a->merge, NULL)); 344 PetscCall(PetscOptionsEnum("-mat_composite_merge_type", "Set composite merge direction", "MatCompositeSetMergeType", MatCompositeMergeTypes, (PetscEnum)a->mergetype, (PetscEnum *)&a->mergetype, NULL)); 345 PetscCall(PetscOptionsBool("-mat_composite_merge_mvctx", "Merge MatMult() vecscat contexts", "MatCreateComposite", a->merge_mvctx, &a->merge_mvctx, NULL)); 346 PetscOptionsHeadEnd(); 347 PetscFunctionReturn(PETSC_SUCCESS); 348 } 349 350 /*@ 351 MatCreateComposite - Creates a matrix as the sum or product of one or more matrices 352 353 Collective 354 355 Input Parameters: 356 + comm - MPI communicator 357 . nmat - number of matrices to put in 358 - mats - the matrices 359 360 Output Parameter: 361 . mat - the matrix 362 363 Options Database Keys: 364 + -mat_composite_merge - merge in `MatAssemblyEnd()` 365 . -mat_composite_merge_mvctx - merge Mvctx of component matrices to optimize communication in `MatMult()` for ADDITIVE matrices 366 - -mat_composite_merge_type - set merge direction 367 368 Level: advanced 369 370 Note: 371 Alternative construction 372 .vb 373 MatCreate(comm,&mat); 374 MatSetSizes(mat,m,n,M,N); 375 MatSetType(mat,MATCOMPOSITE); 376 MatCompositeAddMat(mat,mats[0]); 377 .... 378 MatCompositeAddMat(mat,mats[nmat-1]); 379 MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY); 380 MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY); 381 .ve 382 383 For the multiplicative form the product is mat[nmat-1]*mat[nmat-2]*....*mat[0] 384 385 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCompositeGetMat()`, `MatCompositeMerge()`, `MatCompositeSetType()`, 386 `MATCOMPOSITE`, `MatCompositeType` 387 @*/ 388 PetscErrorCode MatCreateComposite(MPI_Comm comm, PetscInt nmat, const Mat *mats, Mat *mat) 389 { 390 PetscInt m, n, M, N, i; 391 392 PetscFunctionBegin; 393 PetscCheck(nmat >= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Must pass in at least one matrix"); 394 PetscAssertPointer(mat, 4); 395 396 PetscCall(MatGetLocalSize(mats[0], PETSC_IGNORE, &n)); 397 PetscCall(MatGetLocalSize(mats[nmat - 1], &m, PETSC_IGNORE)); 398 PetscCall(MatGetSize(mats[0], PETSC_IGNORE, &N)); 399 PetscCall(MatGetSize(mats[nmat - 1], &M, PETSC_IGNORE)); 400 PetscCall(MatCreate(comm, mat)); 401 PetscCall(MatSetSizes(*mat, m, n, M, N)); 402 PetscCall(MatSetType(*mat, MATCOMPOSITE)); 403 for (i = 0; i < nmat; i++) PetscCall(MatCompositeAddMat(*mat, mats[i])); 404 PetscCall(MatAssemblyBegin(*mat, MAT_FINAL_ASSEMBLY)); 405 PetscCall(MatAssemblyEnd(*mat, MAT_FINAL_ASSEMBLY)); 406 PetscFunctionReturn(PETSC_SUCCESS); 407 } 408 409 static PetscErrorCode MatCompositeAddMat_Composite(Mat mat, Mat smat) 410 { 411 Mat_Composite *shell; 412 Mat_CompositeLink ilink, next; 413 VecType vtype_mat, vtype_smat; 414 PetscBool match; 415 416 PetscFunctionBegin; 417 PetscCall(MatShellGetContext(mat, &shell)); 418 next = shell->head; 419 PetscCall(PetscNew(&ilink)); 420 ilink->next = NULL; 421 PetscCall(PetscObjectReference((PetscObject)smat)); 422 ilink->mat = smat; 423 424 if (!next) shell->head = ilink; 425 else { 426 while (next->next) next = next->next; 427 next->next = ilink; 428 ilink->prev = next; 429 } 430 shell->tail = ilink; 431 shell->nmat += 1; 432 433 /* If all of the partial matrices have the same default vector type, then the composite matrix should also have this default type. 434 Otherwise, the default type should be "standard". */ 435 PetscCall(MatGetVecType(smat, &vtype_smat)); 436 if (shell->nmat == 1) PetscCall(MatSetVecType(mat, vtype_smat)); 437 else { 438 PetscCall(MatGetVecType(mat, &vtype_mat)); 439 PetscCall(PetscStrcmp(vtype_smat, vtype_mat, &match)); 440 if (!match) PetscCall(MatSetVecType(mat, VECSTANDARD)); 441 } 442 443 /* Retain the old scalings (if any) and expand it with a 1.0 for the newly added matrix */ 444 if (shell->scalings) { 445 PetscCall(PetscRealloc(sizeof(PetscScalar) * shell->nmat, &shell->scalings)); 446 shell->scalings[shell->nmat - 1] = 1.0; 447 } 448 PetscFunctionReturn(PETSC_SUCCESS); 449 } 450 451 /*@ 452 MatCompositeAddMat - Add another matrix to a composite matrix. 453 454 Collective 455 456 Input Parameters: 457 + mat - the composite matrix 458 - smat - the partial matrix 459 460 Level: advanced 461 462 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE` 463 @*/ 464 PetscErrorCode MatCompositeAddMat(Mat mat, Mat smat) 465 { 466 PetscFunctionBegin; 467 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 468 PetscValidHeaderSpecific(smat, MAT_CLASSID, 2); 469 PetscUseMethod(mat, "MatCompositeAddMat_C", (Mat, Mat), (mat, smat)); 470 PetscFunctionReturn(PETSC_SUCCESS); 471 } 472 473 static PetscErrorCode MatCompositeSetType_Composite(Mat mat, MatCompositeType type) 474 { 475 Mat_Composite *b; 476 477 PetscFunctionBegin; 478 PetscCall(MatShellGetContext(mat, &b)); 479 b->type = type; 480 if (type == MAT_COMPOSITE_MULTIPLICATIVE) { 481 PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, NULL)); 482 PetscCall(MatShellSetOperation(mat, MATOP_MULT, (void (*)(void))MatMult_Composite_Multiplicative)); 483 PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite_Multiplicative)); 484 b->merge_mvctx = PETSC_FALSE; 485 } else { 486 PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Composite)); 487 PetscCall(MatShellSetOperation(mat, MATOP_MULT, (void (*)(void))MatMult_Composite)); 488 PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite)); 489 } 490 PetscFunctionReturn(PETSC_SUCCESS); 491 } 492 493 /*@ 494 MatCompositeSetType - Indicates if the matrix is defined as the sum of a set of matrices or the product. 495 496 Logically Collective 497 498 Input Parameters: 499 + mat - the composite matrix 500 - type - the `MatCompositeType` to use for the matrix 501 502 Level: advanced 503 504 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeGetType()`, `MATCOMPOSITE`, 505 `MatCompositeType` 506 @*/ 507 PetscErrorCode MatCompositeSetType(Mat mat, MatCompositeType type) 508 { 509 PetscFunctionBegin; 510 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 511 PetscValidLogicalCollectiveEnum(mat, type, 2); 512 PetscUseMethod(mat, "MatCompositeSetType_C", (Mat, MatCompositeType), (mat, type)); 513 PetscFunctionReturn(PETSC_SUCCESS); 514 } 515 516 static PetscErrorCode MatCompositeGetType_Composite(Mat mat, MatCompositeType *type) 517 { 518 Mat_Composite *shell; 519 520 PetscFunctionBegin; 521 PetscCall(MatShellGetContext(mat, &shell)); 522 *type = shell->type; 523 PetscFunctionReturn(PETSC_SUCCESS); 524 } 525 526 /*@ 527 MatCompositeGetType - Returns type of composite. 528 529 Not Collective 530 531 Input Parameter: 532 . mat - the composite matrix 533 534 Output Parameter: 535 . type - type of composite 536 537 Level: advanced 538 539 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetType()`, `MATCOMPOSITE`, `MatCompositeType` 540 @*/ 541 PetscErrorCode MatCompositeGetType(Mat mat, MatCompositeType *type) 542 { 543 PetscFunctionBegin; 544 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 545 PetscAssertPointer(type, 2); 546 PetscUseMethod(mat, "MatCompositeGetType_C", (Mat, MatCompositeType *), (mat, type)); 547 PetscFunctionReturn(PETSC_SUCCESS); 548 } 549 550 static PetscErrorCode MatCompositeSetMatStructure_Composite(Mat mat, MatStructure str) 551 { 552 Mat_Composite *shell; 553 554 PetscFunctionBegin; 555 PetscCall(MatShellGetContext(mat, &shell)); 556 shell->structure = str; 557 PetscFunctionReturn(PETSC_SUCCESS); 558 } 559 560 /*@ 561 MatCompositeSetMatStructure - Indicates structure of matrices in the composite matrix. 562 563 Not Collective 564 565 Input Parameters: 566 + mat - the composite matrix 567 - str - either `SAME_NONZERO_PATTERN`, `DIFFERENT_NONZERO_PATTERN` (default) or `SUBSET_NONZERO_PATTERN` 568 569 Level: advanced 570 571 Note: 572 Information about the matrices structure is used in `MatCompositeMerge()` for additive composite matrix. 573 574 .seealso: [](ch_matrices), `Mat`, `MatAXPY()`, `MatCreateComposite()`, `MatCompositeMerge()` `MatCompositeGetMatStructure()`, `MATCOMPOSITE` 575 @*/ 576 PetscErrorCode MatCompositeSetMatStructure(Mat mat, MatStructure str) 577 { 578 PetscFunctionBegin; 579 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 580 PetscUseMethod(mat, "MatCompositeSetMatStructure_C", (Mat, MatStructure), (mat, str)); 581 PetscFunctionReturn(PETSC_SUCCESS); 582 } 583 584 static PetscErrorCode MatCompositeGetMatStructure_Composite(Mat mat, MatStructure *str) 585 { 586 Mat_Composite *shell; 587 588 PetscFunctionBegin; 589 PetscCall(MatShellGetContext(mat, &shell)); 590 *str = shell->structure; 591 PetscFunctionReturn(PETSC_SUCCESS); 592 } 593 594 /*@ 595 MatCompositeGetMatStructure - Returns the structure of matrices in the composite matrix. 596 597 Not Collective 598 599 Input Parameter: 600 . mat - the composite matrix 601 602 Output Parameter: 603 . str - structure of the matrices 604 605 Level: advanced 606 607 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MATCOMPOSITE` 608 @*/ 609 PetscErrorCode MatCompositeGetMatStructure(Mat mat, MatStructure *str) 610 { 611 PetscFunctionBegin; 612 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 613 PetscAssertPointer(str, 2); 614 PetscUseMethod(mat, "MatCompositeGetMatStructure_C", (Mat, MatStructure *), (mat, str)); 615 PetscFunctionReturn(PETSC_SUCCESS); 616 } 617 618 static PetscErrorCode MatCompositeSetMergeType_Composite(Mat mat, MatCompositeMergeType type) 619 { 620 Mat_Composite *shell; 621 622 PetscFunctionBegin; 623 PetscCall(MatShellGetContext(mat, &shell)); 624 shell->mergetype = type; 625 PetscFunctionReturn(PETSC_SUCCESS); 626 } 627 628 /*@ 629 MatCompositeSetMergeType - Sets order of `MatCompositeMerge()`. 630 631 Logically Collective 632 633 Input Parameters: 634 + mat - the composite matrix 635 - type - `MAT_COMPOSITE_MERGE RIGHT` (default) to start merge from right with the first added matrix (mat[0]), 636 `MAT_COMPOSITE_MERGE_LEFT` to start merge from left with the last added matrix (mat[nmat-1]) 637 638 Level: advanced 639 640 Note: 641 The resulting matrix is the same regardless of the `MatCompositeMergeType`. Only the order of operation is changed. 642 If set to `MAT_COMPOSITE_MERGE_RIGHT` the order of the merge is mat[nmat-1]*(mat[nmat-2]*(...*(mat[1]*mat[0]))) 643 otherwise the order is (((mat[nmat-1]*mat[nmat-2])*mat[nmat-3])*...)*mat[0]. 644 645 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeMerge()`, `MATCOMPOSITE` 646 @*/ 647 PetscErrorCode MatCompositeSetMergeType(Mat mat, MatCompositeMergeType type) 648 { 649 PetscFunctionBegin; 650 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 651 PetscValidLogicalCollectiveEnum(mat, type, 2); 652 PetscUseMethod(mat, "MatCompositeSetMergeType_C", (Mat, MatCompositeMergeType), (mat, type)); 653 PetscFunctionReturn(PETSC_SUCCESS); 654 } 655 656 static PetscErrorCode MatCompositeMerge_Composite(Mat mat) 657 { 658 Mat_Composite *shell; 659 Mat_CompositeLink next, prev; 660 Mat tmat, newmat; 661 Vec left, right, dshift; 662 PetscScalar scale, shift; 663 PetscInt i; 664 665 PetscFunctionBegin; 666 PetscCall(MatShellGetContext(mat, &shell)); 667 next = shell->head; 668 prev = shell->tail; 669 PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()"); 670 PetscCheck(!((Mat_Shell *)mat->data)->zrows && !((Mat_Shell *)mat->data)->zcols, PetscObjectComm((PetscObject)mat), PETSC_ERR_SUP, "Cannot call MatCompositeMerge() if MatZeroRows() or MatZeroRowsColumns() has been called on the input Mat"); // TODO FIXME: lift this limitation by calling MatZeroRows()/MatZeroRowsColumns() after the merge 671 PetscCheck(!((Mat_Shell *)mat->data)->axpy, PetscObjectComm((PetscObject)mat), PETSC_ERR_SUP, "Cannot call MatCompositeMerge() if MatAXPY() has been called on the input Mat"); // TODO FIXME: lift this limitation by calling MatAXPY() after the merge 672 scale = ((Mat_Shell *)mat->data)->vscale; 673 shift = ((Mat_Shell *)mat->data)->vshift; 674 if (shell->type == MAT_COMPOSITE_ADDITIVE) { 675 if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) { 676 i = 0; 677 PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat)); 678 if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i++])); 679 while ((next = next->next)) PetscCall(MatAXPY(tmat, (shell->scalings ? shell->scalings[i++] : 1.0), next->mat, shell->structure)); 680 } else { 681 i = shell->nmat - 1; 682 PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat)); 683 if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i--])); 684 while ((prev = prev->prev)) PetscCall(MatAXPY(tmat, (shell->scalings ? shell->scalings[i--] : 1.0), prev->mat, shell->structure)); 685 } 686 } else { 687 if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) { 688 PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat)); 689 while ((next = next->next)) { 690 PetscCall(MatMatMult(next->mat, tmat, MAT_INITIAL_MATRIX, PETSC_DECIDE, &newmat)); 691 PetscCall(MatDestroy(&tmat)); 692 tmat = newmat; 693 } 694 } else { 695 PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat)); 696 while ((prev = prev->prev)) { 697 PetscCall(MatMatMult(tmat, prev->mat, MAT_INITIAL_MATRIX, PETSC_DECIDE, &newmat)); 698 PetscCall(MatDestroy(&tmat)); 699 tmat = newmat; 700 } 701 } 702 if (shell->scalings) { 703 for (i = 0; i < shell->nmat; i++) scale *= shell->scalings[i]; 704 } 705 } 706 707 if ((left = ((Mat_Shell *)mat->data)->left)) PetscCall(PetscObjectReference((PetscObject)left)); 708 if ((right = ((Mat_Shell *)mat->data)->right)) PetscCall(PetscObjectReference((PetscObject)right)); 709 if ((dshift = ((Mat_Shell *)mat->data)->dshift)) PetscCall(PetscObjectReference((PetscObject)dshift)); 710 711 PetscCall(MatHeaderReplace(mat, &tmat)); 712 713 PetscCall(MatDiagonalScale(mat, left, right)); 714 PetscCall(MatScale(mat, scale)); 715 PetscCall(MatShift(mat, shift)); 716 PetscCall(VecDestroy(&left)); 717 PetscCall(VecDestroy(&right)); 718 if (dshift) { 719 PetscCall(MatDiagonalSet(mat, dshift, ADD_VALUES)); 720 PetscCall(VecDestroy(&dshift)); 721 } 722 PetscFunctionReturn(PETSC_SUCCESS); 723 } 724 725 /*@ 726 MatCompositeMerge - Given a composite matrix, replaces it with a "regular" matrix 727 by summing or computing the product of all the matrices inside the composite matrix. 728 729 Collective 730 731 Input Parameter: 732 . mat - the composite matrix 733 734 Options Database Keys: 735 + -mat_composite_merge - merge in `MatAssemblyEnd()` 736 - -mat_composite_merge_type - set merge direction 737 738 Level: advanced 739 740 Note: 741 The `MatType` of the resulting matrix will be the same as the `MatType` of the FIRST matrix in the composite matrix. 742 743 .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MatCompositeSetMergeType()`, `MATCOMPOSITE` 744 @*/ 745 PetscErrorCode MatCompositeMerge(Mat mat) 746 { 747 PetscFunctionBegin; 748 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 749 PetscUseMethod(mat, "MatCompositeMerge_C", (Mat), (mat)); 750 PetscFunctionReturn(PETSC_SUCCESS); 751 } 752 753 static PetscErrorCode MatCompositeGetNumberMat_Composite(Mat mat, PetscInt *nmat) 754 { 755 Mat_Composite *shell; 756 757 PetscFunctionBegin; 758 PetscCall(MatShellGetContext(mat, &shell)); 759 *nmat = shell->nmat; 760 PetscFunctionReturn(PETSC_SUCCESS); 761 } 762 763 /*@ 764 MatCompositeGetNumberMat - Returns the number of matrices in the composite matrix. 765 766 Not Collective 767 768 Input Parameter: 769 . mat - the composite matrix 770 771 Output Parameter: 772 . nmat - number of matrices in the composite matrix 773 774 Level: advanced 775 776 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE` 777 @*/ 778 PetscErrorCode MatCompositeGetNumberMat(Mat mat, PetscInt *nmat) 779 { 780 PetscFunctionBegin; 781 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 782 PetscAssertPointer(nmat, 2); 783 PetscUseMethod(mat, "MatCompositeGetNumberMat_C", (Mat, PetscInt *), (mat, nmat)); 784 PetscFunctionReturn(PETSC_SUCCESS); 785 } 786 787 static PetscErrorCode MatCompositeGetMat_Composite(Mat mat, PetscInt i, Mat *Ai) 788 { 789 Mat_Composite *shell; 790 Mat_CompositeLink ilink; 791 PetscInt k; 792 793 PetscFunctionBegin; 794 PetscCall(MatShellGetContext(mat, &shell)); 795 PetscCheck(i < shell->nmat, PetscObjectComm((PetscObject)mat), PETSC_ERR_ARG_OUTOFRANGE, "index out of range: %" PetscInt_FMT " >= %" PetscInt_FMT, i, shell->nmat); 796 ilink = shell->head; 797 for (k = 0; k < i; k++) ilink = ilink->next; 798 *Ai = ilink->mat; 799 PetscFunctionReturn(PETSC_SUCCESS); 800 } 801 802 /*@ 803 MatCompositeGetMat - Returns the ith matrix from the composite matrix. 804 805 Logically Collective 806 807 Input Parameters: 808 + mat - the composite matrix 809 - i - the number of requested matrix 810 811 Output Parameter: 812 . Ai - ith matrix in composite 813 814 Level: advanced 815 816 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetNumberMat()`, `MatCompositeAddMat()`, `MATCOMPOSITE` 817 @*/ 818 PetscErrorCode MatCompositeGetMat(Mat mat, PetscInt i, Mat *Ai) 819 { 820 PetscFunctionBegin; 821 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 822 PetscValidLogicalCollectiveInt(mat, i, 2); 823 PetscAssertPointer(Ai, 3); 824 PetscUseMethod(mat, "MatCompositeGetMat_C", (Mat, PetscInt, Mat *), (mat, i, Ai)); 825 PetscFunctionReturn(PETSC_SUCCESS); 826 } 827 828 static PetscErrorCode MatCompositeSetScalings_Composite(Mat mat, const PetscScalar *scalings) 829 { 830 Mat_Composite *shell; 831 PetscInt nmat; 832 833 PetscFunctionBegin; 834 PetscCall(MatShellGetContext(mat, &shell)); 835 PetscCall(MatCompositeGetNumberMat(mat, &nmat)); 836 if (!shell->scalings) PetscCall(PetscMalloc1(nmat, &shell->scalings)); 837 PetscCall(PetscArraycpy(shell->scalings, scalings, nmat)); 838 PetscFunctionReturn(PETSC_SUCCESS); 839 } 840 841 /*@ 842 MatCompositeSetScalings - Sets separate scaling factors for component matrices. 843 844 Logically Collective 845 846 Input Parameters: 847 + mat - the composite matrix 848 - scalings - array of scaling factors with scalings[i] being factor of i-th matrix, for i in [0, nmat) 849 850 Level: advanced 851 852 .seealso: [](ch_matrices), `Mat`, `MatScale()`, `MatDiagonalScale()`, `MATCOMPOSITE` 853 @*/ 854 PetscErrorCode MatCompositeSetScalings(Mat mat, const PetscScalar *scalings) 855 { 856 PetscFunctionBegin; 857 PetscValidHeaderSpecific(mat, MAT_CLASSID, 1); 858 PetscAssertPointer(scalings, 2); 859 PetscValidLogicalCollectiveScalar(mat, *scalings, 2); 860 PetscUseMethod(mat, "MatCompositeSetScalings_C", (Mat, const PetscScalar *), (mat, scalings)); 861 PetscFunctionReturn(PETSC_SUCCESS); 862 } 863 864 /*MC 865 MATCOMPOSITE - A matrix defined by the sum (or product) of one or more matrices. 866 The matrices need to have a correct size and parallel layout for the sum or product to be valid. 867 868 Level: advanced 869 870 Note: 871 To use the product of the matrices call `MatCompositeSetType`(mat,`MAT_COMPOSITE_MULTIPLICATIVE`); 872 873 Developer Notes: 874 This is implemented on top of `MATSHELL` to get support for scaling and shifting without requiring duplicate code 875 876 Users can not call `MatShellSetOperation()` operations on this class, there is some error checking for that incorrect usage 877 878 .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetScalings()`, `MatCompositeAddMat()`, `MatSetType()`, `MatCompositeSetType()`, `MatCompositeGetType()`, 879 `MatCompositeSetMatStructure()`, `MatCompositeGetMatStructure()`, `MatCompositeMerge()`, `MatCompositeSetMergeType()`, `MatCompositeGetNumberMat()`, `MatCompositeGetMat()` 880 M*/ 881 882 PETSC_EXTERN PetscErrorCode MatCreate_Composite(Mat A) 883 { 884 Mat_Composite *b; 885 886 PetscFunctionBegin; 887 PetscCall(PetscNew(&b)); 888 889 b->type = MAT_COMPOSITE_ADDITIVE; 890 b->nmat = 0; 891 b->merge = PETSC_FALSE; 892 b->mergetype = MAT_COMPOSITE_MERGE_RIGHT; 893 b->structure = DIFFERENT_NONZERO_PATTERN; 894 b->merge_mvctx = PETSC_TRUE; 895 896 PetscCall(MatSetType(A, MATSHELL)); 897 PetscCall(MatShellSetContext(A, b)); 898 PetscCall(MatShellSetOperation(A, MATOP_DESTROY, (void (*)(void))MatDestroy_Composite)); 899 PetscCall(MatShellSetOperation(A, MATOP_MULT, (void (*)(void))MatMult_Composite)); 900 PetscCall(MatShellSetOperation(A, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite)); 901 PetscCall(MatShellSetOperation(A, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Composite)); 902 PetscCall(MatShellSetOperation(A, MATOP_ASSEMBLY_END, (void (*)(void))MatAssemblyEnd_Composite)); 903 PetscCall(MatShellSetOperation(A, MATOP_SET_FROM_OPTIONS, (void (*)(void))MatSetFromOptions_Composite)); 904 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeAddMat_C", MatCompositeAddMat_Composite)); 905 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetType_C", MatCompositeSetType_Composite)); 906 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetType_C", MatCompositeGetType_Composite)); 907 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMergeType_C", MatCompositeSetMergeType_Composite)); 908 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMatStructure_C", MatCompositeSetMatStructure_Composite)); 909 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMatStructure_C", MatCompositeGetMatStructure_Composite)); 910 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeMerge_C", MatCompositeMerge_Composite)); 911 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetNumberMat_C", MatCompositeGetNumberMat_Composite)); 912 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMat_C", MatCompositeGetMat_Composite)); 913 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetScalings_C", MatCompositeSetScalings_Composite)); 914 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContext_C", MatShellSetContext_Immutable)); 915 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContextDestroy_C", MatShellSetContextDestroy_Immutable)); 916 PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetManageScalingShifts_C", MatShellSetManageScalingShifts_Immutable)); 917 PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATCOMPOSITE)); 918 PetscFunctionReturn(PETSC_SUCCESS); 919 } 920