1 2 /* 3 Defines matrix-matrix product routines for pairs of MPIAIJ matrices 4 C = A^T * B 5 The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense(). 6 */ 7 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/ 8 #include <../src/mat/impls/aij/mpi/mpiaij.h> 9 #include <../src/mat/impls/dense/mpi/mpidense.h> 10 11 PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data) 12 { 13 Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data; 14 15 PetscFunctionBegin; 16 PetscCall(MatDestroy(&atb->mA)); 17 PetscCall(VecDestroy(&atb->bt)); 18 PetscCall(VecDestroy(&atb->ct)); 19 PetscCall(PetscFree(atb)); 20 PetscFunctionReturn(0); 21 } 22 23 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat); 24 25 PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C) 26 { 27 Mat_MatTransMatMult *atb; 28 PetscBool cisdense; 29 30 PetscFunctionBegin; 31 MatCheckProduct(C, 4); 32 PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty"); 33 34 /* create output dense matrix C = A^T*B */ 35 PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N)); 36 PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, "")); 37 if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name)); 38 PetscCall(MatSetUp(C)); 39 40 /* create additional data structure for the product */ 41 PetscCall(PetscNew(&atb)); 42 if (B->cmap->N) { 43 PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA)); 44 if (!atb->mA->assembled) { 45 PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY)); 46 PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY)); 47 } 48 PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt)); 49 } 50 C->product->data = atb; 51 C->product->destroy = MatDestroy_MPIDense_MatTransMatMult; 52 53 C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense; 54 PetscFunctionReturn(0); 55 } 56 57 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C) 58 { 59 const PetscScalar *Barray, *ctarray; 60 PetscScalar *Carray, *btarray; 61 PetscInt i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc; 62 Mat_MatTransMatMult *atb; 63 Vec bt, ct; 64 65 PetscFunctionBegin; 66 MatCheckProduct(C, 3); 67 atb = (Mat_MatTransMatMult *)C->product->data; 68 PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct"); 69 if (!BN) { 70 PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY)); 71 PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY)); 72 PetscFunctionReturn(0); 73 } 74 bt = atb->bt; 75 ct = atb->ct; 76 77 /* transpose local array of B, then copy it to vector bt */ 78 PetscCall(MatDenseGetArrayRead(B, &Barray)); 79 PetscCall(MatDenseGetLDA(B, &ldb)); 80 PetscCall(VecGetArray(bt, &btarray)); 81 for (j = 0; j < BN; j++) 82 for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i]; 83 PetscCall(VecRestoreArray(bt, &btarray)); 84 PetscCall(MatDenseRestoreArrayRead(B, &Barray)); 85 86 /* compute ct = mA^T * cb */ 87 PetscCall(MatMultTranspose(atb->mA, bt, ct)); 88 89 /* transpose local array of ct to matrix C */ 90 PetscCall(MatDenseGetArray(C, &Carray)); 91 PetscCall(MatDenseGetLDA(C, &ldc)); 92 PetscCall(VecGetArrayRead(ct, &ctarray)); 93 for (j = 0; j < BN; j++) 94 for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j]; 95 PetscCall(VecRestoreArrayRead(ct, &ctarray)); 96 PetscCall(MatDenseRestoreArray(C, &Carray)); 97 PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY)); 98 PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY)); 99 PetscFunctionReturn(0); 100 } 101