xref: /petsc/src/mat/impls/transpose/htransm.c (revision 2d30e087755efd99e28fdfe792ffbeb2ee1ea928)
1 
2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
3 
4 typedef struct {
5   Mat A;
6 } Mat_HT;
7 
8 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_HermitianTranspose(Mat D) {
9   Mat            A, B, C, Ain, Bin, Cin;
10   PetscBool      Aistrans, Bistrans, Cistrans;
11   PetscInt       Atrans, Btrans, Ctrans;
12   MatProductType ptype;
13 
14   PetscFunctionBegin;
15   MatCheckProduct(D, 1);
16   A = D->product->A;
17   B = D->product->B;
18   C = D->product->C;
19   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATHERMITIANTRANSPOSEVIRTUAL, &Aistrans));
20   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATHERMITIANTRANSPOSEVIRTUAL, &Bistrans));
21   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATHERMITIANTRANSPOSEVIRTUAL, &Cistrans));
22   PetscCheck(Aistrans || Bistrans || Cistrans, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "This should not happen");
23   Atrans = 0;
24   Ain    = A;
25   while (Aistrans) {
26     Atrans++;
27     PetscCall(MatHermitianTransposeGetMat(Ain, &Ain));
28     PetscCall(PetscObjectTypeCompare((PetscObject)Ain, MATHERMITIANTRANSPOSEVIRTUAL, &Aistrans));
29   }
30   Btrans = 0;
31   Bin    = B;
32   while (Bistrans) {
33     Btrans++;
34     PetscCall(MatHermitianTransposeGetMat(Bin, &Bin));
35     PetscCall(PetscObjectTypeCompare((PetscObject)Bin, MATHERMITIANTRANSPOSEVIRTUAL, &Bistrans));
36   }
37   Ctrans = 0;
38   Cin    = C;
39   while (Cistrans) {
40     Ctrans++;
41     PetscCall(MatHermitianTransposeGetMat(Cin, &Cin));
42     PetscCall(PetscObjectTypeCompare((PetscObject)Cin, MATHERMITIANTRANSPOSEVIRTUAL, &Cistrans));
43   }
44   Atrans = Atrans % 2;
45   Btrans = Btrans % 2;
46   Ctrans = Ctrans % 2;
47   ptype  = D->product->type; /* same product type by default */
48   if (Ain->symmetric == PETSC_BOOL3_TRUE) Atrans = 0;
49   if (Bin->symmetric == PETSC_BOOL3_TRUE) Btrans = 0;
50   if (Cin && Cin->symmetric == PETSC_BOOL3_TRUE) Ctrans = 0;
51 
52   if (Atrans || Btrans || Ctrans) {
53     if (PetscDefined(USE_COMPLEX)) SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "No support for complex Hermitian transpose matrices");
54     ptype = MATPRODUCT_UNSPECIFIED;
55     switch (D->product->type) {
56     case MATPRODUCT_AB:
57       if (Atrans && Btrans) { /* At * Bt we do not have support for this */
58         /* TODO custom implementation ? */
59       } else if (Atrans) { /* At * B */
60         ptype = MATPRODUCT_AtB;
61       } else { /* A * Bt */
62         ptype = MATPRODUCT_ABt;
63       }
64       break;
65     case MATPRODUCT_AtB:
66       if (Atrans && Btrans) { /* A * Bt */
67         ptype = MATPRODUCT_ABt;
68       } else if (Atrans) { /* A * B */
69         ptype = MATPRODUCT_AB;
70       } else { /* At * Bt we do not have support for this */
71         /* TODO custom implementation ? */
72       }
73       break;
74     case MATPRODUCT_ABt:
75       if (Atrans && Btrans) { /* At * B */
76         ptype = MATPRODUCT_AtB;
77       } else if (Atrans) { /* At * Bt we do not have support for this */
78         /* TODO custom implementation ? */
79       } else { /* A * B */
80         ptype = MATPRODUCT_AB;
81       }
82       break;
83     case MATPRODUCT_PtAP:
84       if (Atrans) { /* PtAtP */
85         /* TODO custom implementation ? */
86       } else { /* RARt */
87         ptype = MATPRODUCT_RARt;
88       }
89       break;
90     case MATPRODUCT_RARt:
91       if (Atrans) { /* RAtRt */
92         /* TODO custom implementation ? */
93       } else { /* PtAP */
94         ptype = MATPRODUCT_PtAP;
95       }
96       break;
97     case MATPRODUCT_ABC:
98       /* TODO custom implementation ? */
99       break;
100     default: SETERRQ(PetscObjectComm((PetscObject)D), PETSC_ERR_SUP, "ProductType %s is not supported", MatProductTypes[D->product->type]);
101     }
102   }
103   PetscCall(MatProductReplaceMats(Ain, Bin, Cin, D));
104   PetscCall(MatProductSetType(D, ptype));
105   PetscCall(MatProductSetFromOptions(D));
106   PetscFunctionReturn(0);
107 }
108 PetscErrorCode MatMult_HT(Mat N, Vec x, Vec y) {
109   Mat_HT *Na = (Mat_HT *)N->data;
110 
111   PetscFunctionBegin;
112   PetscCall(MatMultHermitianTranspose(Na->A, x, y));
113   PetscFunctionReturn(0);
114 }
115 
116 PetscErrorCode MatMultAdd_HT(Mat N, Vec v1, Vec v2, Vec v3) {
117   Mat_HT *Na = (Mat_HT *)N->data;
118 
119   PetscFunctionBegin;
120   PetscCall(MatMultHermitianTransposeAdd(Na->A, v1, v2, v3));
121   PetscFunctionReturn(0);
122 }
123 
124 PetscErrorCode MatMultHermitianTranspose_HT(Mat N, Vec x, Vec y) {
125   Mat_HT *Na = (Mat_HT *)N->data;
126 
127   PetscFunctionBegin;
128   PetscCall(MatMult(Na->A, x, y));
129   PetscFunctionReturn(0);
130 }
131 
132 PetscErrorCode MatMultHermitianTransposeAdd_HT(Mat N, Vec v1, Vec v2, Vec v3) {
133   Mat_HT *Na = (Mat_HT *)N->data;
134 
135   PetscFunctionBegin;
136   PetscCall(MatMultAdd(Na->A, v1, v2, v3));
137   PetscFunctionReturn(0);
138 }
139 
140 PetscErrorCode MatDestroy_HT(Mat N) {
141   Mat_HT *Na = (Mat_HT *)N->data;
142 
143   PetscFunctionBegin;
144   PetscCall(MatDestroy(&Na->A));
145   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatHermitianTransposeGetMat_C", NULL));
146   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatProductSetFromOptions_anytype_C", NULL));
147 #if !defined(PETSC_USE_COMPLEX)
148   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatTransposeGetMat_C", NULL));
149 #endif
150   PetscCall(PetscFree(N->data));
151   PetscFunctionReturn(0);
152 }
153 
154 PetscErrorCode MatDuplicate_HT(Mat N, MatDuplicateOption op, Mat *m) {
155   Mat_HT *Na = (Mat_HT *)N->data;
156 
157   PetscFunctionBegin;
158   if (op == MAT_COPY_VALUES) {
159     PetscCall(MatHermitianTranspose(Na->A, MAT_INITIAL_MATRIX, m));
160   } else if (op == MAT_DO_NOT_COPY_VALUES) {
161     PetscCall(MatDuplicate(Na->A, MAT_DO_NOT_COPY_VALUES, m));
162     PetscCall(MatHermitianTranspose(*m, MAT_INPLACE_MATRIX, m));
163   } else SETERRQ(PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "MAT_SHARE_NONZERO_PATTERN not supported for this matrix type");
164   PetscFunctionReturn(0);
165 }
166 
167 PetscErrorCode MatCreateVecs_HT(Mat N, Vec *r, Vec *l) {
168   Mat_HT *Na = (Mat_HT *)N->data;
169 
170   PetscFunctionBegin;
171   PetscCall(MatCreateVecs(Na->A, l, r));
172   PetscFunctionReturn(0);
173 }
174 
175 PetscErrorCode MatAXPY_HT(Mat Y, PetscScalar a, Mat X, MatStructure str) {
176   Mat_HT *Ya = (Mat_HT *)Y->data;
177   Mat_HT *Xa = (Mat_HT *)X->data;
178   Mat     M  = Ya->A;
179   Mat     N  = Xa->A;
180 
181   PetscFunctionBegin;
182   PetscCall(MatAXPY(M, a, N, str));
183   PetscFunctionReturn(0);
184 }
185 
186 PetscErrorCode MatHermitianTransposeGetMat_HT(Mat N, Mat *M) {
187   Mat_HT *Na = (Mat_HT *)N->data;
188 
189   PetscFunctionBegin;
190   *M = Na->A;
191   PetscFunctionReturn(0);
192 }
193 
194 /*@
195       MatHermitianTransposeGetMat - Gets the `Mat` object stored inside a `MATHERMITIANTRANSPOSEVIRTUAL`
196 
197    Logically collective on Mat
198 
199    Input Parameter:
200 .   A  - the `MATHERMITIANTRANSPOSEVIRTUAL` matrix
201 
202    Output Parameter:
203 .   M - the matrix object stored inside A
204 
205    Level: intermediate
206 
207 .seealso: `MATHERMITIANTRANSPOSEVIRTUAL`, `MatCreateHermitianTranspose()`
208 @*/
209 PetscErrorCode MatHermitianTransposeGetMat(Mat A, Mat *M) {
210   PetscFunctionBegin;
211   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
212   PetscValidType(A, 1);
213   PetscValidPointer(M, 2);
214   PetscUseMethod(A, "MatHermitianTransposeGetMat_C", (Mat, Mat *), (A, M));
215   PetscFunctionReturn(0);
216 }
217 
218 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose(Mat);
219 
220 PetscErrorCode MatGetDiagonal_HT(Mat A, Vec v) {
221   Mat_HT *Na = (Mat_HT *)A->data;
222 
223   PetscFunctionBegin;
224   PetscCall(MatGetDiagonal(Na->A, v));
225   PetscCall(VecConjugate(v));
226   PetscFunctionReturn(0);
227 }
228 
229 PetscErrorCode MatConvert_HT(Mat A, MatType newtype, MatReuse reuse, Mat *newmat) {
230   Mat_HT   *Na = (Mat_HT *)A->data;
231   PetscBool flg;
232 
233   PetscFunctionBegin;
234   PetscCall(MatHasOperation(Na->A, MATOP_HERMITIAN_TRANSPOSE, &flg));
235   if (flg) {
236     Mat B;
237 
238     PetscCall(MatHermitianTranspose(Na->A, MAT_INITIAL_MATRIX, &B));
239     if (reuse != MAT_INPLACE_MATRIX) {
240       PetscCall(MatConvert(B, newtype, reuse, newmat));
241       PetscCall(MatDestroy(&B));
242     } else {
243       PetscCall(MatConvert(B, newtype, MAT_INPLACE_MATRIX, &B));
244       PetscCall(MatHeaderReplace(A, &B));
245     }
246   } else { /* use basic converter as fallback */
247     PetscCall(MatConvert_Basic(A, newtype, reuse, newmat));
248   }
249   PetscFunctionReturn(0);
250 }
251 
252 /*MC
253    MATHERMITIANTRANSPOSEVIRTUAL - "hermitiantranspose" - A matrix type that represents a virtual transpose of a matrix
254 
255   Level: advanced
256 
257 .seealso: `MATTRANSPOSEVIRTUAL`, `Mat`, `MatCreateHermitianTranspose()`, `MatCreateTranspose()`
258 M*/
259 
260 /*@
261       MatCreateHermitianTranspose - Creates a new matrix object of `MatType` `MATHERMITIANTRANSPOSEVIRTUAL` that behaves like A'*
262 
263    Collective on A
264 
265    Input Parameter:
266 .   A  - the (possibly rectangular) matrix
267 
268    Output Parameter:
269 .   N - the matrix that represents A'*
270 
271    Level: intermediate
272 
273    Note:
274     The Hermitian transpose A' is NOT actually formed! Rather the new matrix
275           object performs the matrix-vector product, `MatMult()`, by using the `MatMultHermitianTranspose()` on
276           the original matrix
277 
278 .seealso: `MatCreateNormal()`, `MatMult()`, `MatMultHermitianTranspose()`, `MatCreate()`,
279           `MATTRANSPOSEVIRTUAL`, `MatCreateTranspose()`, `MatHermitianTransposeGetMat()`
280 @*/
281 PetscErrorCode MatCreateHermitianTranspose(Mat A, Mat *N) {
282   PetscInt m, n;
283   Mat_HT  *Na;
284   VecType  vtype;
285 
286   PetscFunctionBegin;
287   PetscCall(MatGetLocalSize(A, &m, &n));
288   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), N));
289   PetscCall(MatSetSizes(*N, n, m, PETSC_DECIDE, PETSC_DECIDE));
290   PetscCall(PetscLayoutSetUp((*N)->rmap));
291   PetscCall(PetscLayoutSetUp((*N)->cmap));
292   PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATHERMITIANTRANSPOSEVIRTUAL));
293 
294   PetscCall(PetscNewLog(*N, &Na));
295   (*N)->data = (void *)Na;
296   PetscCall(PetscObjectReference((PetscObject)A));
297   Na->A = A;
298 
299   (*N)->ops->destroy                   = MatDestroy_HT;
300   (*N)->ops->mult                      = MatMult_HT;
301   (*N)->ops->multadd                   = MatMultAdd_HT;
302   (*N)->ops->multhermitiantranspose    = MatMultHermitianTranspose_HT;
303   (*N)->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_HT;
304 #if !defined(PETSC_USE_COMPLEX)
305   (*N)->ops->multtranspose    = MatMultHermitianTranspose_HT;
306   (*N)->ops->multtransposeadd = MatMultHermitianTransposeAdd_HT;
307 #endif
308   (*N)->ops->duplicate = MatDuplicate_HT;
309   (*N)->ops->getvecs   = MatCreateVecs_HT;
310   (*N)->ops->axpy      = MatAXPY_HT;
311 #if !defined(PETSC_USE_COMPLEX)
312   (*N)->ops->productsetfromoptions = MatProductSetFromOptions_Transpose;
313 #endif
314   (*N)->ops->getdiagonal = MatGetDiagonal_HT;
315   (*N)->ops->convert     = MatConvert_HT;
316   (*N)->assembled        = PETSC_TRUE;
317 
318   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatHermitianTransposeGetMat_C", MatHermitianTransposeGetMat_HT));
319   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_HermitianTranspose));
320 #if !defined(PETSC_USE_COMPLEX)
321   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatTransposeGetMat_C", MatHermitianTransposeGetMat_HT));
322 #endif
323   PetscCall(MatSetBlockSizes(*N, PetscAbs(A->cmap->bs), PetscAbs(A->rmap->bs)));
324   PetscCall(MatGetVecType(A, &vtype));
325   PetscCall(MatSetVecType(*N, vtype));
326 #if defined(PETSC_HAVE_DEVICE)
327   PetscCall(MatBindToCPU(*N, A->boundtocpu));
328 #endif
329   PetscCall(MatSetUp(*N));
330   PetscFunctionReturn(0);
331 }
332