1 /*
2 Defines matrix-matrix product routines for
3 C = A^T * B and C = A * B^t
4 with A SeqAIJ and B SeqDense
5 */
6
7 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
8 #include <../src/mat/impls/dense/seq/dense.h>
9
MatProductCtxDestroy_SeqDense_MatTransMatMult(PetscCtxRt data)10 static PetscErrorCode MatProductCtxDestroy_SeqDense_MatTransMatMult(PetscCtxRt data)
11 {
12 MatProductCtx_MatTransMatMult *atb = *(MatProductCtx_MatTransMatMult **)data;
13
14 PetscFunctionBegin;
15 PetscCall(MatDestroy(&atb->mA));
16 PetscCall(VecDestroy(&atb->bt));
17 PetscCall(VecDestroy(&atb->ct));
18 PetscCall(PetscFree(atb));
19 PetscFunctionReturn(PETSC_SUCCESS);
20 }
21
22 static PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat, Mat, Mat);
23
MatTMatTMultSymbolic_SeqAIJ_SeqDense(Mat A,Mat B,PetscReal fill,Mat C)24 PETSC_INTERN PetscErrorCode MatTMatTMultSymbolic_SeqAIJ_SeqDense(Mat A, Mat B, PetscReal fill, Mat C)
25 {
26 MatProductCtx_MatTransMatMult *atb;
27 PetscBool cisdense;
28 PetscInt dofm;
29
30 PetscFunctionBegin;
31 MatCheckProduct(C, 4);
32 PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty");
33 PetscCheck(C->product->type == MATPRODUCT_ABt || C->product->type == MATPRODUCT_AtB, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[C->product->type]);
34
35 /* create output dense matrix C */
36 if (C->product->type == MATPRODUCT_AtB) {
37 PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->N, A->cmap->n, B->cmap->N));
38 dofm = B->cmap->n;
39 } else {
40 PetscCall(MatSetSizes(C, A->rmap->n, B->rmap->N, A->rmap->n, B->rmap->N));
41 dofm = B->rmap->n;
42 }
43 PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATSEQDENSE, MATSEQDENSECUDA, ""));
44 if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name));
45 PetscCall(MatSetUp(C));
46
47 /* create additional data structure for the product */
48 PetscCall(PetscNew(&atb));
49 PetscCall(MatCreateMAIJ(A, dofm, &atb->mA));
50 PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt));
51 C->product->data = atb;
52 C->product->destroy = MatProductCtxDestroy_SeqDense_MatTransMatMult;
53
54 if (C->product->type == MATPRODUCT_AtB) {
55 C->ops->transposematmultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense;
56 } else {
57 C->ops->mattransposemultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense;
58 }
59 PetscFunctionReturn(PETSC_SUCCESS);
60 }
61
MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat A,Mat B,Mat C)62 static PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat A, Mat B, Mat C)
63 {
64 PetscInt i, j, m = A->rmap->n, n = A->cmap->n, blda, clda;
65 PetscInt mdof = C->cmap->N;
66 const PetscScalar *Barray;
67 PetscScalar *Carray;
68 MatProductCtx_MatTransMatMult *atb;
69 Vec bt, ct;
70
71 PetscFunctionBegin;
72 MatCheckProduct(C, 3);
73 PetscCheck(C->product->type == MATPRODUCT_ABt || C->product->type == MATPRODUCT_AtB, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[C->product->type]);
74 atb = (MatProductCtx_MatTransMatMult *)C->product->data;
75 PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct");
76 bt = atb->bt;
77 ct = atb->ct;
78
79 PetscCall(MatDenseGetArrayRead(B, &Barray));
80 PetscCall(MatDenseGetLDA(B, &blda));
81 PetscCall(MatDenseGetArrayWrite(C, &Carray));
82 PetscCall(MatDenseGetLDA(C, &clda));
83 if (C->product->type == MATPRODUCT_AtB) { /* transpose local array of B, then copy it to vector bt */
84 const PetscScalar *ctarray;
85 PetscScalar *btarray;
86
87 PetscCall(VecGetArrayWrite(bt, &btarray));
88 for (j = 0; j < mdof; j++) {
89 for (i = 0; i < m; i++) btarray[i * mdof + j] = Barray[j * blda + i];
90 }
91 PetscCall(VecRestoreArrayWrite(bt, &btarray));
92
93 /* compute ct = mA^T * cb */
94 PetscCall(MatMultTranspose(atb->mA, bt, ct));
95
96 /* transpose local array of ct to matrix C */
97 PetscCall(VecGetArrayRead(ct, &ctarray));
98 for (j = 0; j < mdof; j++) {
99 for (i = 0; i < n; i++) Carray[j * clda + i] = ctarray[i * mdof + j];
100 }
101 PetscCall(VecRestoreArrayRead(ct, &ctarray));
102 } else {
103 const PetscScalar *btarray;
104 PetscScalar *ctarray;
105
106 if (blda == B->rmap->n) {
107 PetscCall(VecPlaceArray(ct, Barray));
108 } else {
109 PetscInt bn = B->cmap->n;
110 PetscInt bm = B->rmap->n;
111
112 PetscCall(VecGetArrayWrite(ct, &ctarray));
113 for (j = 0; j < bn; j++) {
114 for (i = 0; i < bm; i++) ctarray[j * bm + i] = Barray[j * blda + i];
115 }
116 PetscCall(VecRestoreArrayWrite(ct, &ctarray));
117 }
118
119 PetscCall(MatMult(atb->mA, ct, bt));
120 if (blda == B->rmap->n) PetscCall(VecResetArray(ct));
121 PetscCall(VecGetArrayRead(bt, &btarray));
122 for (j = 0; j < mdof; j++) {
123 for (i = 0; i < m; i++) Carray[j * clda + i] = btarray[i * mdof + j];
124 }
125 PetscCall(VecRestoreArrayRead(bt, &btarray));
126 }
127 PetscCall(MatDenseRestoreArrayRead(B, &Barray));
128 PetscCall(MatDenseRestoreArray(C, &Carray));
129 PetscFunctionReturn(PETSC_SUCCESS);
130 }
131