xref: /petsc/src/mat/impls/transpose/transm.c (revision 66af8762ec03dbef0e079729eb2a1734a35ed7ff)
1 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
2 
3 typedef struct {
4   Mat A;
5 } Mat_Transpose;
6 
7 static PetscErrorCode MatMult_Transpose(Mat N, Vec x, Vec y)
8 {
9   Mat_Transpose *Na = (Mat_Transpose *)N->data;
10 
11   PetscFunctionBegin;
12   PetscCall(MatMultTranspose(Na->A, x, y));
13   PetscFunctionReturn(PETSC_SUCCESS);
14 }
15 
16 static PetscErrorCode MatMultAdd_Transpose(Mat N, Vec v1, Vec v2, Vec v3)
17 {
18   Mat_Transpose *Na = (Mat_Transpose *)N->data;
19 
20   PetscFunctionBegin;
21   PetscCall(MatMultTransposeAdd(Na->A, v1, v2, v3));
22   PetscFunctionReturn(PETSC_SUCCESS);
23 }
24 
25 static PetscErrorCode MatMultTranspose_Transpose(Mat N, Vec x, Vec y)
26 {
27   Mat_Transpose *Na = (Mat_Transpose *)N->data;
28 
29   PetscFunctionBegin;
30   PetscCall(MatMult(Na->A, x, y));
31   PetscFunctionReturn(PETSC_SUCCESS);
32 }
33 
34 static PetscErrorCode MatMultTransposeAdd_Transpose(Mat N, Vec v1, Vec v2, Vec v3)
35 {
36   Mat_Transpose *Na = (Mat_Transpose *)N->data;
37 
38   PetscFunctionBegin;
39   PetscCall(MatMultAdd(Na->A, v1, v2, v3));
40   PetscFunctionReturn(PETSC_SUCCESS);
41 }
42 
43 static PetscErrorCode MatDestroy_Transpose(Mat N)
44 {
45   Mat_Transpose *Na = (Mat_Transpose *)N->data;
46 
47   PetscFunctionBegin;
48   PetscCall(MatDestroy(&Na->A));
49   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatTransposeGetMat_C", NULL));
50   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatProductSetFromOptions_anytype_C", NULL));
51   PetscCall(PetscFree(N->data));
52   PetscFunctionReturn(PETSC_SUCCESS);
53 }
54 
55 static PetscErrorCode MatDuplicate_Transpose(Mat N, MatDuplicateOption op, Mat *m)
56 {
57   Mat_Transpose *Na = (Mat_Transpose *)N->data;
58 
59   PetscFunctionBegin;
60   if (op == MAT_COPY_VALUES) {
61     PetscCall(MatTranspose(Na->A, MAT_INITIAL_MATRIX, m));
62   } else if (op == MAT_DO_NOT_COPY_VALUES) {
63     PetscCall(MatDuplicate(Na->A, MAT_DO_NOT_COPY_VALUES, m));
64     PetscCall(MatTranspose(*m, MAT_INPLACE_MATRIX, m));
65   } else SETERRQ(PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "MAT_SHARE_NONZERO_PATTERN not supported for this matrix type");
66   PetscFunctionReturn(PETSC_SUCCESS);
67 }
68 
69 static PetscErrorCode MatCreateVecs_Transpose(Mat A, Vec *r, Vec *l)
70 {
71   Mat_Transpose *Aa = (Mat_Transpose *)A->data;
72 
73   PetscFunctionBegin;
74   PetscCall(MatCreateVecs(Aa->A, l, r));
75   PetscFunctionReturn(PETSC_SUCCESS);
76 }
77 
78 static PetscErrorCode MatAXPY_Transpose(Mat Y, PetscScalar a, Mat X, MatStructure str)
79 {
80   Mat_Transpose *Ya = (Mat_Transpose *)Y->data;
81   Mat_Transpose *Xa = (Mat_Transpose *)X->data;
82   Mat            M  = Ya->A;
83   Mat            N  = Xa->A;
84 
85   PetscFunctionBegin;
86   PetscCall(MatAXPY(M, a, N, str));
87   PetscFunctionReturn(PETSC_SUCCESS);
88 }
89 
90 static PetscErrorCode MatHasOperation_Transpose(Mat mat, MatOperation op, PetscBool *has)
91 {
92   Mat_Transpose *X = (Mat_Transpose *)mat->data;
93   PetscFunctionBegin;
94 
95   *has = PETSC_FALSE;
96   if (op == MATOP_MULT) {
97     PetscCall(MatHasOperation(X->A, MATOP_MULT_TRANSPOSE, has));
98   } else if (op == MATOP_MULT_TRANSPOSE) {
99     PetscCall(MatHasOperation(X->A, MATOP_MULT, has));
100   } else if (op == MATOP_MULT_ADD) {
101     PetscCall(MatHasOperation(X->A, MATOP_MULT_TRANSPOSE_ADD, has));
102   } else if (op == MATOP_MULT_TRANSPOSE_ADD) {
103     PetscCall(MatHasOperation(X->A, MATOP_MULT_ADD, has));
104   } else if (((void **)mat->ops)[op]) *has = PETSC_TRUE;
105   PetscFunctionReturn(PETSC_SUCCESS);
106 }
107 
108 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose(Mat D)
109 {
110   Mat            A, B, C, Ain, Bin, Cin;
111   PetscBool      Aistrans, Bistrans, Cistrans;
112   PetscInt       Atrans, Btrans, Ctrans;
113   MatProductType ptype;
114 
115   PetscFunctionBegin;
116   MatCheckProduct(D, 1);
117   A = D->product->A;
118   B = D->product->B;
119   C = D->product->C;
120   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATTRANSPOSEVIRTUAL, &Aistrans));
121   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATTRANSPOSEVIRTUAL, &Bistrans));
122   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATTRANSPOSEVIRTUAL, &Cistrans));
123   PetscCheck(Aistrans || Bistrans || Cistrans, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "This should not happen");
124   Atrans = 0;
125   Ain    = A;
126   while (Aistrans) {
127     Atrans++;
128     PetscCall(MatTransposeGetMat(Ain, &Ain));
129     PetscCall(PetscObjectTypeCompare((PetscObject)Ain, MATTRANSPOSEVIRTUAL, &Aistrans));
130   }
131   Btrans = 0;
132   Bin    = B;
133   while (Bistrans) {
134     Btrans++;
135     PetscCall(MatTransposeGetMat(Bin, &Bin));
136     PetscCall(PetscObjectTypeCompare((PetscObject)Bin, MATTRANSPOSEVIRTUAL, &Bistrans));
137   }
138   Ctrans = 0;
139   Cin    = C;
140   while (Cistrans) {
141     Ctrans++;
142     PetscCall(MatTransposeGetMat(Cin, &Cin));
143     PetscCall(PetscObjectTypeCompare((PetscObject)Cin, MATTRANSPOSEVIRTUAL, &Cistrans));
144   }
145   Atrans = Atrans % 2;
146   Btrans = Btrans % 2;
147   Ctrans = Ctrans % 2;
148   ptype  = D->product->type; /* same product type by default */
149   if (Ain->symmetric == PETSC_BOOL3_TRUE) Atrans = 0;
150   if (Bin->symmetric == PETSC_BOOL3_TRUE) Btrans = 0;
151   if (Cin && Cin->symmetric == PETSC_BOOL3_TRUE) Ctrans = 0;
152 
153   if (Atrans || Btrans || Ctrans) {
154     ptype = MATPRODUCT_UNSPECIFIED;
155     switch (D->product->type) {
156     case MATPRODUCT_AB:
157       if (Atrans && Btrans) { /* At * Bt we do not have support for this */
158         /* TODO custom implementation ? */
159       } else if (Atrans) { /* At * B */
160         ptype = MATPRODUCT_AtB;
161       } else { /* A * Bt */
162         ptype = MATPRODUCT_ABt;
163       }
164       break;
165     case MATPRODUCT_AtB:
166       if (Atrans && Btrans) { /* A * Bt */
167         ptype = MATPRODUCT_ABt;
168       } else if (Atrans) { /* A * B */
169         ptype = MATPRODUCT_AB;
170       } else { /* At * Bt we do not have support for this */
171         /* TODO custom implementation ? */
172       }
173       break;
174     case MATPRODUCT_ABt:
175       if (Atrans && Btrans) { /* At * B */
176         ptype = MATPRODUCT_AtB;
177       } else if (Atrans) { /* At * Bt we do not have support for this */
178         /* TODO custom implementation ? */
179       } else { /* A * B */
180         ptype = MATPRODUCT_AB;
181       }
182       break;
183     case MATPRODUCT_PtAP:
184       if (Atrans) { /* PtAtP */
185         /* TODO custom implementation ? */
186       } else { /* RARt */
187         ptype = MATPRODUCT_RARt;
188       }
189       break;
190     case MATPRODUCT_RARt:
191       if (Atrans) { /* RAtRt */
192         /* TODO custom implementation ? */
193       } else { /* PtAP */
194         ptype = MATPRODUCT_PtAP;
195       }
196       break;
197     case MATPRODUCT_ABC:
198       /* TODO custom implementation ? */
199       break;
200     default:
201       SETERRQ(PetscObjectComm((PetscObject)D), PETSC_ERR_SUP, "ProductType %s is not supported", MatProductTypes[D->product->type]);
202     }
203   }
204   PetscCall(MatProductReplaceMats(Ain, Bin, Cin, D));
205   PetscCall(MatProductSetType(D, ptype));
206   PetscCall(MatProductSetFromOptions(D));
207   PetscFunctionReturn(PETSC_SUCCESS);
208 }
209 
210 static PetscErrorCode MatGetDiagonal_Transpose(Mat A, Vec v)
211 {
212   Mat_Transpose *Aa = (Mat_Transpose *)A->data;
213 
214   PetscFunctionBegin;
215   PetscCall(MatGetDiagonal(Aa->A, v));
216   PetscFunctionReturn(PETSC_SUCCESS);
217 }
218 
219 static PetscErrorCode MatConvert_Transpose(Mat A, MatType newtype, MatReuse reuse, Mat *newmat)
220 {
221   Mat_Transpose *Aa = (Mat_Transpose *)A->data;
222   PetscBool      flg;
223 
224   PetscFunctionBegin;
225   PetscCall(MatHasOperation(Aa->A, MATOP_TRANSPOSE, &flg));
226   if (flg) {
227     Mat B;
228 
229     PetscCall(MatTranspose(Aa->A, MAT_INITIAL_MATRIX, &B));
230     if (reuse != MAT_INPLACE_MATRIX) {
231       PetscCall(MatConvert(B, newtype, reuse, newmat));
232       PetscCall(MatDestroy(&B));
233     } else {
234       PetscCall(MatConvert(B, newtype, MAT_INPLACE_MATRIX, &B));
235       PetscCall(MatHeaderReplace(A, &B));
236     }
237   } else { /* use basic converter as fallback */
238     PetscCall(MatConvert_Basic(A, newtype, reuse, newmat));
239   }
240   PetscFunctionReturn(PETSC_SUCCESS);
241 }
242 
243 static PetscErrorCode MatTransposeGetMat_Transpose(Mat A, Mat *M)
244 {
245   Mat_Transpose *Aa = (Mat_Transpose *)A->data;
246 
247   PetscFunctionBegin;
248   *M = Aa->A;
249   PetscFunctionReturn(PETSC_SUCCESS);
250 }
251 
252 /*@
253   MatTransposeGetMat - Gets the `Mat` object stored inside a `MATTRANSPOSEVIRTUAL`
254 
255   Logically Collective
256 
257   Input Parameter:
258 . A - the `MATTRANSPOSEVIRTUAL` matrix
259 
260   Output Parameter:
261 . M - the matrix object stored inside `A`
262 
263   Level: intermediate
264 
265 .seealso: [](ch_matrices), `Mat`, `MATTRANSPOSEVIRTUAL`, `MatCreateTranspose()`
266 @*/
267 PetscErrorCode MatTransposeGetMat(Mat A, Mat *M)
268 {
269   PetscFunctionBegin;
270   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
271   PetscValidType(A, 1);
272   PetscAssertPointer(M, 2);
273   PetscUseMethod(A, "MatTransposeGetMat_C", (Mat, Mat *), (A, M));
274   PetscFunctionReturn(PETSC_SUCCESS);
275 }
276 
277 /*MC
278    MATTRANSPOSEVIRTUAL - "transpose" - A matrix type that represents a virtual transpose of a matrix
279 
280   Level: advanced
281 
282 .seealso: [](ch_matrices), `Mat`, `MATHERMITIANTRANSPOSEVIRTUAL`, `Mat`, `MatCreateHermitianTranspose()`, `MatCreateTranspose()`,
283           `MATNORMALHERMITIAN`, `MATNORMAL`
284 M*/
285 
286 /*@
287   MatCreateTranspose - Creates a new matrix `MATTRANSPOSEVIRTUAL` object that behaves like A'
288 
289   Collective
290 
291   Input Parameter:
292 . A - the (possibly rectangular) matrix
293 
294   Output Parameter:
295 . N - the matrix that represents A'
296 
297   Level: intermediate
298 
299   Note:
300   The transpose A' is NOT actually formed! Rather the new matrix
301   object performs the matrix-vector product by using the `MatMultTranspose()` on
302   the original matrix
303 
304 .seealso: [](ch_matrices), `Mat`, `MATTRANSPOSEVIRTUAL`, `MatCreateNormal()`, `MatMult()`, `MatMultTranspose()`, `MatCreate()`,
305           `MATNORMALHERMITIAN`
306 @*/
307 PetscErrorCode MatCreateTranspose(Mat A, Mat *N)
308 {
309   Mat_Transpose *Na;
310   VecType        vtype;
311 
312   PetscFunctionBegin;
313   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), N));
314   PetscCall(PetscLayoutReference(A->rmap, &((*N)->cmap)));
315   PetscCall(PetscLayoutReference(A->cmap, &((*N)->rmap)));
316   PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATTRANSPOSEVIRTUAL));
317 
318   PetscCall(PetscNew(&Na));
319   (*N)->data = (void *)Na;
320   PetscCall(PetscObjectReference((PetscObject)A));
321   Na->A = A;
322 
323   (*N)->ops->destroy               = MatDestroy_Transpose;
324   (*N)->ops->mult                  = MatMult_Transpose;
325   (*N)->ops->multadd               = MatMultAdd_Transpose;
326   (*N)->ops->multtranspose         = MatMultTranspose_Transpose;
327   (*N)->ops->multtransposeadd      = MatMultTransposeAdd_Transpose;
328   (*N)->ops->duplicate             = MatDuplicate_Transpose;
329   (*N)->ops->getvecs               = MatCreateVecs_Transpose;
330   (*N)->ops->axpy                  = MatAXPY_Transpose;
331   (*N)->ops->hasoperation          = MatHasOperation_Transpose;
332   (*N)->ops->productsetfromoptions = MatProductSetFromOptions_Transpose;
333   (*N)->ops->getdiagonal           = MatGetDiagonal_Transpose;
334   (*N)->ops->convert               = MatConvert_Transpose;
335   (*N)->assembled                  = PETSC_TRUE;
336 
337   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatTransposeGetMat_C", MatTransposeGetMat_Transpose));
338   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_Transpose));
339   PetscCall(MatSetBlockSizes(*N, PetscAbs(A->cmap->bs), PetscAbs(A->rmap->bs)));
340   PetscCall(MatGetVecType(A, &vtype));
341   PetscCall(MatSetVecType(*N, vtype));
342 #if defined(PETSC_HAVE_DEVICE)
343   PetscCall(MatBindToCPU(*N, A->boundtocpu));
344 #endif
345   PetscCall(MatSetUp(*N));
346   PetscFunctionReturn(PETSC_SUCCESS);
347 }
348