1 2 /* 3 Defines matrix-matrix product routines for pairs of MPIAIJ matrices 4 C = A^T * B 5 The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense(). 6 */ 7 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/ 8 #include <../src/mat/impls/aij/mpi/mpiaij.h> 9 #include <../src/mat/impls/dense/mpi/mpidense.h> 10 11 PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data) { 12 Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data; 13 14 PetscFunctionBegin; 15 PetscCall(MatDestroy(&atb->mA)); 16 PetscCall(VecDestroy(&atb->bt)); 17 PetscCall(VecDestroy(&atb->ct)); 18 PetscCall(PetscFree(atb)); 19 PetscFunctionReturn(0); 20 } 21 22 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat); 23 24 PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C) { 25 Mat_MatTransMatMult *atb; 26 PetscBool cisdense; 27 28 PetscFunctionBegin; 29 MatCheckProduct(C, 4); 30 PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty"); 31 32 /* create output dense matrix C = A^T*B */ 33 PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N)); 34 PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, "")); 35 if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name)); 36 PetscCall(MatSetUp(C)); 37 38 /* create additional data structure for the product */ 39 PetscCall(PetscNew(&atb)); 40 if (B->cmap->N) { 41 PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA)); 42 if (!atb->mA->assembled) { 43 PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY)); 44 PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY)); 45 } 46 PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt)); 47 } 48 C->product->data = atb; 49 C->product->destroy = MatDestroy_MPIDense_MatTransMatMult; 50 51 C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense; 52 PetscFunctionReturn(0); 53 } 54 55 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C) { 56 const PetscScalar *Barray, *ctarray; 57 PetscScalar *Carray, *btarray; 58 PetscInt i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc; 59 Mat_MatTransMatMult *atb; 60 Vec bt, ct; 61 62 PetscFunctionBegin; 63 MatCheckProduct(C, 3); 64 atb = (Mat_MatTransMatMult *)C->product->data; 65 PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct"); 66 if (!BN) { 67 PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY)); 68 PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY)); 69 PetscFunctionReturn(0); 70 } 71 bt = atb->bt; 72 ct = atb->ct; 73 74 /* transpose local array of B, then copy it to vector bt */ 75 PetscCall(MatDenseGetArrayRead(B, &Barray)); 76 PetscCall(MatDenseGetLDA(B, &ldb)); 77 PetscCall(VecGetArray(bt, &btarray)); 78 for (j = 0; j < BN; j++) 79 for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i]; 80 PetscCall(VecRestoreArray(bt, &btarray)); 81 PetscCall(MatDenseRestoreArrayRead(B, &Barray)); 82 83 /* compute ct = mA^T * cb */ 84 PetscCall(MatMultTranspose(atb->mA, bt, ct)); 85 86 /* transpose local array of ct to matrix C */ 87 PetscCall(MatDenseGetArray(C, &Carray)); 88 PetscCall(MatDenseGetLDA(C, &ldc)); 89 PetscCall(VecGetArrayRead(ct, &ctarray)); 90 for (j = 0; j < BN; j++) 91 for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j]; 92 PetscCall(VecRestoreArrayRead(ct, &ctarray)); 93 PetscCall(MatDenseRestoreArray(C, &Carray)); 94 PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY)); 95 PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY)); 96 PetscFunctionReturn(0); 97 } 98