1 /*
2 Defines matrix-matrix product routines for pairs of MPIAIJ matrices
3 C = A^T * B
4 The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
5 */
6 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
7 #include <../src/mat/impls/aij/mpi/mpiaij.h>
8 #include <../src/mat/impls/dense/mpi/mpidense.h>
9
MatProductCtxDestroy_MPIDense_MatTransMatMult(PetscCtxRt data)10 static PetscErrorCode MatProductCtxDestroy_MPIDense_MatTransMatMult(PetscCtxRt data)
11 {
12 MatProductCtx_MatTransMatMult *atb = *(MatProductCtx_MatTransMatMult **)data;
13
14 PetscFunctionBegin;
15 PetscCall(MatDestroy(&atb->mA));
16 PetscCall(VecDestroy(&atb->bt));
17 PetscCall(VecDestroy(&atb->ct));
18 PetscCall(PetscFree(atb));
19 PetscFunctionReturn(PETSC_SUCCESS);
20 }
21
22 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);
23
MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A,Mat B,PetscReal fill,Mat C)24 PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C)
25 {
26 MatProductCtx_MatTransMatMult *atb;
27 PetscBool cisdense;
28
29 PetscFunctionBegin;
30 MatCheckProduct(C, 4);
31 PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty");
32
33 /* create output dense matrix C = A^T*B */
34 PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N));
35 PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, ""));
36 if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name));
37 PetscCall(MatSetUp(C));
38
39 /* create additional data structure for the product */
40 PetscCall(PetscNew(&atb));
41 if (B->cmap->N) {
42 PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA));
43 if (!atb->mA->assembled) {
44 PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY));
45 PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY));
46 }
47 PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt));
48 }
49 C->product->data = atb;
50 C->product->destroy = MatProductCtxDestroy_MPIDense_MatTransMatMult;
51
52 C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
53 PetscFunctionReturn(PETSC_SUCCESS);
54 }
55
MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A,Mat B,Mat C)56 static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C)
57 {
58 const PetscScalar *Barray, *ctarray;
59 PetscScalar *Carray, *btarray;
60 PetscInt i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
61 MatProductCtx_MatTransMatMult *atb;
62 Vec bt, ct;
63
64 PetscFunctionBegin;
65 MatCheckProduct(C, 3);
66 atb = (MatProductCtx_MatTransMatMult *)C->product->data;
67 PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct");
68 if (!BN) {
69 PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
70 PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
71 PetscFunctionReturn(PETSC_SUCCESS);
72 }
73 bt = atb->bt;
74 ct = atb->ct;
75
76 /* transpose local array of B, then copy it to vector bt */
77 PetscCall(MatDenseGetArrayRead(B, &Barray));
78 PetscCall(MatDenseGetLDA(B, &ldb));
79 PetscCall(VecGetArray(bt, &btarray));
80 for (j = 0; j < BN; j++)
81 for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
82 PetscCall(VecRestoreArray(bt, &btarray));
83 PetscCall(MatDenseRestoreArrayRead(B, &Barray));
84
85 /* compute ct = mA^T * cb */
86 PetscCall(MatMultTranspose(atb->mA, bt, ct));
87
88 /* transpose local array of ct to matrix C */
89 PetscCall(MatDenseGetArray(C, &Carray));
90 PetscCall(MatDenseGetLDA(C, &ldc));
91 PetscCall(VecGetArrayRead(ct, &ctarray));
92 for (j = 0; j < BN; j++)
93 for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
94 PetscCall(VecRestoreArrayRead(ct, &ctarray));
95 PetscCall(MatDenseRestoreArray(C, &Carray));
96 PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
97 PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
98 PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
99 PetscFunctionReturn(PETSC_SUCCESS);
100 }
101