1 2 /* 3 Defines matrix-matrix product routines for pairs of MPIAIJ matrices 4 C = A^T * B 5 The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense(). 6 */ 7 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/ 8 #include <../src/mat/impls/aij/mpi/mpiaij.h> 9 #include <../src/mat/impls/dense/mpi/mpidense.h> 10 11 PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data) 12 { 13 Mat_MatTransMatMult *atb = (Mat_MatTransMatMult*)data; 14 15 PetscFunctionBegin; 16 PetscCall(MatDestroy(&atb->mA)); 17 PetscCall(VecDestroy(&atb->bt)); 18 PetscCall(VecDestroy(&atb->ct)); 19 PetscCall(PetscFree(atb)); 20 PetscFunctionReturn(0); 21 } 22 23 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat,Mat,Mat); 24 25 PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A,Mat B,PetscReal fill,Mat C) 26 { 27 Mat_MatTransMatMult *atb; 28 PetscBool cisdense; 29 30 PetscFunctionBegin; 31 MatCheckProduct(C,4); 32 PetscCheck(!C->product->data,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Extra product struct not empty"); 33 34 /* create output dense matrix C = A^T*B */ 35 PetscCall(MatSetSizes(C,A->cmap->n,B->cmap->n,A->cmap->N,B->cmap->N)); 36 PetscCall(PetscObjectTypeCompareAny((PetscObject)C,&cisdense,MATMPIDENSE,MATMPIDENSECUDA,"")); 37 if (!cisdense) { 38 PetscCall(MatSetType(C,((PetscObject)B)->type_name)); 39 } 40 PetscCall(MatSetUp(C)); 41 42 /* create additional data structure for the product */ 43 PetscCall(PetscNew(&atb)); 44 if (B->cmap->N) { 45 PetscCall(MatCreateMAIJ(A,B->cmap->N,&atb->mA)); 46 if (!atb->mA->assembled) { 47 PetscCall(MatAssemblyBegin(atb->mA,MAT_FINAL_ASSEMBLY)); 48 PetscCall(MatAssemblyEnd(atb->mA,MAT_FINAL_ASSEMBLY)); 49 } 50 PetscCall(MatCreateVecs(atb->mA,&atb->ct,&atb->bt)); 51 } 52 C->product->data = atb; 53 C->product->destroy = MatDestroy_MPIDense_MatTransMatMult; 54 55 C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense; 56 PetscFunctionReturn(0); 57 } 58 59 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A,Mat B,Mat C) 60 { 61 const PetscScalar *Barray,*ctarray; 62 PetscScalar *Carray,*btarray; 63 PetscInt i,j,m=A->rmap->n,n=A->cmap->n,ldb,BN=B->cmap->N,ldc; 64 Mat_MatTransMatMult *atb; 65 Vec bt,ct; 66 67 PetscFunctionBegin; 68 MatCheckProduct(C,3); 69 atb = (Mat_MatTransMatMult *)C->product->data; 70 PetscCheck(atb,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing product struct"); 71 if (!BN) { 72 PetscCall(MatAssemblyBegin(C,MAT_FINAL_ASSEMBLY)); 73 PetscCall(MatAssemblyEnd(C,MAT_FINAL_ASSEMBLY)); 74 PetscFunctionReturn(0); 75 } 76 bt = atb->bt; 77 ct = atb->ct; 78 79 /* transpose local array of B, then copy it to vector bt */ 80 PetscCall(MatDenseGetArrayRead(B,&Barray)); 81 PetscCall(MatDenseGetLDA(B,&ldb)); 82 PetscCall(VecGetArray(bt,&btarray)); 83 for (j=0; j<BN; j++) 84 for (i=0; i<m; i++) 85 btarray[i*BN + j] = Barray[j*ldb + i]; 86 PetscCall(VecRestoreArray(bt,&btarray)); 87 PetscCall(MatDenseRestoreArrayRead(B,&Barray)); 88 89 /* compute ct = mA^T * cb */ 90 PetscCall(MatMultTranspose(atb->mA,bt,ct)); 91 92 /* transpose local array of ct to matrix C */ 93 PetscCall(MatDenseGetArray(C,&Carray)); 94 PetscCall(MatDenseGetLDA(C,&ldc)); 95 PetscCall(VecGetArrayRead(ct,&ctarray)); 96 for (j=0; j<BN; j++) 97 for (i=0; i<n; i++) 98 Carray[j*ldc + i] = ctarray[i*BN + j]; 99 PetscCall(VecRestoreArrayRead(ct,&ctarray)); 100 PetscCall(MatDenseRestoreArray(C,&Carray)); 101 PetscCall(MatAssemblyBegin(C,MAT_FINAL_ASSEMBLY)); 102 PetscCall(MatAssemblyEnd(C,MAT_FINAL_ASSEMBLY)); 103 PetscFunctionReturn(0); 104 } 105