xref: /petsc/src/mat/impls/aij/seq/mattransposematmult.c (revision a4af0ceea8a251db97ee0dc5c0d52d4adf50264a)
1 
2 /*
3   Defines matrix-matrix product routines for
4           C = A^T * B and C = A * B^t
5   with A SeqAIJ and B SeqDense
6 */
7 
8 #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
9 #include <../src/mat/impls/dense/seq/dense.h>
10 
11 PetscErrorCode MatDestroy_SeqDense_MatTransMatMult(void *data)
12 {
13   PetscErrorCode      ierr;
14   Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data;
15 
16   PetscFunctionBegin;
17   ierr = MatDestroy(&atb->mA);CHKERRQ(ierr);
18   ierr = VecDestroy(&atb->bt);CHKERRQ(ierr);
19   ierr = VecDestroy(&atb->ct);CHKERRQ(ierr);
20   ierr = PetscFree(atb);CHKERRQ(ierr);
21   PetscFunctionReturn(0);
22 }
23 
24 static PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat,Mat,Mat);
25 
26 PETSC_INTERN PetscErrorCode MatTMatTMultSymbolic_SeqAIJ_SeqDense(Mat A,Mat B,PetscReal fill,Mat C)
27 {
28   PetscErrorCode      ierr;
29   Mat_MatTransMatMult *atb;
30   PetscBool           cisdense;
31   PetscInt            dofm;
32 
33   PetscFunctionBegin;
34   MatCheckProduct(C,4);
35   if (C->product->data) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Extra product struct not empty");
36   if (C->product->type != MATPRODUCT_ABt && C->product->type != MATPRODUCT_AtB) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[C->product->type]);
37 
38   /* create output dense matrix C */
39   if (C->product->type == MATPRODUCT_AtB) {
40     ierr = MatSetSizes(C,A->cmap->n,B->cmap->N,A->cmap->n,B->cmap->N);CHKERRQ(ierr);
41     dofm = B->cmap->n;
42   } else {
43     ierr = MatSetSizes(C,A->rmap->n,B->rmap->N,A->rmap->n,B->rmap->N);CHKERRQ(ierr);
44     dofm = B->rmap->n;
45   }
46   ierr = PetscObjectTypeCompareAny((PetscObject)C,&cisdense,MATSEQDENSE,MATSEQDENSECUDA,"");CHKERRQ(ierr);
47   if (!cisdense) {
48     ierr = MatSetType(C,((PetscObject)B)->type_name);CHKERRQ(ierr);
49   }
50   ierr = MatSetUp(C);CHKERRQ(ierr);
51 
52   /* create additional data structure for the product */
53   ierr = PetscNew(&atb);CHKERRQ(ierr);
54   ierr = MatCreateMAIJ(A,dofm,&atb->mA);CHKERRQ(ierr);
55   ierr = MatCreateVecs(atb->mA,&atb->ct,&atb->bt);CHKERRQ(ierr);
56   C->product->data    = atb;
57   C->product->destroy = MatDestroy_SeqDense_MatTransMatMult;
58 
59   if (C->product->type == MATPRODUCT_AtB) {
60     C->ops->transposematmultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense;
61   } else {
62     C->ops->mattransposemultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense;
63   }
64   PetscFunctionReturn(0);
65 }
66 
67 PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat A,Mat B,Mat C)
68 {
69   PetscErrorCode      ierr;
70   PetscInt            i,j,m=A->rmap->n,n=A->cmap->n,blda,clda;
71   PetscInt            mdof = C->cmap->N;
72   const PetscScalar   *Barray;
73   PetscScalar         *Carray;
74   Mat_MatTransMatMult *atb;
75   Vec                 bt,ct;
76 
77   PetscFunctionBegin;
78   MatCheckProduct(C,3);
79   if (C->product->type != MATPRODUCT_ABt && C->product->type != MATPRODUCT_AtB) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Not for product type %s",MatProductTypes[C->product->type]);
80   atb = (Mat_MatTransMatMult *)C->product->data;
81   if (!atb) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing product struct");
82   bt = atb->bt;
83   ct = atb->ct;
84 
85   ierr = MatDenseGetArrayRead(B,&Barray);CHKERRQ(ierr);
86   ierr = MatDenseGetLDA(B,&blda);CHKERRQ(ierr);
87   ierr = MatDenseGetArrayWrite(C,&Carray);CHKERRQ(ierr);
88   ierr = MatDenseGetLDA(C,&clda);CHKERRQ(ierr);
89   if (C->product->type == MATPRODUCT_AtB) { /* transpose local array of B, then copy it to vector bt */
90     const PetscScalar *ctarray;
91     PetscScalar       *btarray;
92 
93     ierr = VecGetArrayWrite(bt,&btarray);CHKERRQ(ierr);
94     for (j=0; j<mdof; j++) {
95       for (i=0; i<m; i++) btarray[i*mdof + j] = Barray[j*blda + i];
96     }
97     ierr = VecRestoreArrayWrite(bt,&btarray);CHKERRQ(ierr);
98 
99     /* compute ct = mA^T * cb */
100     ierr = MatMultTranspose(atb->mA,bt,ct);CHKERRQ(ierr);
101 
102     /* transpose local array of ct to matrix C */
103     ierr = VecGetArrayRead(ct,&ctarray);CHKERRQ(ierr);
104     for (j=0; j<mdof; j++) {
105       for (i=0; i<n; i++) Carray[j*clda + i] = ctarray[i*mdof + j];
106     }
107     ierr = VecRestoreArrayRead(ct,&ctarray);CHKERRQ(ierr);
108   } else {
109     const PetscScalar *btarray;
110     PetscScalar       *ctarray;
111 
112     if (blda == B->rmap->n) {
113       ierr = VecPlaceArray(ct,Barray);CHKERRQ(ierr);
114     } else {
115       PetscInt bn = B->cmap->n;
116       PetscInt bm = B->rmap->n;
117 
118       ierr = VecGetArrayWrite(ct,&ctarray);CHKERRQ(ierr);
119       for (j=0; j<bn; j++) {
120         for (i=0; i<bm; i++) ctarray[j*bm + i] = Barray[j*blda + i];
121       }
122       ierr = VecRestoreArrayWrite(ct,&ctarray);CHKERRQ(ierr);
123     }
124 
125     ierr = MatMult(atb->mA,ct,bt);CHKERRQ(ierr);
126     if (blda == B->rmap->n) {
127       ierr = VecResetArray(ct);CHKERRQ(ierr);
128     }
129     ierr = VecGetArrayRead(bt,&btarray);CHKERRQ(ierr);
130     for (j=0; j<mdof; j++) {
131       for (i=0; i<m; i++) Carray[j*clda + i] = btarray[i*mdof + j];
132     }
133     ierr = VecRestoreArrayRead(bt,&btarray);CHKERRQ(ierr);
134   }
135   ierr = MatDenseRestoreArrayRead(B,&Barray);CHKERRQ(ierr);
136   ierr = MatDenseRestoreArray(C,&Carray);CHKERRQ(ierr);
137   PetscFunctionReturn(0);
138 }
139