xref: /petsc/src/mat/impls/aij/mpi/mpimattransposematmult.c (revision d5b43468fb8780a8feea140ccd6fa3e6a50411cc)
1 
2 /*
3   Defines matrix-matrix product routines for pairs of MPIAIJ matrices
4           C = A^T * B
5   The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
6 */
7 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
8 #include <../src/mat/impls/aij/mpi/mpiaij.h>
9 #include <../src/mat/impls/dense/mpi/mpidense.h>
10 
11 PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data)
12 {
13   Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data;
14 
15   PetscFunctionBegin;
16   PetscCall(MatDestroy(&atb->mA));
17   PetscCall(VecDestroy(&atb->bt));
18   PetscCall(VecDestroy(&atb->ct));
19   PetscCall(PetscFree(atb));
20   PetscFunctionReturn(0);
21 }
22 
23 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);
24 
25 PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C)
26 {
27   Mat_MatTransMatMult *atb;
28   PetscBool            cisdense;
29 
30   PetscFunctionBegin;
31   MatCheckProduct(C, 4);
32   PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty");
33 
34   /* create output dense matrix C = A^T*B */
35   PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N));
36   PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, ""));
37   if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name));
38   PetscCall(MatSetUp(C));
39 
40   /* create additional data structure for the product */
41   PetscCall(PetscNew(&atb));
42   if (B->cmap->N) {
43     PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA));
44     if (!atb->mA->assembled) {
45       PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY));
46       PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY));
47     }
48     PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt));
49   }
50   C->product->data    = atb;
51   C->product->destroy = MatDestroy_MPIDense_MatTransMatMult;
52 
53   C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
54   PetscFunctionReturn(0);
55 }
56 
57 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C)
58 {
59   const PetscScalar   *Barray, *ctarray;
60   PetscScalar         *Carray, *btarray;
61   PetscInt             i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
62   Mat_MatTransMatMult *atb;
63   Vec                  bt, ct;
64 
65   PetscFunctionBegin;
66   MatCheckProduct(C, 3);
67   atb = (Mat_MatTransMatMult *)C->product->data;
68   PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct");
69   if (!BN) {
70     PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
71     PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
72     PetscFunctionReturn(0);
73   }
74   bt = atb->bt;
75   ct = atb->ct;
76 
77   /* transpose local array of B, then copy it to vector bt */
78   PetscCall(MatDenseGetArrayRead(B, &Barray));
79   PetscCall(MatDenseGetLDA(B, &ldb));
80   PetscCall(VecGetArray(bt, &btarray));
81   for (j = 0; j < BN; j++)
82     for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
83   PetscCall(VecRestoreArray(bt, &btarray));
84   PetscCall(MatDenseRestoreArrayRead(B, &Barray));
85 
86   /* compute ct = mA^T * cb */
87   PetscCall(MatMultTranspose(atb->mA, bt, ct));
88 
89   /* transpose local array of ct to matrix C */
90   PetscCall(MatDenseGetArray(C, &Carray));
91   PetscCall(MatDenseGetLDA(C, &ldc));
92   PetscCall(VecGetArrayRead(ct, &ctarray));
93   for (j = 0; j < BN; j++)
94     for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
95   PetscCall(VecRestoreArrayRead(ct, &ctarray));
96   PetscCall(MatDenseRestoreArray(C, &Carray));
97   PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
98   PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
99   PetscFunctionReturn(0);
100 }
101