xref: /petsc/src/mat/impls/aij/mpi/mpimattransposematmult.c (revision 4e8208cbcbc709572b8abe32f33c78b69c819375)
1 /*
2   Defines matrix-matrix product routines for pairs of MPIAIJ matrices
3           C = A^T * B
4   The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
5 */
6 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
7 #include <../src/mat/impls/aij/mpi/mpiaij.h>
8 #include <../src/mat/impls/dense/mpi/mpidense.h>
9 
MatProductCtxDestroy_MPIDense_MatTransMatMult(PetscCtxRt data)10 static PetscErrorCode MatProductCtxDestroy_MPIDense_MatTransMatMult(PetscCtxRt data)
11 {
12   MatProductCtx_MatTransMatMult *atb = *(MatProductCtx_MatTransMatMult **)data;
13 
14   PetscFunctionBegin;
15   PetscCall(MatDestroy(&atb->mA));
16   PetscCall(VecDestroy(&atb->bt));
17   PetscCall(VecDestroy(&atb->ct));
18   PetscCall(PetscFree(atb));
19   PetscFunctionReturn(PETSC_SUCCESS);
20 }
21 
22 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);
23 
MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A,Mat B,PetscReal fill,Mat C)24 PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C)
25 {
26   MatProductCtx_MatTransMatMult *atb;
27   PetscBool                      cisdense;
28 
29   PetscFunctionBegin;
30   MatCheckProduct(C, 4);
31   PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty");
32 
33   /* create output dense matrix C = A^T*B */
34   PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N));
35   PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, ""));
36   if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name));
37   PetscCall(MatSetUp(C));
38 
39   /* create additional data structure for the product */
40   PetscCall(PetscNew(&atb));
41   if (B->cmap->N) {
42     PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA));
43     if (!atb->mA->assembled) {
44       PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY));
45       PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY));
46     }
47     PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt));
48   }
49   C->product->data    = atb;
50   C->product->destroy = MatProductCtxDestroy_MPIDense_MatTransMatMult;
51 
52   C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
53   PetscFunctionReturn(PETSC_SUCCESS);
54 }
55 
MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A,Mat B,Mat C)56 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C)
57 {
58   const PetscScalar             *Barray, *ctarray;
59   PetscScalar                   *Carray, *btarray;
60   PetscInt                       i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
61   MatProductCtx_MatTransMatMult *atb;
62   Vec                            bt, ct;
63 
64   PetscFunctionBegin;
65   MatCheckProduct(C, 3);
66   atb = (MatProductCtx_MatTransMatMult *)C->product->data;
67   PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct");
68   if (!BN) {
69     PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
70     PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
71     PetscFunctionReturn(PETSC_SUCCESS);
72   }
73   bt = atb->bt;
74   ct = atb->ct;
75 
76   /* transpose local array of B, then copy it to vector bt */
77   PetscCall(MatDenseGetArrayRead(B, &Barray));
78   PetscCall(MatDenseGetLDA(B, &ldb));
79   PetscCall(VecGetArray(bt, &btarray));
80   for (j = 0; j < BN; j++)
81     for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
82   PetscCall(VecRestoreArray(bt, &btarray));
83   PetscCall(MatDenseRestoreArrayRead(B, &Barray));
84 
85   /* compute ct = mA^T * cb */
86   PetscCall(MatMultTranspose(atb->mA, bt, ct));
87 
88   /* transpose local array of ct to matrix C */
89   PetscCall(MatDenseGetArray(C, &Carray));
90   PetscCall(MatDenseGetLDA(C, &ldc));
91   PetscCall(VecGetArrayRead(ct, &ctarray));
92   for (j = 0; j < BN; j++)
93     for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
94   PetscCall(VecRestoreArrayRead(ct, &ctarray));
95   PetscCall(MatDenseRestoreArray(C, &Carray));
96   PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
97   PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
98   PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
99   PetscFunctionReturn(PETSC_SUCCESS);
100 }
101