/* Defines matrix-matrix product routines for pairs of MPIAIJ matrices C = A^T * B The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense(). */ #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/ #include <../src/mat/impls/aij/mpi/mpiaij.h> #include <../src/mat/impls/dense/mpi/mpidense.h> static PetscErrorCode MatProductCtxDestroy_MPIDense_MatTransMatMult(void **data) { MatProductCtx_MatTransMatMult *atb = *(MatProductCtx_MatTransMatMult **)data; PetscFunctionBegin; PetscCall(MatDestroy(&atb->mA)); PetscCall(VecDestroy(&atb->bt)); PetscCall(VecDestroy(&atb->ct)); PetscCall(PetscFree(atb)); PetscFunctionReturn(PETSC_SUCCESS); } static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat); PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C) { MatProductCtx_MatTransMatMult *atb; PetscBool cisdense; PetscFunctionBegin; MatCheckProduct(C, 4); PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty"); /* create output dense matrix C = A^T*B */ PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N)); PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, "")); if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name)); PetscCall(MatSetUp(C)); /* create additional data structure for the product */ PetscCall(PetscNew(&atb)); if (B->cmap->N) { PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA)); if (!atb->mA->assembled) { PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY)); PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY)); } PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt)); } C->product->data = atb; C->product->destroy = MatProductCtxDestroy_MPIDense_MatTransMatMult; C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense; PetscFunctionReturn(PETSC_SUCCESS); } static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C) { const PetscScalar *Barray, *ctarray; PetscScalar *Carray, *btarray; PetscInt i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc; MatProductCtx_MatTransMatMult *atb; Vec bt, ct; PetscFunctionBegin; MatCheckProduct(C, 3); atb = (MatProductCtx_MatTransMatMult *)C->product->data; PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct"); if (!BN) { PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY)); PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY)); PetscFunctionReturn(PETSC_SUCCESS); } bt = atb->bt; ct = atb->ct; /* transpose local array of B, then copy it to vector bt */ PetscCall(MatDenseGetArrayRead(B, &Barray)); PetscCall(MatDenseGetLDA(B, &ldb)); PetscCall(VecGetArray(bt, &btarray)); for (j = 0; j < BN; j++) for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i]; PetscCall(VecRestoreArray(bt, &btarray)); PetscCall(MatDenseRestoreArrayRead(B, &Barray)); /* compute ct = mA^T * cb */ PetscCall(MatMultTranspose(atb->mA, bt, ct)); /* transpose local array of ct to matrix C */ PetscCall(MatDenseGetArray(C, &Carray)); PetscCall(MatDenseGetLDA(C, &ldc)); PetscCall(VecGetArrayRead(ct, &ctarray)); for (j = 0; j < BN; j++) for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j]; PetscCall(VecRestoreArrayRead(ct, &ctarray)); PetscCall(MatDenseRestoreArray(C, &Carray)); PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE)); PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY)); PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY)); PetscFunctionReturn(PETSC_SUCCESS); }