xref: /petsc/src/mat/impls/aij/mpi/mpimattransposematmult.c (revision a69119a591a03a9d906b29c0a4e9802e4d7c9795) !
1 
2 /*
3   Defines matrix-matrix product routines for pairs of MPIAIJ matrices
4           C = A^T * B
5   The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
6 */
7 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
8 #include <../src/mat/impls/aij/mpi/mpiaij.h>
9 #include <../src/mat/impls/dense/mpi/mpidense.h>
10 
11 PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data) {
12   Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data;
13 
14   PetscFunctionBegin;
15   PetscCall(MatDestroy(&atb->mA));
16   PetscCall(VecDestroy(&atb->bt));
17   PetscCall(VecDestroy(&atb->ct));
18   PetscCall(PetscFree(atb));
19   PetscFunctionReturn(0);
20 }
21 
22 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);
23 
24 PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C) {
25   Mat_MatTransMatMult *atb;
26   PetscBool            cisdense;
27 
28   PetscFunctionBegin;
29   MatCheckProduct(C, 4);
30   PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty");
31 
32   /* create output dense matrix C = A^T*B */
33   PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N));
34   PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, ""));
35   if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name));
36   PetscCall(MatSetUp(C));
37 
38   /* create additional data structure for the product */
39   PetscCall(PetscNew(&atb));
40   if (B->cmap->N) {
41     PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA));
42     if (!atb->mA->assembled) {
43       PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY));
44       PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY));
45     }
46     PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt));
47   }
48   C->product->data    = atb;
49   C->product->destroy = MatDestroy_MPIDense_MatTransMatMult;
50 
51   C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
52   PetscFunctionReturn(0);
53 }
54 
55 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C) {
56   const PetscScalar   *Barray, *ctarray;
57   PetscScalar         *Carray, *btarray;
58   PetscInt             i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
59   Mat_MatTransMatMult *atb;
60   Vec                  bt, ct;
61 
62   PetscFunctionBegin;
63   MatCheckProduct(C, 3);
64   atb = (Mat_MatTransMatMult *)C->product->data;
65   PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct");
66   if (!BN) {
67     PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
68     PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
69     PetscFunctionReturn(0);
70   }
71   bt = atb->bt;
72   ct = atb->ct;
73 
74   /* transpose local array of B, then copy it to vector bt */
75   PetscCall(MatDenseGetArrayRead(B, &Barray));
76   PetscCall(MatDenseGetLDA(B, &ldb));
77   PetscCall(VecGetArray(bt, &btarray));
78   for (j = 0; j < BN; j++)
79     for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
80   PetscCall(VecRestoreArray(bt, &btarray));
81   PetscCall(MatDenseRestoreArrayRead(B, &Barray));
82 
83   /* compute ct = mA^T * cb */
84   PetscCall(MatMultTranspose(atb->mA, bt, ct));
85 
86   /* transpose local array of ct to matrix C */
87   PetscCall(MatDenseGetArray(C, &Carray));
88   PetscCall(MatDenseGetLDA(C, &ldc));
89   PetscCall(VecGetArrayRead(ct, &ctarray));
90   for (j = 0; j < BN; j++)
91     for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
92   PetscCall(VecRestoreArrayRead(ct, &ctarray));
93   PetscCall(MatDenseRestoreArray(C, &Carray));
94   PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
95   PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
96   PetscFunctionReturn(0);
97 }
98