xref: /petsc/src/mat/impls/transpose/htransm.c (revision 2daea058f1fa0b4b5a893c784be52f4813ced5f1)
1 #include <../src/mat/impls/shell/shell.h> /*I "petscmat.h" I*/
2 
3 typedef struct {
4   PetscErrorCode (*numeric)(Mat);
5   PetscCtxDestroyFn *destroy;
6   Mat                B, D;
7   PetscScalar        scale;
8   PetscBool          conjugate;
9   void              *data;
10 } MatProductCtx_HT;
11 
MatProductCtxDestroy_HT(PetscCtxRt ptr)12 static PetscErrorCode MatProductCtxDestroy_HT(PetscCtxRt ptr)
13 {
14   MatProductCtx_HT *data = *(MatProductCtx_HT **)ptr;
15   PetscContainer    container;
16 
17   PetscFunctionBegin;
18   if (data->data) PetscCall((*data->destroy)(&data->data));
19   if (data->conjugate) PetscCall(MatDestroy(&data->B));
20   PetscCall(PetscObjectQuery((PetscObject)data->D, "MatProductCtx_HT", (PetscObject *)&container));
21   PetscCall(PetscContainerDestroy(&container));
22   PetscCall(PetscObjectCompose((PetscObject)data->D, "MatProductCtx_HT", NULL));
23   PetscCall(PetscFree(data));
24   PetscFunctionReturn(PETSC_SUCCESS);
25 }
26 
MatProductNumeric_HT(Mat D)27 static PetscErrorCode MatProductNumeric_HT(Mat D)
28 {
29   Mat_Product      *product;
30   Mat               B;
31   MatProductCtx_HT *data;
32   PetscContainer    container;
33 
34   PetscFunctionBegin;
35   MatCheckProduct(D, 1);
36   PetscCheck(D->product->data, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "Product data empty");
37   product = D->product;
38   PetscCall(PetscObjectQuery((PetscObject)D, "MatProductCtx_HT", (PetscObject *)&container));
39   PetscCheck(container, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "MatProductCtx_HT missing");
40   PetscCall(PetscContainerGetPointer(container, &data));
41   B    = product->B;
42   data = (MatProductCtx_HT *)product->data;
43   if (data->conjugate) {
44     PetscCall(MatCopy(product->B, data->B, SAME_NONZERO_PATTERN));
45     PetscCall(MatConjugate(data->B));
46     product->B = data->B;
47   }
48   product->data = data->data;
49   PetscCall((*data->numeric)(D));
50   if (data->conjugate) {
51     PetscCall(MatConjugate(D));
52     product->B = B;
53   }
54   PetscCall(MatScale(D, data->scale));
55   product->data = data;
56   PetscFunctionReturn(PETSC_SUCCESS);
57 }
58 
MatProductSymbolic_HT(Mat D)59 static PetscErrorCode MatProductSymbolic_HT(Mat D)
60 {
61   Mat_Product      *product;
62   Mat               B;
63   MatProductCtx_HT *data;
64   PetscContainer    container;
65 
66   PetscFunctionBegin;
67   MatCheckProduct(D, 1);
68   product = D->product;
69   B       = product->B;
70   if (D->ops->productsymbolic == MatProductSymbolic_HT) {
71     PetscCheck(!product->data, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "Product data not empty");
72     PetscCall(PetscObjectQuery((PetscObject)D, "MatProductCtx_HT", (PetscObject *)&container));
73     PetscCheck(container, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "MatProductCtx_HT missing");
74     PetscCall(PetscContainerGetPointer(container, &data));
75     PetscCall(MatProductSetFromOptions(D));
76     if (data->conjugate) {
77       PetscCall(MatDuplicate(B, MAT_DO_NOT_COPY_VALUES, &data->B));
78       product->B = data->B;
79     }
80     PetscCall(MatProductSymbolic(D));
81     data->numeric          = D->ops->productnumeric;
82     data->destroy          = product->destroy;
83     data->data             = product->data;
84     D->ops->productnumeric = MatProductNumeric_HT;
85     product->destroy       = MatProductCtxDestroy_HT;
86     if (data->conjugate) product->B = B;
87     product->data = data;
88   }
89   PetscFunctionReturn(PETSC_SUCCESS);
90 }
91 
MatProductSetFromOptions_HT(Mat D)92 static PetscErrorCode MatProductSetFromOptions_HT(Mat D)
93 {
94   Mat               A, B, C, Ain, Bin, Cin;
95   PetscScalar       scale = 1.0, vscale;
96   PetscBool         Aistrans, Bistrans, Cistrans, conjugate = PETSC_FALSE;
97   PetscInt          Atrans, Btrans, Ctrans;
98   PetscContainer    container = NULL;
99   MatProductCtx_HT *data;
100   MatProductType    ptype;
101 
102   PetscFunctionBegin;
103   MatCheckProduct(D, 1);
104   A = D->product->A;
105   B = D->product->B;
106   C = D->product->C;
107   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATHERMITIANTRANSPOSEVIRTUAL, &Aistrans));
108   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATHERMITIANTRANSPOSEVIRTUAL, &Bistrans));
109   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATHERMITIANTRANSPOSEVIRTUAL, &Cistrans));
110   PetscCheck(Aistrans || Bistrans || Cistrans, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "This should not happen");
111   Atrans = 0;
112   Ain    = A;
113   while (Aistrans) {
114     Atrans++;
115     PetscCall(MatShellGetScalingShifts(Ain, (PetscScalar *)MAT_SHELL_NOT_ALLOWED, &vscale, (Vec *)MAT_SHELL_NOT_ALLOWED, (Vec *)MAT_SHELL_NOT_ALLOWED, (Vec *)MAT_SHELL_NOT_ALLOWED, (Mat *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED));
116     conjugate = (PetscBool)!conjugate;
117     scale *= vscale;
118     PetscCall(MatHermitianTransposeGetMat(Ain, &Ain));
119     PetscCall(PetscObjectTypeCompare((PetscObject)Ain, MATHERMITIANTRANSPOSEVIRTUAL, &Aistrans));
120   }
121   Btrans = 0;
122   Bin    = B;
123   while (Bistrans) {
124     Btrans++;
125     PetscCall(MatShellGetScalingShifts(Bin, (PetscScalar *)MAT_SHELL_NOT_ALLOWED, &vscale, (Vec *)MAT_SHELL_NOT_ALLOWED, (Vec *)MAT_SHELL_NOT_ALLOWED, (Vec *)MAT_SHELL_NOT_ALLOWED, (Mat *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED));
126     scale *= vscale;
127     PetscCall(MatHermitianTransposeGetMat(Bin, &Bin));
128     PetscCall(PetscObjectTypeCompare((PetscObject)Bin, MATHERMITIANTRANSPOSEVIRTUAL, &Bistrans));
129   }
130   Ctrans = 0;
131   Cin    = C;
132   while (Cistrans) {
133     Ctrans++;
134     PetscCall(MatShellGetScalingShifts(Cin, (PetscScalar *)MAT_SHELL_NOT_ALLOWED, &vscale, (Vec *)MAT_SHELL_NOT_ALLOWED, (Vec *)MAT_SHELL_NOT_ALLOWED, (Vec *)MAT_SHELL_NOT_ALLOWED, (Mat *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED));
135     scale *= vscale;
136     PetscCall(MatHermitianTransposeGetMat(Cin, &Cin));
137     PetscCall(PetscObjectTypeCompare((PetscObject)Cin, MATHERMITIANTRANSPOSEVIRTUAL, &Cistrans));
138   }
139   Atrans = Atrans % 2;
140   Btrans = Btrans % 2;
141   Ctrans = Ctrans % 2;
142   ptype  = D->product->type; /* same product type by default */
143   if (Ain->symmetric == PETSC_BOOL3_TRUE) Atrans = 0;
144   if (Bin->symmetric == PETSC_BOOL3_TRUE) Btrans = 0;
145   if (Cin && Cin->symmetric == PETSC_BOOL3_TRUE) Ctrans = 0;
146 
147   if (Atrans || Btrans || Ctrans) {
148     PetscCheck(!PetscDefined(USE_COMPLEX) || (!Btrans && !Ctrans), PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "No support for complex Hermitian transpose matrices");
149     if ((PetscDefined(USE_COMPLEX) && Atrans) || scale != 1.0) {
150       PetscCall(PetscObjectQuery((PetscObject)D, "MatProductCtx_HT", (PetscObject *)&container));
151       if (!container) {
152         PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)D), &container));
153         PetscCall(PetscNew(&data));
154         PetscCall(PetscContainerSetPointer(container, data));
155         PetscCall(PetscObjectCompose((PetscObject)D, "MatProductCtx_HT", (PetscObject)container));
156       } else PetscCall(PetscContainerGetPointer(container, &data));
157       data->scale     = scale;
158       data->conjugate = (PetscBool)Atrans;
159       data->D         = D;
160     }
161     ptype = MATPRODUCT_UNSPECIFIED;
162     switch (D->product->type) {
163     case MATPRODUCT_AB:
164       if (Atrans && Btrans) { /* At * Bt we do not have support for this */
165         /* TODO custom implementation ? */
166       } else if (Atrans) { /* At * B */
167         ptype = MATPRODUCT_AtB;
168       } else { /* A * Bt */
169         ptype = MATPRODUCT_ABt;
170       }
171       break;
172     case MATPRODUCT_AtB:
173       if (Atrans && Btrans) { /* A * Bt */
174         ptype = MATPRODUCT_ABt;
175       } else if (Atrans) { /* A * B */
176         ptype = MATPRODUCT_AB;
177       } else { /* At * Bt we do not have support for this */
178         /* TODO custom implementation ? */
179       }
180       break;
181     case MATPRODUCT_ABt:
182       if (Atrans && Btrans) { /* At * B */
183         ptype = MATPRODUCT_AtB;
184       } else if (Atrans) { /* At * Bt we do not have support for this */
185         /* TODO custom implementation ? */
186       } else { /* A * B */
187         ptype = MATPRODUCT_AB;
188       }
189       break;
190     case MATPRODUCT_PtAP:
191       if (Atrans) { /* PtAtP */
192         /* TODO custom implementation ? */
193       } else { /* RARt */
194         ptype = MATPRODUCT_RARt;
195       }
196       break;
197     case MATPRODUCT_RARt:
198       if (Atrans) { /* RAtRt */
199         /* TODO custom implementation ? */
200       } else { /* PtAP */
201         ptype = MATPRODUCT_PtAP;
202       }
203       break;
204     case MATPRODUCT_ABC:
205       /* TODO custom implementation ? */
206       break;
207     default:
208       SETERRQ(PetscObjectComm((PetscObject)D), PETSC_ERR_SUP, "ProductType %s is not supported", MatProductTypes[D->product->type]);
209     }
210   }
211   PetscCall(MatProductReplaceMats(Ain, Bin, Cin, D));
212   PetscCall(MatProductSetType(D, ptype));
213   if (container) D->ops->productsymbolic = MatProductSymbolic_HT;
214   else PetscCall(MatProductSetFromOptions(D));
215   PetscFunctionReturn(PETSC_SUCCESS);
216 }
217 
MatMult_HT(Mat N,Vec x,Vec y)218 static PetscErrorCode MatMult_HT(Mat N, Vec x, Vec y)
219 {
220   Mat A;
221 
222   PetscFunctionBegin;
223   PetscCall(MatShellGetContext(N, &A));
224   PetscCall(MatMultHermitianTranspose(A, x, y));
225   PetscFunctionReturn(PETSC_SUCCESS);
226 }
227 
MatMultHermitianTranspose_HT(Mat N,Vec x,Vec y)228 static PetscErrorCode MatMultHermitianTranspose_HT(Mat N, Vec x, Vec y)
229 {
230   Mat A;
231 
232   PetscFunctionBegin;
233   PetscCall(MatShellGetContext(N, &A));
234   PetscCall(MatMult(A, x, y));
235   PetscFunctionReturn(PETSC_SUCCESS);
236 }
237 
MatSolve_HT_LU(Mat N,Vec b,Vec x)238 static PetscErrorCode MatSolve_HT_LU(Mat N, Vec b, Vec x)
239 {
240   Mat A;
241   Vec w;
242 
243   PetscFunctionBegin;
244   PetscCall(MatShellGetContext(N, &A));
245   PetscCall(VecDuplicate(b, &w));
246   PetscCall(VecCopy(b, w));
247   PetscCall(VecConjugate(w));
248   PetscCall(MatSolveTranspose(A, w, x));
249   PetscCall(VecConjugate(x));
250   PetscCall(VecDestroy(&w));
251   PetscFunctionReturn(PETSC_SUCCESS);
252 }
253 
MatSolveAdd_HT_LU(Mat N,Vec b,Vec y,Vec x)254 static PetscErrorCode MatSolveAdd_HT_LU(Mat N, Vec b, Vec y, Vec x)
255 {
256   Mat A;
257   Vec v, w;
258 
259   PetscFunctionBegin;
260   PetscCall(MatShellGetContext(N, &A));
261   PetscCall(VecDuplicate(b, &v));
262   PetscCall(VecDuplicate(b, &w));
263   PetscCall(VecCopy(y, v));
264   PetscCall(VecCopy(b, w));
265   PetscCall(VecConjugate(v));
266   PetscCall(VecConjugate(w));
267   PetscCall(MatSolveTransposeAdd(A, w, v, x));
268   PetscCall(VecConjugate(x));
269   PetscCall(VecDestroy(&v));
270   PetscCall(VecDestroy(&w));
271   PetscFunctionReturn(PETSC_SUCCESS);
272 }
273 
MatMatSolve_HT_LU(Mat N,Mat B,Mat X)274 static PetscErrorCode MatMatSolve_HT_LU(Mat N, Mat B, Mat X)
275 {
276   Mat A, W;
277 
278   PetscFunctionBegin;
279   PetscCall(MatShellGetContext(N, &A));
280   PetscCall(MatDuplicate(B, MAT_COPY_VALUES, &W));
281   PetscCall(MatConjugate(W));
282   PetscCall(MatMatSolveTranspose(A, W, X));
283   PetscCall(MatConjugate(X));
284   PetscCall(MatDestroy(&W));
285   PetscFunctionReturn(PETSC_SUCCESS);
286 }
287 
MatLUFactor_HT(Mat N,IS row,IS col,const MatFactorInfo * minfo)288 static PetscErrorCode MatLUFactor_HT(Mat N, IS row, IS col, const MatFactorInfo *minfo)
289 {
290   Mat A;
291 
292   PetscFunctionBegin;
293   PetscCall(MatShellGetContext(N, &A));
294   PetscCall(MatLUFactor(A, col, row, minfo));
295   PetscCall(MatShellSetOperation(N, MATOP_SOLVE, (PetscErrorCodeFn *)MatSolve_HT_LU));
296   PetscCall(MatShellSetOperation(N, MATOP_SOLVE_ADD, (PetscErrorCodeFn *)MatSolveAdd_HT_LU));
297   PetscCall(MatShellSetOperation(N, MATOP_MAT_SOLVE, (PetscErrorCodeFn *)MatMatSolve_HT_LU));
298   PetscFunctionReturn(PETSC_SUCCESS);
299 }
300 
MatSolve_HT_Cholesky(Mat N,Vec b,Vec x)301 static PetscErrorCode MatSolve_HT_Cholesky(Mat N, Vec b, Vec x)
302 {
303   Mat A;
304 
305   PetscFunctionBegin;
306   PetscCall(MatShellGetContext(N, &A));
307   PetscCall(MatSolve(A, b, x));
308   PetscFunctionReturn(PETSC_SUCCESS);
309 }
310 
MatSolveAdd_HT_Cholesky(Mat N,Vec b,Vec y,Vec x)311 static PetscErrorCode MatSolveAdd_HT_Cholesky(Mat N, Vec b, Vec y, Vec x)
312 {
313   Mat A;
314   Vec v, w;
315 
316   PetscFunctionBegin;
317   PetscCall(MatShellGetContext(N, &A));
318   PetscCall(VecDuplicate(b, &v));
319   PetscCall(VecDuplicate(b, &w));
320   PetscCall(VecCopy(y, v));
321   PetscCall(VecCopy(b, w));
322   PetscCall(VecConjugate(v));
323   PetscCall(VecConjugate(w));
324   PetscCall(MatSolveTransposeAdd(A, w, v, x));
325   PetscCall(VecConjugate(x));
326   PetscCall(VecDestroy(&v));
327   PetscCall(VecDestroy(&w));
328   PetscFunctionReturn(PETSC_SUCCESS);
329 }
330 
MatMatSolve_HT_Cholesky(Mat N,Mat B,Mat X)331 static PetscErrorCode MatMatSolve_HT_Cholesky(Mat N, Mat B, Mat X)
332 {
333   Mat A, W;
334 
335   PetscFunctionBegin;
336   PetscCall(MatShellGetContext(N, &A));
337   PetscCall(MatDuplicate(B, MAT_COPY_VALUES, &W));
338   PetscCall(MatConjugate(W));
339   PetscCall(MatMatSolveTranspose(A, W, X));
340   PetscCall(MatConjugate(X));
341   PetscCall(MatDestroy(&W));
342   PetscFunctionReturn(PETSC_SUCCESS);
343 }
344 
MatCholeskyFactor_HT(Mat N,IS perm,const MatFactorInfo * minfo)345 static PetscErrorCode MatCholeskyFactor_HT(Mat N, IS perm, const MatFactorInfo *minfo)
346 {
347   Mat A;
348 
349   PetscFunctionBegin;
350   PetscCall(MatShellGetContext(N, &A));
351   PetscCheck(!PetscDefined(USE_COMPLEX) || A->hermitian == PETSC_BOOL3_TRUE, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Cholesky supported only if original matrix is Hermitian");
352   PetscCall(MatCholeskyFactor(A, perm, minfo));
353   PetscCall(MatShellSetOperation(N, MATOP_SOLVE, (PetscErrorCodeFn *)MatSolve_HT_Cholesky));
354   PetscCall(MatShellSetOperation(N, MATOP_SOLVE_ADD, (PetscErrorCodeFn *)MatSolveAdd_HT_Cholesky));
355   PetscCall(MatShellSetOperation(N, MATOP_MAT_SOLVE, (PetscErrorCodeFn *)MatMatSolve_HT_Cholesky));
356   PetscFunctionReturn(PETSC_SUCCESS);
357 }
358 
MatLUFactorNumeric_HT(Mat F,Mat N,const MatFactorInfo * info)359 static PetscErrorCode MatLUFactorNumeric_HT(Mat F, Mat N, const MatFactorInfo *info)
360 {
361   Mat A, FA;
362 
363   PetscFunctionBegin;
364   PetscCall(MatShellGetContext(N, &A));
365   PetscCall(MatShellGetContext(F, &FA));
366   PetscCall(MatLUFactorNumeric(FA, A, info));
367   PetscCall(MatShellSetOperation(F, MATOP_SOLVE, (PetscErrorCodeFn *)MatSolve_HT_LU));
368   PetscCall(MatShellSetOperation(F, MATOP_SOLVE_ADD, (PetscErrorCodeFn *)MatSolveAdd_HT_LU));
369   PetscCall(MatShellSetOperation(F, MATOP_MAT_SOLVE, (PetscErrorCodeFn *)MatMatSolve_HT_LU));
370   PetscFunctionReturn(PETSC_SUCCESS);
371 }
372 
MatLUFactorSymbolic_HT(Mat F,Mat N,IS row,IS col,const MatFactorInfo * info)373 static PetscErrorCode MatLUFactorSymbolic_HT(Mat F, Mat N, IS row, IS col, const MatFactorInfo *info)
374 {
375   Mat A, FA;
376 
377   PetscFunctionBegin;
378   PetscCall(MatShellGetContext(N, &A));
379   PetscCall(MatShellGetContext(F, &FA));
380   PetscCall(MatLUFactorSymbolic(FA, A, row, col, info));
381   PetscCall(MatShellSetOperation(F, MATOP_LUFACTOR_NUMERIC, (PetscErrorCodeFn *)MatLUFactorNumeric_HT));
382   PetscFunctionReturn(PETSC_SUCCESS);
383 }
384 
MatCholeskyFactorNumeric_HT(Mat F,Mat N,const MatFactorInfo * info)385 static PetscErrorCode MatCholeskyFactorNumeric_HT(Mat F, Mat N, const MatFactorInfo *info)
386 {
387   Mat A, FA;
388 
389   PetscFunctionBegin;
390   PetscCall(MatShellGetContext(N, &A));
391   PetscCall(MatShellGetContext(F, &FA));
392   PetscCall(MatCholeskyFactorNumeric(FA, A, info));
393   PetscCall(MatShellSetOperation(F, MATOP_SOLVE, (PetscErrorCodeFn *)MatSolve_HT_Cholesky));
394   PetscCall(MatShellSetOperation(F, MATOP_SOLVE_ADD, (PetscErrorCodeFn *)MatSolveAdd_HT_Cholesky));
395   PetscCall(MatShellSetOperation(F, MATOP_MAT_SOLVE, (PetscErrorCodeFn *)MatMatSolve_HT_Cholesky));
396   PetscFunctionReturn(PETSC_SUCCESS);
397 }
398 
MatCholeskyFactorSymbolic_HT(Mat F,Mat N,IS perm,const MatFactorInfo * info)399 static PetscErrorCode MatCholeskyFactorSymbolic_HT(Mat F, Mat N, IS perm, const MatFactorInfo *info)
400 {
401   Mat A, FA;
402 
403   PetscFunctionBegin;
404   PetscCall(MatShellGetContext(N, &A));
405   PetscCall(MatShellGetContext(F, &FA));
406   PetscCall(MatCholeskyFactorSymbolic(FA, A, perm, info));
407   PetscCall(MatShellSetOperation(F, MATOP_CHOLESKY_FACTOR_NUMERIC, (PetscErrorCodeFn *)MatCholeskyFactorNumeric_HT));
408   PetscFunctionReturn(PETSC_SUCCESS);
409 }
410 
MatGetFactor_HT(Mat N,MatSolverType type,MatFactorType ftype,Mat * F)411 static PetscErrorCode MatGetFactor_HT(Mat N, MatSolverType type, MatFactorType ftype, Mat *F)
412 {
413   Mat A, FA;
414 
415   PetscFunctionBegin;
416   PetscCall(MatShellGetContext(N, &A));
417   PetscCall(MatGetFactor(A, type, ftype, &FA));
418   PetscCall(MatCreateTranspose(FA, F));
419   if (ftype == MAT_FACTOR_LU) PetscCall(MatShellSetOperation(*F, MATOP_LUFACTOR_SYMBOLIC, (PetscErrorCodeFn *)MatLUFactorSymbolic_HT));
420   else if (ftype == MAT_FACTOR_CHOLESKY) {
421     PetscCheck(!PetscDefined(USE_COMPLEX) || A->hermitian == PETSC_BOOL3_TRUE, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Cholesky supported only if original matrix is Hermitian");
422     PetscCall(MatPropagateSymmetryOptions(A, FA));
423     PetscCall(MatShellSetOperation(*F, MATOP_CHOLESKY_FACTOR_SYMBOLIC, (PetscErrorCodeFn *)MatCholeskyFactorSymbolic_HT));
424   } else SETERRQ(PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "Support for factor type %s not implemented in MATTRANSPOSEVIRTUAL", MatFactorTypes[ftype]);
425   (*F)->factortype = ftype;
426   PetscCall(MatDestroy(&FA));
427   PetscFunctionReturn(PETSC_SUCCESS);
428 }
429 
MatDestroy_HT(Mat N)430 static PetscErrorCode MatDestroy_HT(Mat N)
431 {
432   Mat A;
433 
434   PetscFunctionBegin;
435   PetscCall(MatShellGetContext(N, &A));
436   PetscCall(MatDestroy(&A));
437   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatHermitianTransposeGetMat_C", NULL));
438 #if !defined(PETSC_USE_COMPLEX)
439   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatTransposeGetMat_C", NULL));
440 #endif
441   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatProductSetFromOptions_anytype_C", NULL));
442   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatShellSetContext_C", NULL));
443   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatFactorGetSolverType_C", NULL));
444   PetscFunctionReturn(PETSC_SUCCESS);
445 }
446 
MatGetInfo_HT(Mat N,MatInfoType flag,MatInfo * info)447 static PetscErrorCode MatGetInfo_HT(Mat N, MatInfoType flag, MatInfo *info)
448 {
449   Mat A;
450 
451   PetscFunctionBegin;
452   PetscCall(MatShellGetContext(N, &A));
453   PetscCall(MatGetInfo(A, flag, info));
454   PetscFunctionReturn(PETSC_SUCCESS);
455 }
456 
MatFactorGetSolverType_HT(Mat N,MatSolverType * type)457 static PetscErrorCode MatFactorGetSolverType_HT(Mat N, MatSolverType *type)
458 {
459   Mat A;
460 
461   PetscFunctionBegin;
462   PetscCall(MatShellGetContext(N, &A));
463   PetscCall(MatFactorGetSolverType(A, type));
464   PetscFunctionReturn(PETSC_SUCCESS);
465 }
466 
MatDuplicate_HT(Mat N,MatDuplicateOption op,Mat * m)467 static PetscErrorCode MatDuplicate_HT(Mat N, MatDuplicateOption op, Mat *m)
468 {
469   Mat A, C;
470 
471   PetscFunctionBegin;
472   PetscCall(MatShellGetContext(N, &A));
473   PetscCall(MatDuplicate(A, op, &C));
474   PetscCall(MatCreateHermitianTranspose(C, m));
475   if (op == MAT_COPY_VALUES) PetscCall(MatCopy(N, *m, SAME_NONZERO_PATTERN));
476   PetscCall(MatDestroy(&C));
477   PetscFunctionReturn(PETSC_SUCCESS);
478 }
479 
MatHasOperation_HT(Mat mat,MatOperation op,PetscBool * has)480 static PetscErrorCode MatHasOperation_HT(Mat mat, MatOperation op, PetscBool *has)
481 {
482   Mat A;
483 
484   PetscFunctionBegin;
485   PetscCall(MatShellGetContext(mat, &A));
486   *has = PETSC_FALSE;
487   if (op == MATOP_MULT || op == MATOP_MULT_ADD) {
488     PetscCall(MatHasOperation(A, MATOP_MULT_HERMITIAN_TRANSPOSE, has));
489     if (!*has) PetscCall(MatHasOperation(A, MATOP_MULT_TRANSPOSE, has));
490   } else if (op == MATOP_MULT_HERMITIAN_TRANSPOSE || op == MATOP_MULT_HERMITIAN_TRANS_ADD || op == MATOP_MULT_TRANSPOSE || op == MATOP_MULT_TRANSPOSE_ADD) {
491     PetscCall(MatHasOperation(A, MATOP_MULT, has));
492   } else if (((void **)mat->ops)[op]) *has = PETSC_TRUE;
493   PetscFunctionReturn(PETSC_SUCCESS);
494 }
495 
MatHermitianTransposeGetMat_HT(Mat N,Mat * M)496 static PetscErrorCode MatHermitianTransposeGetMat_HT(Mat N, Mat *M)
497 {
498   PetscFunctionBegin;
499   PetscCall(MatShellGetContext(N, M));
500   PetscFunctionReturn(PETSC_SUCCESS);
501 }
502 
503 /*@
504   MatHermitianTransposeGetMat - Gets the `Mat` object stored inside a `MATHERMITIANTRANSPOSEVIRTUAL`
505 
506   Logically Collective
507 
508   Input Parameter:
509 . A - the `MATHERMITIANTRANSPOSEVIRTUAL` matrix
510 
511   Output Parameter:
512 . M - the matrix object stored inside A
513 
514   Level: intermediate
515 
516 .seealso: [](ch_matrices), `Mat`, `MATHERMITIANTRANSPOSEVIRTUAL`, `MatCreateHermitianTranspose()`
517 @*/
MatHermitianTransposeGetMat(Mat A,Mat * M)518 PetscErrorCode MatHermitianTransposeGetMat(Mat A, Mat *M)
519 {
520   PetscFunctionBegin;
521   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
522   PetscValidType(A, 1);
523   PetscAssertPointer(M, 2);
524   PetscUseMethod(A, "MatHermitianTransposeGetMat_C", (Mat, Mat *), (A, M));
525   PetscFunctionReturn(PETSC_SUCCESS);
526 }
527 
MatGetDiagonal_HT(Mat N,Vec v)528 static PetscErrorCode MatGetDiagonal_HT(Mat N, Vec v)
529 {
530   Mat A;
531 
532   PetscFunctionBegin;
533   PetscCall(MatShellGetContext(N, &A));
534   PetscCall(MatGetDiagonal(A, v));
535   PetscCall(VecConjugate(v));
536   PetscFunctionReturn(PETSC_SUCCESS);
537 }
538 
MatCopy_HT(Mat A,Mat B,MatStructure str)539 static PetscErrorCode MatCopy_HT(Mat A, Mat B, MatStructure str)
540 {
541   Mat a, b;
542 
543   PetscFunctionBegin;
544   PetscCall(MatShellGetContext(A, &a));
545   PetscCall(MatShellGetContext(B, &b));
546   PetscCall(MatCopy(a, b, str));
547   PetscFunctionReturn(PETSC_SUCCESS);
548 }
549 
MatConvert_HT(Mat N,MatType newtype,MatReuse reuse,Mat * newmat)550 static PetscErrorCode MatConvert_HT(Mat N, MatType newtype, MatReuse reuse, Mat *newmat)
551 {
552   Mat         A;
553   PetscScalar vscale = 1.0, vshift = 0.0;
554   PetscBool   flg;
555 
556   PetscFunctionBegin;
557   PetscCall(MatShellGetContext(N, &A));
558   PetscCall(MatHasOperation(A, MATOP_HERMITIAN_TRANSPOSE, &flg));
559   if (flg || N->ops->getrow) { /* if this condition is false, MatConvert_Shell() will be called in MatConvert_Basic(), so the following checks are not needed */
560     PetscCall(MatShellGetScalingShifts(N, &vshift, &vscale, (Vec *)MAT_SHELL_NOT_ALLOWED, (Vec *)MAT_SHELL_NOT_ALLOWED, (Vec *)MAT_SHELL_NOT_ALLOWED, (Mat *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED));
561   }
562   if (flg) {
563     Mat B;
564 
565     PetscCall(MatHermitianTranspose(A, MAT_INITIAL_MATRIX, &B));
566     if (reuse != MAT_INPLACE_MATRIX) {
567       PetscCall(MatConvert(B, newtype, reuse, newmat));
568       PetscCall(MatDestroy(&B));
569     } else {
570       PetscCall(MatConvert(B, newtype, MAT_INPLACE_MATRIX, &B));
571       PetscCall(MatHeaderReplace(N, &B));
572     }
573   } else { /* use basic converter as fallback */
574     flg = (PetscBool)(N->ops->getrow != NULL);
575     PetscCall(MatConvert_Basic(N, newtype, reuse, newmat));
576   }
577   if (flg) {
578     PetscCall(MatScale(*newmat, vscale));
579     PetscCall(MatShift(*newmat, vshift));
580   }
581   PetscFunctionReturn(PETSC_SUCCESS);
582 }
583 
584 /*MC
585    MATHERMITIANTRANSPOSEVIRTUAL - "hermitiantranspose" - A matrix type that represents a virtual transpose of a matrix
586 
587   Level: advanced
588 
589   Developer Notes:
590   This is implemented on top of `MATSHELL` to get support for scaling and shifting without requiring duplicate code
591 
592   Users can not call `MatShellSetOperation()` operations on this class, there is some error checking for that incorrect usage
593 
594 .seealso: [](ch_matrices), `Mat`, `MATTRANSPOSEVIRTUAL`, `Mat`, `MatCreateHermitianTranspose()`, `MatCreateTranspose()`
595 M*/
596 
597 /*@
598   MatCreateHermitianTranspose - Creates a new matrix object of `MatType` `MATHERMITIANTRANSPOSEVIRTUAL` that behaves like A'*
599 
600   Collective
601 
602   Input Parameter:
603 . A - the (possibly rectangular) matrix
604 
605   Output Parameter:
606 . N - the matrix that represents A'*
607 
608   Level: intermediate
609 
610   Note:
611   The Hermitian transpose A' is NOT actually formed! Rather the new matrix
612   object performs the matrix-vector product, `MatMult()`, by using the `MatMultHermitianTranspose()` on
613   the original matrix
614 
615 .seealso: [](ch_matrices), `Mat`, `MatCreateNormal()`, `MatMult()`, `MatMultHermitianTranspose()`, `MatCreate()`,
616           `MATTRANSPOSEVIRTUAL`, `MatCreateTranspose()`, `MatHermitianTransposeGetMat()`, `MATNORMAL`, `MATNORMALHERMITIAN`
617 @*/
MatCreateHermitianTranspose(Mat A,Mat * N)618 PetscErrorCode MatCreateHermitianTranspose(Mat A, Mat *N)
619 {
620   VecType vtype;
621 
622   PetscFunctionBegin;
623   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), N));
624   PetscCall(PetscLayoutReference(A->rmap, &((*N)->cmap)));
625   PetscCall(PetscLayoutReference(A->cmap, &((*N)->rmap)));
626   PetscCall(MatSetType(*N, MATSHELL));
627   PetscCall(MatShellSetContext(*N, A));
628   PetscCall(PetscObjectReference((PetscObject)A));
629 
630   PetscCall(MatSetBlockSizes(*N, A->cmap->bs, A->rmap->bs));
631   PetscCall(MatGetVecType(A, &vtype));
632   PetscCall(MatSetVecType(*N, vtype));
633 #if defined(PETSC_HAVE_DEVICE)
634   PetscCall(MatBindToCPU(*N, A->boundtocpu));
635 #endif
636   PetscCall(MatSetUp(*N));
637 
638   PetscCall(MatShellSetOperation(*N, MATOP_DESTROY, (PetscErrorCodeFn *)MatDestroy_HT));
639   PetscCall(MatShellSetOperation(*N, MATOP_MULT, (PetscErrorCodeFn *)MatMult_HT));
640   PetscCall(MatShellSetOperation(*N, MATOP_MULT_HERMITIAN_TRANSPOSE, (PetscErrorCodeFn *)MatMultHermitianTranspose_HT));
641 #if !defined(PETSC_USE_COMPLEX)
642   PetscCall(MatShellSetOperation(*N, MATOP_MULT_TRANSPOSE, (PetscErrorCodeFn *)MatMultHermitianTranspose_HT));
643 #endif
644   PetscCall(MatShellSetOperation(*N, MATOP_LUFACTOR, (PetscErrorCodeFn *)MatLUFactor_HT));
645   PetscCall(MatShellSetOperation(*N, MATOP_CHOLESKYFACTOR, (PetscErrorCodeFn *)MatCholeskyFactor_HT));
646   PetscCall(MatShellSetOperation(*N, MATOP_GET_FACTOR, (PetscErrorCodeFn *)MatGetFactor_HT));
647   PetscCall(MatShellSetOperation(*N, MATOP_GETINFO, (PetscErrorCodeFn *)MatGetInfo_HT));
648   PetscCall(MatShellSetOperation(*N, MATOP_DUPLICATE, (PetscErrorCodeFn *)MatDuplicate_HT));
649   PetscCall(MatShellSetOperation(*N, MATOP_HAS_OPERATION, (PetscErrorCodeFn *)MatHasOperation_HT));
650   PetscCall(MatShellSetOperation(*N, MATOP_GET_DIAGONAL, (PetscErrorCodeFn *)MatGetDiagonal_HT));
651   PetscCall(MatShellSetOperation(*N, MATOP_COPY, (PetscErrorCodeFn *)MatCopy_HT));
652   PetscCall(MatShellSetOperation(*N, MATOP_CONVERT, (PetscErrorCodeFn *)MatConvert_HT));
653 
654   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatHermitianTransposeGetMat_C", MatHermitianTransposeGetMat_HT));
655 #if !defined(PETSC_USE_COMPLEX)
656   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatTransposeGetMat_C", MatHermitianTransposeGetMat_HT));
657 #endif
658   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_HT));
659   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatFactorGetSolverType_C", MatFactorGetSolverType_HT));
660   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetContext_C", MatShellSetContext_Immutable));
661   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetContextDestroy_C", MatShellSetContextDestroy_Immutable));
662   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetManageScalingShifts_C", MatShellSetManageScalingShifts_Immutable));
663   PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATHERMITIANTRANSPOSEVIRTUAL));
664   PetscFunctionReturn(PETSC_SUCCESS);
665 }
666