xref: /petsc/src/mat/impls/transpose/htransm.c (revision 21e3ffae2f3b73c0bd738cf6d0a809700fc04bb0)
1 
2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
3 
4 typedef struct {
5   Mat A;
6 } Mat_HT;
7 
8 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_HermitianTranspose(Mat D)
9 {
10   Mat            A, B, C, Ain, Bin, Cin;
11   PetscBool      Aistrans, Bistrans, Cistrans;
12   PetscInt       Atrans, Btrans, Ctrans;
13   MatProductType ptype;
14 
15   PetscFunctionBegin;
16   MatCheckProduct(D, 1);
17   A = D->product->A;
18   B = D->product->B;
19   C = D->product->C;
20   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATHERMITIANTRANSPOSEVIRTUAL, &Aistrans));
21   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATHERMITIANTRANSPOSEVIRTUAL, &Bistrans));
22   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATHERMITIANTRANSPOSEVIRTUAL, &Cistrans));
23   PetscCheck(Aistrans || Bistrans || Cistrans, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "This should not happen");
24   Atrans = 0;
25   Ain    = A;
26   while (Aistrans) {
27     Atrans++;
28     PetscCall(MatHermitianTransposeGetMat(Ain, &Ain));
29     PetscCall(PetscObjectTypeCompare((PetscObject)Ain, MATHERMITIANTRANSPOSEVIRTUAL, &Aistrans));
30   }
31   Btrans = 0;
32   Bin    = B;
33   while (Bistrans) {
34     Btrans++;
35     PetscCall(MatHermitianTransposeGetMat(Bin, &Bin));
36     PetscCall(PetscObjectTypeCompare((PetscObject)Bin, MATHERMITIANTRANSPOSEVIRTUAL, &Bistrans));
37   }
38   Ctrans = 0;
39   Cin    = C;
40   while (Cistrans) {
41     Ctrans++;
42     PetscCall(MatHermitianTransposeGetMat(Cin, &Cin));
43     PetscCall(PetscObjectTypeCompare((PetscObject)Cin, MATHERMITIANTRANSPOSEVIRTUAL, &Cistrans));
44   }
45   Atrans = Atrans % 2;
46   Btrans = Btrans % 2;
47   Ctrans = Ctrans % 2;
48   ptype  = D->product->type; /* same product type by default */
49   if (Ain->symmetric == PETSC_BOOL3_TRUE) Atrans = 0;
50   if (Bin->symmetric == PETSC_BOOL3_TRUE) Btrans = 0;
51   if (Cin && Cin->symmetric == PETSC_BOOL3_TRUE) Ctrans = 0;
52 
53   if (Atrans || Btrans || Ctrans) {
54     PetscCheck(!PetscDefined(USE_COMPLEX), PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "No support for complex Hermitian transpose matrices");
55     ptype = MATPRODUCT_UNSPECIFIED;
56     switch (D->product->type) {
57     case MATPRODUCT_AB:
58       if (Atrans && Btrans) { /* At * Bt we do not have support for this */
59         /* TODO custom implementation ? */
60       } else if (Atrans) { /* At * B */
61         ptype = MATPRODUCT_AtB;
62       } else { /* A * Bt */
63         ptype = MATPRODUCT_ABt;
64       }
65       break;
66     case MATPRODUCT_AtB:
67       if (Atrans && Btrans) { /* A * Bt */
68         ptype = MATPRODUCT_ABt;
69       } else if (Atrans) { /* A * B */
70         ptype = MATPRODUCT_AB;
71       } else { /* At * Bt we do not have support for this */
72         /* TODO custom implementation ? */
73       }
74       break;
75     case MATPRODUCT_ABt:
76       if (Atrans && Btrans) { /* At * B */
77         ptype = MATPRODUCT_AtB;
78       } else if (Atrans) { /* At * Bt we do not have support for this */
79         /* TODO custom implementation ? */
80       } else { /* A * B */
81         ptype = MATPRODUCT_AB;
82       }
83       break;
84     case MATPRODUCT_PtAP:
85       if (Atrans) { /* PtAtP */
86         /* TODO custom implementation ? */
87       } else { /* RARt */
88         ptype = MATPRODUCT_RARt;
89       }
90       break;
91     case MATPRODUCT_RARt:
92       if (Atrans) { /* RAtRt */
93         /* TODO custom implementation ? */
94       } else { /* PtAP */
95         ptype = MATPRODUCT_PtAP;
96       }
97       break;
98     case MATPRODUCT_ABC:
99       /* TODO custom implementation ? */
100       break;
101     default:
102       SETERRQ(PetscObjectComm((PetscObject)D), PETSC_ERR_SUP, "ProductType %s is not supported", MatProductTypes[D->product->type]);
103     }
104   }
105   PetscCall(MatProductReplaceMats(Ain, Bin, Cin, D));
106   PetscCall(MatProductSetType(D, ptype));
107   PetscCall(MatProductSetFromOptions(D));
108   PetscFunctionReturn(PETSC_SUCCESS);
109 }
110 PetscErrorCode MatMult_HT(Mat N, Vec x, Vec y)
111 {
112   Mat_HT *Na = (Mat_HT *)N->data;
113 
114   PetscFunctionBegin;
115   PetscCall(MatMultHermitianTranspose(Na->A, x, y));
116   PetscFunctionReturn(PETSC_SUCCESS);
117 }
118 
119 PetscErrorCode MatMultAdd_HT(Mat N, Vec v1, Vec v2, Vec v3)
120 {
121   Mat_HT *Na = (Mat_HT *)N->data;
122 
123   PetscFunctionBegin;
124   PetscCall(MatMultHermitianTransposeAdd(Na->A, v1, v2, v3));
125   PetscFunctionReturn(PETSC_SUCCESS);
126 }
127 
128 PetscErrorCode MatMultHermitianTranspose_HT(Mat N, Vec x, Vec y)
129 {
130   Mat_HT *Na = (Mat_HT *)N->data;
131 
132   PetscFunctionBegin;
133   PetscCall(MatMult(Na->A, x, y));
134   PetscFunctionReturn(PETSC_SUCCESS);
135 }
136 
137 PetscErrorCode MatMultHermitianTransposeAdd_HT(Mat N, Vec v1, Vec v2, Vec v3)
138 {
139   Mat_HT *Na = (Mat_HT *)N->data;
140 
141   PetscFunctionBegin;
142   PetscCall(MatMultAdd(Na->A, v1, v2, v3));
143   PetscFunctionReturn(PETSC_SUCCESS);
144 }
145 
146 PetscErrorCode MatDestroy_HT(Mat N)
147 {
148   Mat_HT *Na = (Mat_HT *)N->data;
149 
150   PetscFunctionBegin;
151   PetscCall(MatDestroy(&Na->A));
152   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatHermitianTransposeGetMat_C", NULL));
153   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatProductSetFromOptions_anytype_C", NULL));
154 #if !defined(PETSC_USE_COMPLEX)
155   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatTransposeGetMat_C", NULL));
156 #endif
157   PetscCall(PetscFree(N->data));
158   PetscFunctionReturn(PETSC_SUCCESS);
159 }
160 
161 PetscErrorCode MatDuplicate_HT(Mat N, MatDuplicateOption op, Mat *m)
162 {
163   Mat_HT *Na = (Mat_HT *)N->data;
164 
165   PetscFunctionBegin;
166   if (op == MAT_COPY_VALUES) {
167     PetscCall(MatHermitianTranspose(Na->A, MAT_INITIAL_MATRIX, m));
168   } else if (op == MAT_DO_NOT_COPY_VALUES) {
169     PetscCall(MatDuplicate(Na->A, MAT_DO_NOT_COPY_VALUES, m));
170     PetscCall(MatHermitianTranspose(*m, MAT_INPLACE_MATRIX, m));
171   } else SETERRQ(PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "MAT_SHARE_NONZERO_PATTERN not supported for this matrix type");
172   PetscFunctionReturn(PETSC_SUCCESS);
173 }
174 
175 PetscErrorCode MatCreateVecs_HT(Mat N, Vec *r, Vec *l)
176 {
177   Mat_HT *Na = (Mat_HT *)N->data;
178 
179   PetscFunctionBegin;
180   PetscCall(MatCreateVecs(Na->A, l, r));
181   PetscFunctionReturn(PETSC_SUCCESS);
182 }
183 
184 PetscErrorCode MatAXPY_HT(Mat Y, PetscScalar a, Mat X, MatStructure str)
185 {
186   Mat_HT *Ya = (Mat_HT *)Y->data;
187   Mat_HT *Xa = (Mat_HT *)X->data;
188   Mat     M  = Ya->A;
189   Mat     N  = Xa->A;
190 
191   PetscFunctionBegin;
192   PetscCall(MatAXPY(M, a, N, str));
193   PetscFunctionReturn(PETSC_SUCCESS);
194 }
195 
196 PetscErrorCode MatHermitianTransposeGetMat_HT(Mat N, Mat *M)
197 {
198   Mat_HT *Na = (Mat_HT *)N->data;
199 
200   PetscFunctionBegin;
201   *M = Na->A;
202   PetscFunctionReturn(PETSC_SUCCESS);
203 }
204 
205 /*@
206       MatHermitianTransposeGetMat - Gets the `Mat` object stored inside a `MATHERMITIANTRANSPOSEVIRTUAL`
207 
208    Logically collective
209 
210    Input Parameter:
211 .   A  - the `MATHERMITIANTRANSPOSEVIRTUAL` matrix
212 
213    Output Parameter:
214 .   M - the matrix object stored inside A
215 
216    Level: intermediate
217 
218 .seealso: `MATHERMITIANTRANSPOSEVIRTUAL`, `MatCreateHermitianTranspose()`
219 @*/
220 PetscErrorCode MatHermitianTransposeGetMat(Mat A, Mat *M)
221 {
222   PetscFunctionBegin;
223   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
224   PetscValidType(A, 1);
225   PetscValidPointer(M, 2);
226   PetscUseMethod(A, "MatHermitianTransposeGetMat_C", (Mat, Mat *), (A, M));
227   PetscFunctionReturn(PETSC_SUCCESS);
228 }
229 
230 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose(Mat);
231 
232 PetscErrorCode MatGetDiagonal_HT(Mat A, Vec v)
233 {
234   Mat_HT *Na = (Mat_HT *)A->data;
235 
236   PetscFunctionBegin;
237   PetscCall(MatGetDiagonal(Na->A, v));
238   PetscCall(VecConjugate(v));
239   PetscFunctionReturn(PETSC_SUCCESS);
240 }
241 
242 PetscErrorCode MatConvert_HT(Mat A, MatType newtype, MatReuse reuse, Mat *newmat)
243 {
244   Mat_HT   *Na = (Mat_HT *)A->data;
245   PetscBool flg;
246 
247   PetscFunctionBegin;
248   PetscCall(MatHasOperation(Na->A, MATOP_HERMITIAN_TRANSPOSE, &flg));
249   if (flg) {
250     Mat B;
251 
252     PetscCall(MatHermitianTranspose(Na->A, MAT_INITIAL_MATRIX, &B));
253     if (reuse != MAT_INPLACE_MATRIX) {
254       PetscCall(MatConvert(B, newtype, reuse, newmat));
255       PetscCall(MatDestroy(&B));
256     } else {
257       PetscCall(MatConvert(B, newtype, MAT_INPLACE_MATRIX, &B));
258       PetscCall(MatHeaderReplace(A, &B));
259     }
260   } else { /* use basic converter as fallback */
261     PetscCall(MatConvert_Basic(A, newtype, reuse, newmat));
262   }
263   PetscFunctionReturn(PETSC_SUCCESS);
264 }
265 
266 /*MC
267    MATHERMITIANTRANSPOSEVIRTUAL - "hermitiantranspose" - A matrix type that represents a virtual transpose of a matrix
268 
269   Level: advanced
270 
271 .seealso: `MATTRANSPOSEVIRTUAL`, `Mat`, `MatCreateHermitianTranspose()`, `MatCreateTranspose()`
272 M*/
273 
274 /*@
275       MatCreateHermitianTranspose - Creates a new matrix object of `MatType` `MATHERMITIANTRANSPOSEVIRTUAL` that behaves like A'*
276 
277    Collective
278 
279    Input Parameter:
280 .   A  - the (possibly rectangular) matrix
281 
282    Output Parameter:
283 .   N - the matrix that represents A'*
284 
285    Level: intermediate
286 
287    Note:
288     The Hermitian transpose A' is NOT actually formed! Rather the new matrix
289           object performs the matrix-vector product, `MatMult()`, by using the `MatMultHermitianTranspose()` on
290           the original matrix
291 
292 .seealso: `MatCreateNormal()`, `MatMult()`, `MatMultHermitianTranspose()`, `MatCreate()`,
293           `MATTRANSPOSEVIRTUAL`, `MatCreateTranspose()`, `MatHermitianTransposeGetMat()`
294 @*/
295 PetscErrorCode MatCreateHermitianTranspose(Mat A, Mat *N)
296 {
297   PetscInt m, n;
298   Mat_HT  *Na;
299   VecType  vtype;
300 
301   PetscFunctionBegin;
302   PetscCall(MatGetLocalSize(A, &m, &n));
303   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), N));
304   PetscCall(MatSetSizes(*N, n, m, PETSC_DECIDE, PETSC_DECIDE));
305   PetscCall(PetscLayoutSetUp((*N)->rmap));
306   PetscCall(PetscLayoutSetUp((*N)->cmap));
307   PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATHERMITIANTRANSPOSEVIRTUAL));
308 
309   PetscCall(PetscNew(&Na));
310   (*N)->data = (void *)Na;
311   PetscCall(PetscObjectReference((PetscObject)A));
312   Na->A = A;
313 
314   (*N)->ops->destroy                   = MatDestroy_HT;
315   (*N)->ops->mult                      = MatMult_HT;
316   (*N)->ops->multadd                   = MatMultAdd_HT;
317   (*N)->ops->multhermitiantranspose    = MatMultHermitianTranspose_HT;
318   (*N)->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_HT;
319 #if !defined(PETSC_USE_COMPLEX)
320   (*N)->ops->multtranspose    = MatMultHermitianTranspose_HT;
321   (*N)->ops->multtransposeadd = MatMultHermitianTransposeAdd_HT;
322 #endif
323   (*N)->ops->duplicate = MatDuplicate_HT;
324   (*N)->ops->getvecs   = MatCreateVecs_HT;
325   (*N)->ops->axpy      = MatAXPY_HT;
326 #if !defined(PETSC_USE_COMPLEX)
327   (*N)->ops->productsetfromoptions = MatProductSetFromOptions_Transpose;
328 #endif
329   (*N)->ops->getdiagonal = MatGetDiagonal_HT;
330   (*N)->ops->convert     = MatConvert_HT;
331   (*N)->assembled        = PETSC_TRUE;
332 
333   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatHermitianTransposeGetMat_C", MatHermitianTransposeGetMat_HT));
334   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_HermitianTranspose));
335 #if !defined(PETSC_USE_COMPLEX)
336   PetscCall(PetscObjectComposeFunction((PetscObject)(*N), "MatTransposeGetMat_C", MatHermitianTransposeGetMat_HT));
337 #endif
338   PetscCall(MatSetBlockSizes(*N, PetscAbs(A->cmap->bs), PetscAbs(A->rmap->bs)));
339   PetscCall(MatGetVecType(A, &vtype));
340   PetscCall(MatSetVecType(*N, vtype));
341 #if defined(PETSC_HAVE_DEVICE)
342   PetscCall(MatBindToCPU(*N, A->boundtocpu));
343 #endif
344   PetscCall(MatSetUp(*N));
345   PetscFunctionReturn(PETSC_SUCCESS);
346 }
347