xref: /petsc/src/mat/impls/transpose/transm.c (revision 750b007cd8d816cecd9de99077bb0a703b4cf61a)
1 
2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
3 
4 typedef struct {
5   Mat A;
6 } Mat_Transpose;
7 
8 PetscErrorCode MatMult_Transpose(Mat N, Vec x, Vec y) {
9   Mat_Transpose *Na = (Mat_Transpose *)N->data;
10 
11   PetscFunctionBegin;
12   PetscCall(MatMultTranspose(Na->A, x, y));
13   PetscFunctionReturn(0);
14 }
15 
16 PetscErrorCode MatMultAdd_Transpose(Mat N, Vec v1, Vec v2, Vec v3) {
17   Mat_Transpose *Na = (Mat_Transpose *)N->data;
18 
19   PetscFunctionBegin;
20   PetscCall(MatMultTransposeAdd(Na->A, v1, v2, v3));
21   PetscFunctionReturn(0);
22 }
23 
24 PetscErrorCode MatMultTranspose_Transpose(Mat N, Vec x, Vec y) {
25   Mat_Transpose *Na = (Mat_Transpose *)N->data;
26 
27   PetscFunctionBegin;
28   PetscCall(MatMult(Na->A, x, y));
29   PetscFunctionReturn(0);
30 }
31 
32 PetscErrorCode MatMultTransposeAdd_Transpose(Mat N, Vec v1, Vec v2, Vec v3) {
33   Mat_Transpose *Na = (Mat_Transpose *)N->data;
34 
35   PetscFunctionBegin;
36   PetscCall(MatMultAdd(Na->A, v1, v2, v3));
37   PetscFunctionReturn(0);
38 }
39 
40 PetscErrorCode MatDestroy_Transpose(Mat N) {
41   Mat_Transpose *Na = (Mat_Transpose *)N->data;
42 
43   PetscFunctionBegin;
44   PetscCall(MatDestroy(&Na->A));
45   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatTransposeGetMat_C", NULL));
46   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatProductSetFromOptions_anytype_C", NULL));
47   PetscCall(PetscFree(N->data));
48   PetscFunctionReturn(0);
49 }
50 
51 PetscErrorCode MatDuplicate_Transpose(Mat N, MatDuplicateOption op, Mat *m) {
52   Mat_Transpose *Na = (Mat_Transpose *)N->data;
53 
54   PetscFunctionBegin;
55   if (op == MAT_COPY_VALUES) {
56     PetscCall(MatTranspose(Na->A, MAT_INITIAL_MATRIX, m));
57   } else if (op == MAT_DO_NOT_COPY_VALUES) {
58     PetscCall(MatDuplicate(Na->A, MAT_DO_NOT_COPY_VALUES, m));
59     PetscCall(MatTranspose(*m, MAT_INPLACE_MATRIX, m));
60   } else SETERRQ(PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "MAT_SHARE_NONZERO_PATTERN not supported for this matrix type");
61   PetscFunctionReturn(0);
62 }
63 
64 PetscErrorCode MatCreateVecs_Transpose(Mat A, Vec *r, Vec *l) {
65   Mat_Transpose *Aa = (Mat_Transpose *)A->data;
66 
67   PetscFunctionBegin;
68   PetscCall(MatCreateVecs(Aa->A, l, r));
69   PetscFunctionReturn(0);
70 }
71 
72 PetscErrorCode MatAXPY_Transpose(Mat Y, PetscScalar a, Mat X, MatStructure str) {
73   Mat_Transpose *Ya = (Mat_Transpose *)Y->data;
74   Mat_Transpose *Xa = (Mat_Transpose *)X->data;
75   Mat            M  = Ya->A;
76   Mat            N  = Xa->A;
77 
78   PetscFunctionBegin;
79   PetscCall(MatAXPY(M, a, N, str));
80   PetscFunctionReturn(0);
81 }
82 
83 PetscErrorCode MatHasOperation_Transpose(Mat mat, MatOperation op, PetscBool *has) {
84   Mat_Transpose *X = (Mat_Transpose *)mat->data;
85   PetscFunctionBegin;
86 
87   *has = PETSC_FALSE;
88   if (op == MATOP_MULT) {
89     PetscCall(MatHasOperation(X->A, MATOP_MULT_TRANSPOSE, has));
90   } else if (op == MATOP_MULT_TRANSPOSE) {
91     PetscCall(MatHasOperation(X->A, MATOP_MULT, has));
92   } else if (op == MATOP_MULT_ADD) {
93     PetscCall(MatHasOperation(X->A, MATOP_MULT_TRANSPOSE_ADD, has));
94   } else if (op == MATOP_MULT_TRANSPOSE_ADD) {
95     PetscCall(MatHasOperation(X->A, MATOP_MULT_ADD, has));
96   } else if (((void **)mat->ops)[op]) *has = PETSC_TRUE;
97   PetscFunctionReturn(0);
98 }
99 
100 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose(Mat D) {
101   Mat            A, B, C, Ain, Bin, Cin;
102   PetscBool      Aistrans, Bistrans, Cistrans;
103   PetscInt       Atrans, Btrans, Ctrans;
104   MatProductType ptype;
105 
106   PetscFunctionBegin;
107   MatCheckProduct(D, 1);
108   A = D->product->A;
109   B = D->product->B;
110   C = D->product->C;
111   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATTRANSPOSEVIRTUAL, &Aistrans));
112   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATTRANSPOSEVIRTUAL, &Bistrans));
113   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATTRANSPOSEVIRTUAL, &Cistrans));
114   PetscCheck(Aistrans || Bistrans || Cistrans, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "This should not happen");
115   Atrans = 0;
116   Ain    = A;
117   while (Aistrans) {
118     Atrans++;
119     PetscCall(MatTransposeGetMat(Ain, &Ain));
120     PetscCall(PetscObjectTypeCompare((PetscObject)Ain, MATTRANSPOSEVIRTUAL, &Aistrans));
121   }
122   Btrans = 0;
123   Bin    = B;
124   while (Bistrans) {
125     Btrans++;
126     PetscCall(MatTransposeGetMat(Bin, &Bin));
127     PetscCall(PetscObjectTypeCompare((PetscObject)Bin, MATTRANSPOSEVIRTUAL, &Bistrans));
128   }
129   Ctrans = 0;
130   Cin    = C;
131   while (Cistrans) {
132     Ctrans++;
133     PetscCall(MatTransposeGetMat(Cin, &Cin));
134     PetscCall(PetscObjectTypeCompare((PetscObject)Cin, MATTRANSPOSEVIRTUAL, &Cistrans));
135   }
136   Atrans = Atrans % 2;
137   Btrans = Btrans % 2;
138   Ctrans = Ctrans % 2;
139   ptype  = D->product->type; /* same product type by default */
140   if (Ain->symmetric == PETSC_BOOL3_TRUE) Atrans = 0;
141   if (Bin->symmetric == PETSC_BOOL3_TRUE) Btrans = 0;
142   if (Cin && Cin->symmetric == PETSC_BOOL3_TRUE) Ctrans = 0;
143 
144   if (Atrans || Btrans || Ctrans) {
145     ptype = MATPRODUCT_UNSPECIFIED;
146     switch (D->product->type) {
147     case MATPRODUCT_AB:
148       if (Atrans && Btrans) { /* At * Bt we do not have support for this */
149         /* TODO custom implementation ? */
150       } else if (Atrans) { /* At * B */
151         ptype = MATPRODUCT_AtB;
152       } else { /* A * Bt */
153         ptype = MATPRODUCT_ABt;
154       }
155       break;
156     case MATPRODUCT_AtB:
157       if (Atrans && Btrans) { /* A * Bt */
158         ptype = MATPRODUCT_ABt;
159       } else if (Atrans) { /* A * B */
160         ptype = MATPRODUCT_AB;
161       } else { /* At * Bt we do not have support for this */
162         /* TODO custom implementation ? */
163       }
164       break;
165     case MATPRODUCT_ABt:
166       if (Atrans && Btrans) { /* At * B */
167         ptype = MATPRODUCT_AtB;
168       } else if (Atrans) { /* At * Bt we do not have support for this */
169         /* TODO custom implementation ? */
170       } else { /* A * B */
171         ptype = MATPRODUCT_AB;
172       }
173       break;
174     case MATPRODUCT_PtAP:
175       if (Atrans) { /* PtAtP */
176         /* TODO custom implementation ? */
177       } else { /* RARt */
178         ptype = MATPRODUCT_RARt;
179       }
180       break;
181     case MATPRODUCT_RARt:
182       if (Atrans) { /* RAtRt */
183         /* TODO custom implementation ? */
184       } else { /* PtAP */
185         ptype = MATPRODUCT_PtAP;
186       }
187       break;
188     case MATPRODUCT_ABC:
189       /* TODO custom implementation ? */
190       break;
191     default: SETERRQ(PetscObjectComm((PetscObject)D), PETSC_ERR_SUP, "ProductType %s is not supported", MatProductTypes[D->product->type]);
192     }
193   }
194   PetscCall(MatProductReplaceMats(Ain, Bin, Cin, D));
195   PetscCall(MatProductSetType(D, ptype));
196   PetscCall(MatProductSetFromOptions(D));
197   PetscFunctionReturn(0);
198 }
199 
200 PetscErrorCode MatGetDiagonal_Transpose(Mat A, Vec v) {
201   Mat_Transpose *Aa = (Mat_Transpose *)A->data;
202 
203   PetscFunctionBegin;
204   PetscCall(MatGetDiagonal(Aa->A, v));
205   PetscFunctionReturn(0);
206 }
207 
208 PetscErrorCode MatConvert_Transpose(Mat A, MatType newtype, MatReuse reuse, Mat *newmat) {
209   Mat_Transpose *Aa = (Mat_Transpose *)A->data;
210   PetscBool      flg;
211 
212   PetscFunctionBegin;
213   PetscCall(MatHasOperation(Aa->A, MATOP_TRANSPOSE, &flg));
214   if (flg) {
215     Mat B;
216 
217     PetscCall(MatTranspose(Aa->A, MAT_INITIAL_MATRIX, &B));
218     if (reuse != MAT_INPLACE_MATRIX) {
219       PetscCall(MatConvert(B, newtype, reuse, newmat));
220       PetscCall(MatDestroy(&B));
221     } else {
222       PetscCall(MatConvert(B, newtype, MAT_INPLACE_MATRIX, &B));
223       PetscCall(MatHeaderReplace(A, &B));
224     }
225   } else { /* use basic converter as fallback */
226     PetscCall(MatConvert_Basic(A, newtype, reuse, newmat));
227   }
228   PetscFunctionReturn(0);
229 }
230 
231 PetscErrorCode MatTransposeGetMat_Transpose(Mat A, Mat *M) {
232   Mat_Transpose *Aa = (Mat_Transpose *)A->data;
233 
234   PetscFunctionBegin;
235   *M = Aa->A;
236   PetscFunctionReturn(0);
237 }
238 
239 /*@
240       MatTransposeGetMat - Gets the `Mat` object stored inside a `MATTRANSPOSEVIRTUAL`
241 
242    Logically collective on A
243 
244    Input Parameter:
245 .   A  - the `MATTRANSPOSEVIRTUAL` matrix
246 
247    Output Parameter:
248 .   M - the matrix object stored inside A
249 
250    Level: intermediate
251 
252 .seealso: `MATTRANSPOSEVIRTUAL`, `MatCreateTranspose()`
253 @*/
254 PetscErrorCode MatTransposeGetMat(Mat A, Mat *M) {
255   PetscFunctionBegin;
256   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
257   PetscValidType(A, 1);
258   PetscValidPointer(M, 2);
259   PetscUseMethod(A, "MatTransposeGetMat_C", (Mat, Mat *), (A, M));
260   PetscFunctionReturn(0);
261 }
262 
263 /*MC
264    MATTRANSPOSEVIRTUAL - "transpose" - A matrix type that represents a virtual transpose of a matrix
265 
266   Level: advanced
267 
268 .seealso: `MATHERMITIANTRANSPOSEVIRTUAL`, `Mat`, `MatCreateHermitianTranspose()`, `MatCreateTranspose()`
269 M*/
270 
271 /*@
272       MatCreateTranspose - Creates a new matrix `MATTRANSPOSEVIRTUAL` object that behaves like A'
273 
274    Collective on A
275 
276    Input Parameter:
277 .   A  - the (possibly rectangular) matrix
278 
279    Output Parameter:
280 .   N - the matrix that represents A'
281 
282    Level: intermediate
283 
284    Note:
285     The transpose A' is NOT actually formed! Rather the new matrix
286           object performs the matrix-vector product by using the `MatMultTranspose()` on
287           the original matrix
288 
289 .seealso: `MATTRANSPOSEVIRTUAL`, `MatCreateNormal()`, `MatMult()`, `MatMultTranspose()`, `MatCreate()`
290 @*/
291 PetscErrorCode MatCreateTranspose(Mat A, Mat *N) {
292   PetscInt       m, n;
293   Mat_Transpose *Na;
294   VecType        vtype;
295 
296   PetscFunctionBegin;
297   PetscCall(MatGetLocalSize(A, &m, &n));
298   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), N));
299   PetscCall(MatSetSizes(*N, n, m, PETSC_DECIDE, PETSC_DECIDE));
300   PetscCall(PetscLayoutSetUp((*N)->rmap));
301   PetscCall(PetscLayoutSetUp((*N)->cmap));
302   PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATTRANSPOSEVIRTUAL));
303 
304   PetscCall(PetscNew(&Na));
305   (*N)->data = (void *)Na;
306   PetscCall(PetscObjectReference((PetscObject)A));
307   Na->A = A;
308 
309   (*N)->ops->destroy               = MatDestroy_Transpose;
310   (*N)->ops->mult                  = MatMult_Transpose;
311   (*N)->ops->multadd               = MatMultAdd_Transpose;
312   (*N)->ops->multtranspose         = MatMultTranspose_Transpose;
313   (*N)->ops->multtransposeadd      = MatMultTransposeAdd_Transpose;
314   (*N)->ops->duplicate             = MatDuplicate_Transpose;
315   (*N)->ops->getvecs               = MatCreateVecs_Transpose;
316   (*N)->ops->axpy                  = MatAXPY_Transpose;
317   (*N)->ops->hasoperation          = MatHasOperation_Transpose;
318   (*N)->ops->productsetfromoptions = MatProductSetFromOptions_Transpose;
319   (*N)->ops->getdiagonal           = MatGetDiagonal_Transpose;
320   (*N)->ops->convert               = MatConvert_Transpose;
321   (*N)->assembled                  = PETSC_TRUE;
322 
323   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatTransposeGetMat_C", MatTransposeGetMat_Transpose));
324   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_Transpose));
325   PetscCall(MatSetBlockSizes(*N, PetscAbs(A->cmap->bs), PetscAbs(A->rmap->bs)));
326   PetscCall(MatGetVecType(A, &vtype));
327   PetscCall(MatSetVecType(*N, vtype));
328 #if defined(PETSC_HAVE_DEVICE)
329   PetscCall(MatBindToCPU(*N, A->boundtocpu));
330 #endif
331   PetscCall(MatSetUp(*N));
332   PetscFunctionReturn(0);
333 }
334