xref: /petsc/src/mat/impls/aij/mpi/mpimatmatmatmult.c (revision d5c9c0c4eebc2f2a01a1bd0c86fca87e2acd2a03)
1 /*
2   Defines matrix-matrix-matrix product routines for MPIAIJ matrices
3           D = A * B * C
4 */
5 #include <../src/mat/impls/aij/mpi/mpiaij.h> /*I "petscmat.h" I*/
6 
7 #if defined(PETSC_HAVE_HYPRE)
8 PETSC_INTERN PetscErrorCode MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,PetscReal,Mat);
9 PETSC_INTERN PetscErrorCode MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,Mat);
10 
11 PETSC_INTERN PetscErrorCode MatProductNumeric_ABC_Transpose_AIJ_AIJ(Mat RAP)
12 {
13   PetscErrorCode ierr;
14   Mat_Product    *product = RAP->product;
15   Mat            Rt,R=product->A,A=product->B,P=product->C;
16 
17   PetscFunctionBegin;
18   ierr = MatTransposeGetMat(R,&Rt);CHKERRQ(ierr);
19   ierr = MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,RAP);CHKERRQ(ierr);
20   PetscFunctionReturn(0);
21 }
22 
23 PETSC_INTERN PetscErrorCode MatProductSymbolic_ABC_Transpose_AIJ_AIJ(Mat RAP)
24 {
25   PetscErrorCode ierr;
26   Mat_Product    *product = RAP->product;
27   Mat            Rt,R=product->A,A=product->B,P=product->C;
28   PetscBool      flg;
29 
30   PetscFunctionBegin;
31   /* local sizes of matrices will be checked by the calling subroutines */
32   ierr = MatTransposeGetMat(R,&Rt);CHKERRQ(ierr);
33   ierr = PetscObjectTypeCompareAny((PetscObject)Rt,&flg,MATSEQAIJ,MATSEQAIJMKL,MATMPIAIJ,NULL);CHKERRQ(ierr);
34   if (!flg) SETERRQ1(PetscObjectComm((PetscObject)Rt),PETSC_ERR_SUP,"Not for matrix type %s",((PetscObject)Rt)->type_name);
35   ierr = MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,product->fill,RAP);CHKERRQ(ierr);
36   RAP->ops->productnumeric = MatProductNumeric_ABC_Transpose_AIJ_AIJ;
37   PetscFunctionReturn(0);
38 }
39 
40 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose_AIJ_AIJ(Mat C)
41 {
42   Mat_Product *product = C->product;
43 
44   PetscFunctionBegin;
45   if (product->type == MATPRODUCT_ABC) {
46     C->ops->productsymbolic = MatProductSymbolic_ABC_Transpose_AIJ_AIJ;
47   } else SETERRQ1(PetscObjectComm((PetscObject)C),PETSC_ERR_SUP,"MatProduct type %s is not supported for Transpose, AIJ and AIJ matrices",MatProductTypes[product->type]);
48   PetscFunctionReturn(0);
49 }
50 #endif
51 
52 PetscErrorCode MatFreeIntermediateDataStructures_MPIAIJ_BC(Mat ABC)
53 {
54   Mat_MPIAIJ        *a = (Mat_MPIAIJ*)ABC->data;
55   Mat_MatMatMatMult *matmatmatmult = a->matmatmatmult;
56   PetscErrorCode    ierr;
57 
58   PetscFunctionBegin;
59   if (!matmatmatmult) PetscFunctionReturn(0);
60 
61   ierr = MatDestroy(&matmatmatmult->BC);CHKERRQ(ierr);
62   ABC->ops->destroy = matmatmatmult->destroy;
63   ierr = PetscFree(a->matmatmatmult);CHKERRQ(ierr);
64   PetscFunctionReturn(0);
65 }
66 
67 PetscErrorCode MatDestroy_MPIAIJ_MatMatMatMult(Mat A)
68 {
69   PetscErrorCode    ierr;
70 
71   PetscFunctionBegin;
72   ierr = (*A->ops->freeintermediatedatastructures)(A);CHKERRQ(ierr);
73   ierr = (*A->ops->destroy)(A);CHKERRQ(ierr);
74   PetscFunctionReturn(0);
75 }
76 
77 PetscErrorCode MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,PetscReal fill,Mat D)
78 {
79   PetscErrorCode ierr;
80   Mat            BC;
81   PetscBool      scalable;
82   Mat_Product    *product = D->product;
83 
84   PetscFunctionBegin;
85   ierr = MatCreate(PetscObjectComm((PetscObject)A),&BC);CHKERRQ(ierr);
86   if (product) {
87     ierr = PetscStrcmp(product->alg,"scalable",&scalable);CHKERRQ(ierr);
88   } else SETERRQ(PetscObjectComm((PetscObject)D),PETSC_ERR_ARG_NULL,"Call MatProductCreate() first");
89 
90   if (scalable) {
91     ierr = MatMatMultSymbolic_MPIAIJ_MPIAIJ(B,C,fill,BC);CHKERRQ(ierr);
92     ierr = MatZeroEntries(BC);CHKERRQ(ierr); /* initialize value entries of BC */
93     ierr = MatMatMultSymbolic_MPIAIJ_MPIAIJ(A,BC,fill,D);CHKERRQ(ierr);
94   } else {
95     ierr = MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(B,C,fill,BC);CHKERRQ(ierr);
96     ierr = MatZeroEntries(BC);CHKERRQ(ierr); /* initialize value entries of BC */
97     ierr = MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(A,BC,fill,D);CHKERRQ(ierr);
98   }
99   product->Dwork = BC;
100 
101   D->ops->matmatmultnumeric = MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ;
102   D->ops->freeintermediatedatastructures = MatFreeIntermediateDataStructures_MPIAIJ_BC;
103   PetscFunctionReturn(0);
104 }
105 
106 PetscErrorCode MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,Mat D)
107 {
108   PetscErrorCode ierr;
109   Mat_Product    *product = D->product;
110   Mat            BC = product->Dwork;
111 
112   PetscFunctionBegin;
113   ierr = (BC->ops->matmultnumeric)(B,C,BC);CHKERRQ(ierr);
114   ierr = (D->ops->matmultnumeric)(A,BC,D);CHKERRQ(ierr);
115   PetscFunctionReturn(0);
116 }
117 
118 /* ----------------------------------------------------- */
119 PetscErrorCode MatDestroy_MPIAIJ_RARt(Mat C)
120 {
121   PetscErrorCode ierr;
122   Mat_MPIAIJ     *c    = (Mat_MPIAIJ*)C->data;
123   Mat_RARt       *rart = c->rart;
124 
125   PetscFunctionBegin;
126   ierr = MatDestroy(&rart->Rt);CHKERRQ(ierr);
127 
128   C->ops->destroy = rart->destroy;
129   if (C->ops->destroy) {
130     ierr = (*C->ops->destroy)(C);CHKERRQ(ierr);
131   }
132   ierr = PetscFree(rart);CHKERRQ(ierr);
133   PetscFunctionReturn(0);
134 }
135 
136 PetscErrorCode MatProductNumeric_RARt_MPIAIJ_MPIAIJ(Mat C)
137 {
138   PetscErrorCode ierr;
139   Mat_MPIAIJ     *c = (Mat_MPIAIJ*)C->data;
140   Mat_RARt       *rart = c->rart;
141   Mat_Product    *product = C->product;
142   Mat            A=product->A,R=product->B,Rt=rart->Rt;
143 
144   PetscFunctionBegin;
145   ierr = MatTranspose(R,MAT_REUSE_MATRIX,&Rt);CHKERRQ(ierr);
146   ierr = (C->ops->matmatmultnumeric)(R,A,Rt,C);CHKERRQ(ierr);
147   PetscFunctionReturn(0);
148 }
149 
150 PetscErrorCode MatProductSymbolic_RARt_MPIAIJ_MPIAIJ(Mat C)
151 {
152   PetscErrorCode      ierr;
153   Mat_Product         *product = C->product;
154   Mat                 A=product->A,R=product->B,Rt;
155   PetscReal           fill=product->fill;
156   Mat_RARt            *rart;
157   Mat_MPIAIJ          *c;
158 
159   PetscFunctionBegin;
160   ierr = MatTranspose(R,MAT_INITIAL_MATRIX,&Rt);CHKERRQ(ierr);
161   /* product->Dwork is used to store A*Rt in MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ() */
162   ierr = MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(R,A,Rt,fill,C);CHKERRQ(ierr);
163   C->ops->productnumeric = MatProductNumeric_RARt_MPIAIJ_MPIAIJ;
164 
165   /* create a supporting struct */
166   ierr     = PetscNew(&rart);CHKERRQ(ierr);
167   c        = (Mat_MPIAIJ*)C->data;
168   c->rart  = rart;
169   rart->Rt = Rt;
170   rart->destroy   = C->ops->destroy;
171   C->ops->destroy = MatDestroy_MPIAIJ_RARt;
172   PetscFunctionReturn(0);
173 }
174