xref: /petsc/src/mat/impls/aij/mpi/mpimattransposematmult.c (revision 4e8208cbcbc709572b8abe32f33c78b69c819375)
18949adfdSHong Zhang /*
28949adfdSHong Zhang   Defines matrix-matrix product routines for pairs of MPIAIJ matrices
38949adfdSHong Zhang           C = A^T * B
48949adfdSHong Zhang   The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
58949adfdSHong Zhang */
68949adfdSHong Zhang #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
78949adfdSHong Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
88949adfdSHong Zhang #include <../src/mat/impls/dense/mpi/mpidense.h>
98949adfdSHong Zhang 
MatProductCtxDestroy_MPIDense_MatTransMatMult(PetscCtxRt data)10*2a8381b2SBarry Smith static PetscErrorCode MatProductCtxDestroy_MPIDense_MatTransMatMult(PetscCtxRt data)
11d71ae5a4SJacob Faibussowitsch {
12cc1eb50dSBarry Smith   MatProductCtx_MatTransMatMult *atb = *(MatProductCtx_MatTransMatMult **)data;
138949adfdSHong Zhang 
148949adfdSHong Zhang   PetscFunctionBegin;
159566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&atb->mA));
169566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&atb->bt));
179566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&atb->ct));
189566063dSJacob Faibussowitsch   PetscCall(PetscFree(atb));
193ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
208949adfdSHong Zhang }
218949adfdSHong Zhang 
226718818eSStefano Zampini static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);
236718818eSStefano Zampini 
MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A,Mat B,PetscReal fill,Mat C)24d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C)
25d71ae5a4SJacob Faibussowitsch {
26cc1eb50dSBarry Smith   MatProductCtx_MatTransMatMult *atb;
276718818eSStefano Zampini   PetscBool                      cisdense;
288949adfdSHong Zhang 
298949adfdSHong Zhang   PetscFunctionBegin;
306718818eSStefano Zampini   MatCheckProduct(C, 4);
3128b400f6SJacob Faibussowitsch   PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty");
328949adfdSHong Zhang 
338949adfdSHong Zhang   /* create output dense matrix C = A^T*B */
349566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N));
359566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, ""));
3648a46eb9SPierre Jolivet   if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name));
379566063dSJacob Faibussowitsch   PetscCall(MatSetUp(C));
388949adfdSHong Zhang 
396718818eSStefano Zampini   /* create additional data structure for the product */
409566063dSJacob Faibussowitsch   PetscCall(PetscNew(&atb));
416718818eSStefano Zampini   if (B->cmap->N) {
429566063dSJacob Faibussowitsch     PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA));
43445ca090SPierre Jolivet     if (!atb->mA->assembled) {
449566063dSJacob Faibussowitsch       PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY));
459566063dSJacob Faibussowitsch       PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY));
46445ca090SPierre Jolivet     }
479566063dSJacob Faibussowitsch     PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt));
486718818eSStefano Zampini   }
496718818eSStefano Zampini   C->product->data    = atb;
50cc1eb50dSBarry Smith   C->product->destroy = MatProductCtxDestroy_MPIDense_MatTransMatMult;
518949adfdSHong Zhang 
524222ddf1SHong Zhang   C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
533ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
548949adfdSHong Zhang }
558949adfdSHong Zhang 
MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A,Mat B,Mat C)56d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C)
57d71ae5a4SJacob Faibussowitsch {
581683a169SBarry Smith   const PetscScalar             *Barray, *ctarray;
591683a169SBarry Smith   PetscScalar                   *Carray, *btarray;
60b45e3bf4SStefano Zampini   PetscInt                       i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
61cc1eb50dSBarry Smith   MatProductCtx_MatTransMatMult *atb;
626718818eSStefano Zampini   Vec                            bt, ct;
638949adfdSHong Zhang 
648949adfdSHong Zhang   PetscFunctionBegin;
656718818eSStefano Zampini   MatCheckProduct(C, 3);
66cc1eb50dSBarry Smith   atb = (MatProductCtx_MatTransMatMult *)C->product->data;
6708401ef6SPierre Jolivet   PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct");
686718818eSStefano Zampini   if (!BN) {
699566063dSJacob Faibussowitsch     PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
709566063dSJacob Faibussowitsch     PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
713ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
726718818eSStefano Zampini   }
736718818eSStefano Zampini   bt = atb->bt;
746718818eSStefano Zampini   ct = atb->ct;
758949adfdSHong Zhang 
76b45e3bf4SStefano Zampini   /* transpose local array of B, then copy it to vector bt */
779566063dSJacob Faibussowitsch   PetscCall(MatDenseGetArrayRead(B, &Barray));
789566063dSJacob Faibussowitsch   PetscCall(MatDenseGetLDA(B, &ldb));
799566063dSJacob Faibussowitsch   PetscCall(VecGetArray(bt, &btarray));
80b45e3bf4SStefano Zampini   for (j = 0; j < BN; j++)
819371c9d4SSatish Balay     for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
829566063dSJacob Faibussowitsch   PetscCall(VecRestoreArray(bt, &btarray));
839566063dSJacob Faibussowitsch   PetscCall(MatDenseRestoreArrayRead(B, &Barray));
848949adfdSHong Zhang 
858949adfdSHong Zhang   /* compute ct = mA^T * cb */
869566063dSJacob Faibussowitsch   PetscCall(MatMultTranspose(atb->mA, bt, ct));
878949adfdSHong Zhang 
88905b3b74SHong Zhang   /* transpose local array of ct to matrix C */
899566063dSJacob Faibussowitsch   PetscCall(MatDenseGetArray(C, &Carray));
909566063dSJacob Faibussowitsch   PetscCall(MatDenseGetLDA(C, &ldc));
919566063dSJacob Faibussowitsch   PetscCall(VecGetArrayRead(ct, &ctarray));
92b45e3bf4SStefano Zampini   for (j = 0; j < BN; j++)
939371c9d4SSatish Balay     for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
949566063dSJacob Faibussowitsch   PetscCall(VecRestoreArrayRead(ct, &ctarray));
959566063dSJacob Faibussowitsch   PetscCall(MatDenseRestoreArray(C, &Carray));
9667af85e8SPierre Jolivet   PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
979566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
989566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1008949adfdSHong Zhang }
101