1 /* 2 Defines matrix-matrix product routines for pairs of MPIAIJ matrices 3 C = A^T * B 4 The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense(). 5 */ 6 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/ 7 #include <../src/mat/impls/aij/mpi/mpiaij.h> 8 #include <../src/mat/impls/dense/mpi/mpidense.h> 9 10 static PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data) 11 { 12 Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data; 13 14 PetscFunctionBegin; 15 PetscCall(MatDestroy(&atb->mA)); 16 PetscCall(VecDestroy(&atb->bt)); 17 PetscCall(VecDestroy(&atb->ct)); 18 PetscCall(PetscFree(atb)); 19 PetscFunctionReturn(PETSC_SUCCESS); 20 } 21 22 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat); 23 24 PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C) 25 { 26 Mat_MatTransMatMult *atb; 27 PetscBool cisdense; 28 29 PetscFunctionBegin; 30 MatCheckProduct(C, 4); 31 PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty"); 32 33 /* create output dense matrix C = A^T*B */ 34 PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N)); 35 PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, "")); 36 if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name)); 37 PetscCall(MatSetUp(C)); 38 39 /* create additional data structure for the product */ 40 PetscCall(PetscNew(&atb)); 41 if (B->cmap->N) { 42 PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA)); 43 if (!atb->mA->assembled) { 44 PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY)); 45 PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY)); 46 } 47 PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt)); 48 } 49 C->product->data = atb; 50 C->product->destroy = MatDestroy_MPIDense_MatTransMatMult; 51 52 C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense; 53 PetscFunctionReturn(PETSC_SUCCESS); 54 } 55 56 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C) 57 { 58 const PetscScalar *Barray, *ctarray; 59 PetscScalar *Carray, *btarray; 60 PetscInt i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc; 61 Mat_MatTransMatMult *atb; 62 Vec bt, ct; 63 64 PetscFunctionBegin; 65 MatCheckProduct(C, 3); 66 atb = (Mat_MatTransMatMult *)C->product->data; 67 PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct"); 68 if (!BN) { 69 PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY)); 70 PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY)); 71 PetscFunctionReturn(PETSC_SUCCESS); 72 } 73 bt = atb->bt; 74 ct = atb->ct; 75 76 /* transpose local array of B, then copy it to vector bt */ 77 PetscCall(MatDenseGetArrayRead(B, &Barray)); 78 PetscCall(MatDenseGetLDA(B, &ldb)); 79 PetscCall(VecGetArray(bt, &btarray)); 80 for (j = 0; j < BN; j++) 81 for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i]; 82 PetscCall(VecRestoreArray(bt, &btarray)); 83 PetscCall(MatDenseRestoreArrayRead(B, &Barray)); 84 85 /* compute ct = mA^T * cb */ 86 PetscCall(MatMultTranspose(atb->mA, bt, ct)); 87 88 /* transpose local array of ct to matrix C */ 89 PetscCall(MatDenseGetArray(C, &Carray)); 90 PetscCall(MatDenseGetLDA(C, &ldc)); 91 PetscCall(VecGetArrayRead(ct, &ctarray)); 92 for (j = 0; j < BN; j++) 93 for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j]; 94 PetscCall(VecRestoreArrayRead(ct, &ctarray)); 95 PetscCall(MatDenseRestoreArray(C, &Carray)); 96 PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY)); 97 PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY)); 98 PetscFunctionReturn(PETSC_SUCCESS); 99 } 100