xref: /petsc/src/mat/impls/aij/mpi/mpimattransposematmult.c (revision 445ca0904fedcd9c19e8b5d649a2e70cfb1cd372)
18949adfdSHong Zhang 
28949adfdSHong Zhang /*
38949adfdSHong Zhang   Defines matrix-matrix product routines for pairs of MPIAIJ matrices
48949adfdSHong Zhang           C = A^T * B
58949adfdSHong Zhang   The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
68949adfdSHong Zhang */
78949adfdSHong Zhang #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
88949adfdSHong Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h>
98949adfdSHong Zhang #include <../src/mat/impls/dense/mpi/mpidense.h>
108949adfdSHong Zhang 
116718818eSStefano Zampini PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data)
128949adfdSHong Zhang {
138949adfdSHong Zhang   PetscErrorCode      ierr;
146718818eSStefano Zampini   Mat_MatTransMatMult *atb = (Mat_MatTransMatMult*)data;
158949adfdSHong Zhang 
168949adfdSHong Zhang   PetscFunctionBegin;
178949adfdSHong Zhang   ierr = MatDestroy(&atb->mA);CHKERRQ(ierr);
188949adfdSHong Zhang   ierr = VecDestroy(&atb->bt);CHKERRQ(ierr);
198949adfdSHong Zhang   ierr = VecDestroy(&atb->ct);CHKERRQ(ierr);
208949adfdSHong Zhang   ierr = PetscFree(atb);CHKERRQ(ierr);
218949adfdSHong Zhang   PetscFunctionReturn(0);
228949adfdSHong Zhang }
238949adfdSHong Zhang 
246718818eSStefano Zampini static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat,Mat,Mat);
256718818eSStefano Zampini 
266718818eSStefano Zampini PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A,Mat B,PetscReal fill,Mat C)
278949adfdSHong Zhang {
288949adfdSHong Zhang   PetscErrorCode      ierr;
298949adfdSHong Zhang   Mat_MatTransMatMult *atb;
306718818eSStefano Zampini   PetscBool           cisdense;
318949adfdSHong Zhang 
328949adfdSHong Zhang   PetscFunctionBegin;
336718818eSStefano Zampini   MatCheckProduct(C,4);
346718818eSStefano Zampini   if (C->product->data) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Extra product struct not empty");
358949adfdSHong Zhang 
368949adfdSHong Zhang   /* create output dense matrix C = A^T*B */
376718818eSStefano Zampini   ierr = MatSetSizes(C,A->cmap->n,B->cmap->n,A->cmap->N,B->cmap->N);CHKERRQ(ierr);
386718818eSStefano Zampini   ierr = PetscObjectTypeCompareAny((PetscObject)C,&cisdense,MATMPIDENSE,MATMPIDENSECUDA,"");CHKERRQ(ierr);
396718818eSStefano Zampini   if (!cisdense) {
406718818eSStefano Zampini     ierr = MatSetType(C,((PetscObject)B)->type_name);CHKERRQ(ierr);
416718818eSStefano Zampini   }
424eeee6adSStefano Zampini   ierr = MatSetUp(C);CHKERRQ(ierr);
438949adfdSHong Zhang 
446718818eSStefano Zampini   /* create additional data structure for the product */
456718818eSStefano Zampini   ierr = PetscNew(&atb);CHKERRQ(ierr);
466718818eSStefano Zampini   if (B->cmap->N) {
476718818eSStefano Zampini     ierr = MatCreateMAIJ(A,B->cmap->N,&atb->mA);CHKERRQ(ierr);
48*445ca090SPierre Jolivet     if (!atb->mA->assembled) {
49*445ca090SPierre Jolivet       ierr = MatAssemblyBegin(atb->mA,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
50*445ca090SPierre Jolivet       ierr = MatAssemblyEnd(atb->mA,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
51*445ca090SPierre Jolivet     }
526718818eSStefano Zampini     ierr = MatCreateVecs(atb->mA,&atb->ct,&atb->bt);CHKERRQ(ierr);
536718818eSStefano Zampini   }
546718818eSStefano Zampini   C->product->data    = atb;
556718818eSStefano Zampini   C->product->destroy = MatDestroy_MPIDense_MatTransMatMult;
568949adfdSHong Zhang 
574222ddf1SHong Zhang   C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
588949adfdSHong Zhang   PetscFunctionReturn(0);
598949adfdSHong Zhang }
608949adfdSHong Zhang 
616718818eSStefano Zampini static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A,Mat B,Mat C)
628949adfdSHong Zhang {
638949adfdSHong Zhang   PetscErrorCode      ierr;
641683a169SBarry Smith   const PetscScalar   *Barray,*ctarray;
651683a169SBarry Smith   PetscScalar         *Carray,*btarray;
66905b3b74SHong Zhang   Mat_MPIDense        *b=(Mat_MPIDense*)B->data,*c=(Mat_MPIDense*)C->data;
67905b3b74SHong Zhang   Mat_SeqDense        *bseq=(Mat_SeqDense*)(b->A)->data,*cseq=(Mat_SeqDense*)(c->A)->data;
68905b3b74SHong Zhang   PetscInt            i,j,m=A->rmap->n,n=A->cmap->n,ldb=bseq->lda,BN=B->cmap->N,ldc=cseq->lda;
696718818eSStefano Zampini   Mat_MatTransMatMult *atb;
706718818eSStefano Zampini   Vec                 bt,ct;
718949adfdSHong Zhang 
728949adfdSHong Zhang   PetscFunctionBegin;
736718818eSStefano Zampini   MatCheckProduct(C,3);
746718818eSStefano Zampini   atb=(Mat_MatTransMatMult *)C->product->data;
756718818eSStefano Zampini   if (!atb) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing product struct");
766718818eSStefano Zampini   if (!BN) {
776718818eSStefano Zampini     ierr = MatAssemblyBegin(C,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
786718818eSStefano Zampini     ierr = MatAssemblyEnd(C,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
796718818eSStefano Zampini     PetscFunctionReturn(0);
806718818eSStefano Zampini   }
816718818eSStefano Zampini   bt = atb->bt;
826718818eSStefano Zampini   ct = atb->ct;
838949adfdSHong Zhang   /* transpose local arry of B, then copy it to vector bt */
841683a169SBarry Smith   ierr = MatDenseGetArrayRead(B,&Barray);CHKERRQ(ierr);
858949adfdSHong Zhang   ierr = VecGetArray(bt,&btarray);CHKERRQ(ierr);
868949adfdSHong Zhang 
878949adfdSHong Zhang   for (j=0; j<BN; j++) {
88905b3b74SHong Zhang     for (i=0; i<m; i++) btarray[i*BN + j] = Barray[j*ldb + i];
898949adfdSHong Zhang   }
908949adfdSHong Zhang   ierr = VecRestoreArray(bt,&btarray);CHKERRQ(ierr);
911683a169SBarry Smith   ierr = MatDenseRestoreArrayRead(B,&Barray);CHKERRQ(ierr);
928949adfdSHong Zhang 
938949adfdSHong Zhang   /* compute ct = mA^T * cb */
948949adfdSHong Zhang   ierr = MatMultTranspose(atb->mA,bt,ct);CHKERRQ(ierr);
958949adfdSHong Zhang 
96905b3b74SHong Zhang   /* transpose local array of ct to matrix C */
978949adfdSHong Zhang   ierr = MatDenseGetArray(C,&Carray);CHKERRQ(ierr);
981683a169SBarry Smith   ierr = VecGetArrayRead(ct,&ctarray);CHKERRQ(ierr);
99905b3b74SHong Zhang 
1008949adfdSHong Zhang   for (j=0; j<BN; j++) {
101905b3b74SHong Zhang     for (i=0; i<n; i++) Carray[j*ldc + i] = ctarray[i*BN + j];
1028949adfdSHong Zhang   }
1031683a169SBarry Smith   ierr = VecRestoreArrayRead(ct,&ctarray);CHKERRQ(ierr);
1048949adfdSHong Zhang   ierr = MatDenseRestoreArray(C,&Carray);CHKERRQ(ierr);
1058949adfdSHong Zhang   ierr = MatAssemblyBegin(C,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
1068949adfdSHong Zhang   ierr = MatAssemblyEnd(C,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
1078949adfdSHong Zhang   PetscFunctionReturn(0);
1088949adfdSHong Zhang }
109