xref: /petsc/src/ksp/pc/impls/h2opus/pch2opus.c (revision d2522c19e8fa9bca20aaca277941d9a63e71db6a)
1 #include <petsc/private/pcimpl.h>
2 #include <petsc/private/matimpl.h>
3 #include <h2opusconf.h>
4 
5 /* Use GPU only if H2OPUS is configured for GPU */
6 #if defined(PETSC_HAVE_CUDA) && defined(H2OPUS_USE_GPU)
7 #define PETSC_H2OPUS_USE_GPU
8 #endif
9 
10 typedef struct {
11   Mat         A;
12   Mat         M;
13   PetscScalar s0;
14 
15   /* sampler for Newton-Schultz */
16   Mat      S;
17   PetscInt hyperorder;
18   Vec      wns[4];
19   Mat      wnsmat[4];
20 
21   /* convergence testing */
22   Mat T;
23   Vec w;
24 
25   /* Support for PCSetCoordinates */
26   PetscInt   sdim;
27   PetscInt   nlocc;
28   PetscReal *coords;
29 
30   /* Newton-Schultz customization */
31   PetscInt  maxits;
32   PetscReal rtol, atol;
33   PetscBool monitor;
34   PetscBool useapproximatenorms;
35   NormType  normtype;
36 
37   /* Used when pmat != MATH2OPUS */
38   PetscReal eta;
39   PetscInt  leafsize;
40   PetscInt  max_rank;
41   PetscInt  bs;
42   PetscReal mrtol;
43 
44   /* CPU/GPU */
45   PetscBool forcecpu;
46   PetscBool boundtocpu;
47 } PC_H2OPUS;
48 
49 PETSC_EXTERN PetscErrorCode MatNorm_H2OPUS(Mat, NormType, PetscReal *);
50 
51 static PetscErrorCode PCReset_H2OPUS(PC pc) {
52   PC_H2OPUS *pch2opus = (PC_H2OPUS *)pc->data;
53 
54   PetscFunctionBegin;
55   pch2opus->sdim  = 0;
56   pch2opus->nlocc = 0;
57   PetscCall(PetscFree(pch2opus->coords));
58   PetscCall(MatDestroy(&pch2opus->A));
59   PetscCall(MatDestroy(&pch2opus->M));
60   PetscCall(MatDestroy(&pch2opus->T));
61   PetscCall(VecDestroy(&pch2opus->w));
62   PetscCall(MatDestroy(&pch2opus->S));
63   PetscCall(VecDestroy(&pch2opus->wns[0]));
64   PetscCall(VecDestroy(&pch2opus->wns[1]));
65   PetscCall(VecDestroy(&pch2opus->wns[2]));
66   PetscCall(VecDestroy(&pch2opus->wns[3]));
67   PetscCall(MatDestroy(&pch2opus->wnsmat[0]));
68   PetscCall(MatDestroy(&pch2opus->wnsmat[1]));
69   PetscCall(MatDestroy(&pch2opus->wnsmat[2]));
70   PetscCall(MatDestroy(&pch2opus->wnsmat[3]));
71   PetscFunctionReturn(0);
72 }
73 
74 static PetscErrorCode PCSetCoordinates_H2OPUS(PC pc, PetscInt sdim, PetscInt nlocc, PetscReal *coords) {
75   PC_H2OPUS *pch2opus = (PC_H2OPUS *)pc->data;
76   PetscBool  reset    = PETSC_TRUE;
77 
78   PetscFunctionBegin;
79   if (pch2opus->sdim && sdim == pch2opus->sdim && nlocc == pch2opus->nlocc) {
80     PetscCall(PetscArraycmp(pch2opus->coords, coords, sdim * nlocc, &reset));
81     reset = (PetscBool)!reset;
82   }
83   PetscCall(MPIU_Allreduce(MPI_IN_PLACE, &reset, 1, MPIU_BOOL, MPI_LOR, PetscObjectComm((PetscObject)pc)));
84   if (reset) {
85     PetscCall(PCReset_H2OPUS(pc));
86     PetscCall(PetscMalloc1(sdim * nlocc, &pch2opus->coords));
87     PetscCall(PetscArraycpy(pch2opus->coords, coords, sdim * nlocc));
88     pch2opus->sdim  = sdim;
89     pch2opus->nlocc = nlocc;
90   }
91   PetscFunctionReturn(0);
92 }
93 
94 static PetscErrorCode PCDestroy_H2OPUS(PC pc) {
95   PetscFunctionBegin;
96   PetscCall(PCReset_H2OPUS(pc));
97   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCSetCoordinates_C", NULL));
98   PetscCall(PetscFree(pc->data));
99   PetscFunctionReturn(0);
100 }
101 
102 static PetscErrorCode PCSetFromOptions_H2OPUS(PC pc, PetscOptionItems *PetscOptionsObject) {
103   PC_H2OPUS *pch2opus = (PC_H2OPUS *)pc->data;
104 
105   PetscFunctionBegin;
106   PetscOptionsHeadBegin(PetscOptionsObject, "H2OPUS options");
107   PetscCall(PetscOptionsInt("-pc_h2opus_maxits", "Maximum number of iterations for Newton-Schultz", NULL, pch2opus->maxits, &pch2opus->maxits, NULL));
108   PetscCall(PetscOptionsBool("-pc_h2opus_monitor", "Monitor Newton-Schultz convergence", NULL, pch2opus->monitor, &pch2opus->monitor, NULL));
109   PetscCall(PetscOptionsReal("-pc_h2opus_atol", "Absolute tolerance", NULL, pch2opus->atol, &pch2opus->atol, NULL));
110   PetscCall(PetscOptionsReal("-pc_h2opus_rtol", "Relative tolerance", NULL, pch2opus->rtol, &pch2opus->rtol, NULL));
111   PetscCall(PetscOptionsEnum("-pc_h2opus_norm_type", "Norm type for convergence monitoring", NULL, NormTypes, (PetscEnum)pch2opus->normtype, (PetscEnum *)&pch2opus->normtype, NULL));
112   PetscCall(PetscOptionsInt("-pc_h2opus_hyperorder", "Hyper power order of sampling", NULL, pch2opus->hyperorder, &pch2opus->hyperorder, NULL));
113   PetscCall(PetscOptionsInt("-pc_h2opus_leafsize", "Leaf size when constructed from kernel", NULL, pch2opus->leafsize, &pch2opus->leafsize, NULL));
114   PetscCall(PetscOptionsReal("-pc_h2opus_eta", "Admissibility condition tolerance", NULL, pch2opus->eta, &pch2opus->eta, NULL));
115   PetscCall(PetscOptionsInt("-pc_h2opus_maxrank", "Maximum rank when constructed from matvecs", NULL, pch2opus->max_rank, &pch2opus->max_rank, NULL));
116   PetscCall(PetscOptionsInt("-pc_h2opus_samples", "Number of samples to be taken concurrently when constructing from matvecs", NULL, pch2opus->bs, &pch2opus->bs, NULL));
117   PetscCall(PetscOptionsReal("-pc_h2opus_mrtol", "Relative tolerance for construction from sampling", NULL, pch2opus->mrtol, &pch2opus->mrtol, NULL));
118   PetscCall(PetscOptionsBool("-pc_h2opus_forcecpu", "Force construction on CPU", NULL, pch2opus->forcecpu, &pch2opus->forcecpu, NULL));
119   PetscOptionsHeadEnd();
120   PetscFunctionReturn(0);
121 }
122 
123 typedef struct {
124   Mat A;
125   Mat M;
126   Vec w;
127 } AAtCtx;
128 
129 static PetscErrorCode MatMult_AAt(Mat A, Vec x, Vec y) {
130   AAtCtx *aat;
131 
132   PetscFunctionBegin;
133   PetscCall(MatShellGetContext(A, (void *)&aat));
134   /* PetscCall(MatMultTranspose(aat->M,x,aat->w)); */
135   PetscCall(MatMultTranspose(aat->A, x, aat->w));
136   PetscCall(MatMult(aat->A, aat->w, y));
137   PetscFunctionReturn(0);
138 }
139 
140 static PetscErrorCode PCH2OpusSetUpInit(PC pc) {
141   PC_H2OPUS *pch2opus = (PC_H2OPUS *)pc->data;
142   Mat        A        = pc->useAmat ? pc->mat : pc->pmat, AAt;
143   PetscInt   M, m;
144   VecType    vtype;
145   PetscReal  n;
146   AAtCtx     aat;
147 
148   PetscFunctionBegin;
149   aat.A = A;
150   aat.M = pch2opus->M; /* unused so far */
151   PetscCall(MatCreateVecs(A, NULL, &aat.w));
152   PetscCall(MatGetSize(A, &M, NULL));
153   PetscCall(MatGetLocalSize(A, &m, NULL));
154   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)A), m, m, M, M, &aat, &AAt));
155   PetscCall(MatBindToCPU(AAt, pch2opus->boundtocpu));
156   PetscCall(MatShellSetOperation(AAt, MATOP_MULT, (void (*)(void))MatMult_AAt));
157   PetscCall(MatShellSetOperation(AAt, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMult_AAt));
158   PetscCall(MatShellSetOperation(AAt, MATOP_NORM, (void (*)(void))MatNorm_H2OPUS));
159   PetscCall(MatGetVecType(A, &vtype));
160   PetscCall(MatShellSetVecType(AAt, vtype));
161   PetscCall(MatNorm(AAt, NORM_1, &n));
162   pch2opus->s0 = 1. / n;
163   PetscCall(VecDestroy(&aat.w));
164   PetscCall(MatDestroy(&AAt));
165   PetscFunctionReturn(0);
166 }
167 
168 static PetscErrorCode PCApplyKernel_H2OPUS(PC pc, Vec x, Vec y, PetscBool t) {
169   PC_H2OPUS *pch2opus = (PC_H2OPUS *)pc->data;
170 
171   PetscFunctionBegin;
172   if (t) PetscCall(MatMultTranspose(pch2opus->M, x, y));
173   else PetscCall(MatMult(pch2opus->M, x, y));
174   PetscFunctionReturn(0);
175 }
176 
177 static PetscErrorCode PCApplyMatKernel_H2OPUS(PC pc, Mat X, Mat Y, PetscBool t) {
178   PC_H2OPUS *pch2opus = (PC_H2OPUS *)pc->data;
179 
180   PetscFunctionBegin;
181   if (t) PetscCall(MatTransposeMatMult(pch2opus->M, X, MAT_REUSE_MATRIX, PETSC_DEFAULT, &Y));
182   else PetscCall(MatMatMult(pch2opus->M, X, MAT_REUSE_MATRIX, PETSC_DEFAULT, &Y));
183   PetscFunctionReturn(0);
184 }
185 
186 static PetscErrorCode PCApplyMat_H2OPUS(PC pc, Mat X, Mat Y) {
187   PetscFunctionBegin;
188   PetscCall(PCApplyMatKernel_H2OPUS(pc, X, Y, PETSC_FALSE));
189   PetscFunctionReturn(0);
190 }
191 
192 static PetscErrorCode PCApplyTransposeMat_H2OPUS(PC pc, Mat X, Mat Y) {
193   PetscFunctionBegin;
194   PetscCall(PCApplyMatKernel_H2OPUS(pc, X, Y, PETSC_TRUE));
195   PetscFunctionReturn(0);
196 }
197 
198 static PetscErrorCode PCApply_H2OPUS(PC pc, Vec x, Vec y) {
199   PetscFunctionBegin;
200   PetscCall(PCApplyKernel_H2OPUS(pc, x, y, PETSC_FALSE));
201   PetscFunctionReturn(0);
202 }
203 
204 static PetscErrorCode PCApplyTranspose_H2OPUS(PC pc, Vec x, Vec y) {
205   PetscFunctionBegin;
206   PetscCall(PCApplyKernel_H2OPUS(pc, x, y, PETSC_TRUE));
207   PetscFunctionReturn(0);
208 }
209 
210 /* used to test the norm of (M^-1 A - I) */
211 static PetscErrorCode MatMultKernel_MAmI(Mat M, Vec x, Vec y, PetscBool t) {
212   PC         pc;
213   Mat        A;
214   PC_H2OPUS *pch2opus;
215   PetscBool  sideleft = PETSC_TRUE;
216 
217   PetscFunctionBegin;
218   PetscCall(MatShellGetContext(M, (void *)&pc));
219   pch2opus = (PC_H2OPUS *)pc->data;
220   if (!pch2opus->w) PetscCall(MatCreateVecs(pch2opus->M, &pch2opus->w, NULL));
221   A = pch2opus->A;
222   PetscCall(VecBindToCPU(pch2opus->w, pch2opus->boundtocpu));
223   if (t) {
224     if (sideleft) {
225       PetscCall(PCApplyTranspose_H2OPUS(pc, x, pch2opus->w));
226       PetscCall(MatMultTranspose(A, pch2opus->w, y));
227     } else {
228       PetscCall(MatMultTranspose(A, x, pch2opus->w));
229       PetscCall(PCApplyTranspose_H2OPUS(pc, pch2opus->w, y));
230     }
231   } else {
232     if (sideleft) {
233       PetscCall(MatMult(A, x, pch2opus->w));
234       PetscCall(PCApply_H2OPUS(pc, pch2opus->w, y));
235     } else {
236       PetscCall(PCApply_H2OPUS(pc, x, pch2opus->w));
237       PetscCall(MatMult(A, pch2opus->w, y));
238     }
239   }
240   PetscCall(VecAXPY(y, -1.0, x));
241   PetscFunctionReturn(0);
242 }
243 
244 static PetscErrorCode MatMult_MAmI(Mat A, Vec x, Vec y) {
245   PetscFunctionBegin;
246   PetscCall(MatMultKernel_MAmI(A, x, y, PETSC_FALSE));
247   PetscFunctionReturn(0);
248 }
249 
250 static PetscErrorCode MatMultTranspose_MAmI(Mat A, Vec x, Vec y) {
251   PetscFunctionBegin;
252   PetscCall(MatMultKernel_MAmI(A, x, y, PETSC_TRUE));
253   PetscFunctionReturn(0);
254 }
255 
256 /* HyperPower kernel:
257 Y = R = x
258 for i = 1 . . . l - 1 do
259   R = (I - A * Xk) * R
260   Y = Y + R
261 Y = Xk * Y
262 */
263 static PetscErrorCode MatMultKernel_Hyper(Mat M, Vec x, Vec y, PetscBool t) {
264   PC         pc;
265   Mat        A;
266   PC_H2OPUS *pch2opus;
267 
268   PetscFunctionBegin;
269   PetscCall(MatShellGetContext(M, (void *)&pc));
270   pch2opus = (PC_H2OPUS *)pc->data;
271   A        = pch2opus->A;
272   PetscCall(MatCreateVecs(pch2opus->M, pch2opus->wns[0] ? NULL : &pch2opus->wns[0], pch2opus->wns[1] ? NULL : &pch2opus->wns[1]));
273   PetscCall(MatCreateVecs(pch2opus->M, pch2opus->wns[2] ? NULL : &pch2opus->wns[2], pch2opus->wns[3] ? NULL : &pch2opus->wns[3]));
274   PetscCall(VecBindToCPU(pch2opus->wns[0], pch2opus->boundtocpu));
275   PetscCall(VecBindToCPU(pch2opus->wns[1], pch2opus->boundtocpu));
276   PetscCall(VecBindToCPU(pch2opus->wns[2], pch2opus->boundtocpu));
277   PetscCall(VecBindToCPU(pch2opus->wns[3], pch2opus->boundtocpu));
278   PetscCall(VecCopy(x, pch2opus->wns[0]));
279   PetscCall(VecCopy(x, pch2opus->wns[3]));
280   if (t) {
281     for (PetscInt i = 0; i < pch2opus->hyperorder - 1; i++) {
282       PetscCall(MatMultTranspose(A, pch2opus->wns[0], pch2opus->wns[1]));
283       PetscCall(PCApplyTranspose_H2OPUS(pc, pch2opus->wns[1], pch2opus->wns[2]));
284       PetscCall(VecAXPY(pch2opus->wns[0], -1., pch2opus->wns[2]));
285       PetscCall(VecAXPY(pch2opus->wns[3], 1., pch2opus->wns[0]));
286     }
287     PetscCall(PCApplyTranspose_H2OPUS(pc, pch2opus->wns[3], y));
288   } else {
289     for (PetscInt i = 0; i < pch2opus->hyperorder - 1; i++) {
290       PetscCall(PCApply_H2OPUS(pc, pch2opus->wns[0], pch2opus->wns[1]));
291       PetscCall(MatMult(A, pch2opus->wns[1], pch2opus->wns[2]));
292       PetscCall(VecAXPY(pch2opus->wns[0], -1., pch2opus->wns[2]));
293       PetscCall(VecAXPY(pch2opus->wns[3], 1., pch2opus->wns[0]));
294     }
295     PetscCall(PCApply_H2OPUS(pc, pch2opus->wns[3], y));
296   }
297   PetscFunctionReturn(0);
298 }
299 
300 static PetscErrorCode MatMult_Hyper(Mat M, Vec x, Vec y) {
301   PetscFunctionBegin;
302   PetscCall(MatMultKernel_Hyper(M, x, y, PETSC_FALSE));
303   PetscFunctionReturn(0);
304 }
305 
306 static PetscErrorCode MatMultTranspose_Hyper(Mat M, Vec x, Vec y) {
307   PetscFunctionBegin;
308   PetscCall(MatMultKernel_Hyper(M, x, y, PETSC_TRUE));
309   PetscFunctionReturn(0);
310 }
311 
312 /* Hyper power kernel, MatMat version */
313 static PetscErrorCode MatMatMultKernel_Hyper(Mat M, Mat X, Mat Y, PetscBool t) {
314   PC         pc;
315   Mat        A;
316   PC_H2OPUS *pch2opus;
317   PetscInt   i;
318 
319   PetscFunctionBegin;
320   PetscCall(MatShellGetContext(M, (void *)&pc));
321   pch2opus = (PC_H2OPUS *)pc->data;
322   A        = pch2opus->A;
323   if (pch2opus->wnsmat[0] && pch2opus->wnsmat[0]->cmap->N != X->cmap->N) {
324     PetscCall(MatDestroy(&pch2opus->wnsmat[0]));
325     PetscCall(MatDestroy(&pch2opus->wnsmat[1]));
326   }
327   if (!pch2opus->wnsmat[0]) {
328     PetscCall(MatDuplicate(X, MAT_SHARE_NONZERO_PATTERN, &pch2opus->wnsmat[0]));
329     PetscCall(MatDuplicate(Y, MAT_SHARE_NONZERO_PATTERN, &pch2opus->wnsmat[1]));
330   }
331   if (pch2opus->wnsmat[2] && pch2opus->wnsmat[2]->cmap->N != X->cmap->N) {
332     PetscCall(MatDestroy(&pch2opus->wnsmat[2]));
333     PetscCall(MatDestroy(&pch2opus->wnsmat[3]));
334   }
335   if (!pch2opus->wnsmat[2]) {
336     PetscCall(MatDuplicate(X, MAT_SHARE_NONZERO_PATTERN, &pch2opus->wnsmat[2]));
337     PetscCall(MatDuplicate(Y, MAT_SHARE_NONZERO_PATTERN, &pch2opus->wnsmat[3]));
338   }
339   PetscCall(MatCopy(X, pch2opus->wnsmat[0], SAME_NONZERO_PATTERN));
340   PetscCall(MatCopy(X, pch2opus->wnsmat[3], SAME_NONZERO_PATTERN));
341   if (t) {
342     for (i = 0; i < pch2opus->hyperorder - 1; i++) {
343       PetscCall(MatTransposeMatMult(A, pch2opus->wnsmat[0], MAT_REUSE_MATRIX, PETSC_DEFAULT, &pch2opus->wnsmat[1]));
344       PetscCall(PCApplyTransposeMat_H2OPUS(pc, pch2opus->wnsmat[1], pch2opus->wnsmat[2]));
345       PetscCall(MatAXPY(pch2opus->wnsmat[0], -1., pch2opus->wnsmat[2], SAME_NONZERO_PATTERN));
346       PetscCall(MatAXPY(pch2opus->wnsmat[3], 1., pch2opus->wnsmat[0], SAME_NONZERO_PATTERN));
347     }
348     PetscCall(PCApplyTransposeMat_H2OPUS(pc, pch2opus->wnsmat[3], Y));
349   } else {
350     for (i = 0; i < pch2opus->hyperorder - 1; i++) {
351       PetscCall(PCApplyMat_H2OPUS(pc, pch2opus->wnsmat[0], pch2opus->wnsmat[1]));
352       PetscCall(MatMatMult(A, pch2opus->wnsmat[1], MAT_REUSE_MATRIX, PETSC_DEFAULT, &pch2opus->wnsmat[2]));
353       PetscCall(MatAXPY(pch2opus->wnsmat[0], -1., pch2opus->wnsmat[2], SAME_NONZERO_PATTERN));
354       PetscCall(MatAXPY(pch2opus->wnsmat[3], 1., pch2opus->wnsmat[0], SAME_NONZERO_PATTERN));
355     }
356     PetscCall(PCApplyMat_H2OPUS(pc, pch2opus->wnsmat[3], Y));
357   }
358   PetscFunctionReturn(0);
359 }
360 
361 static PetscErrorCode MatMatMultNumeric_Hyper(Mat M, Mat X, Mat Y, void *ctx) {
362   PetscFunctionBegin;
363   PetscCall(MatMatMultKernel_Hyper(M, X, Y, PETSC_FALSE));
364   PetscFunctionReturn(0);
365 }
366 
367 /* Basic Newton-Schultz sampler: (2 * I - M * A)*M */
368 static PetscErrorCode MatMultKernel_NS(Mat M, Vec x, Vec y, PetscBool t) {
369   PC         pc;
370   Mat        A;
371   PC_H2OPUS *pch2opus;
372 
373   PetscFunctionBegin;
374   PetscCall(MatShellGetContext(M, (void *)&pc));
375   pch2opus = (PC_H2OPUS *)pc->data;
376   A        = pch2opus->A;
377   PetscCall(MatCreateVecs(pch2opus->M, pch2opus->wns[0] ? NULL : &pch2opus->wns[0], pch2opus->wns[1] ? NULL : &pch2opus->wns[1]));
378   PetscCall(VecBindToCPU(pch2opus->wns[0], pch2opus->boundtocpu));
379   PetscCall(VecBindToCPU(pch2opus->wns[1], pch2opus->boundtocpu));
380   if (t) {
381     PetscCall(PCApplyTranspose_H2OPUS(pc, x, y));
382     PetscCall(MatMultTranspose(A, y, pch2opus->wns[1]));
383     PetscCall(PCApplyTranspose_H2OPUS(pc, pch2opus->wns[1], pch2opus->wns[0]));
384     PetscCall(VecAXPBY(y, -1., 2., pch2opus->wns[0]));
385   } else {
386     PetscCall(PCApply_H2OPUS(pc, x, y));
387     PetscCall(MatMult(A, y, pch2opus->wns[0]));
388     PetscCall(PCApply_H2OPUS(pc, pch2opus->wns[0], pch2opus->wns[1]));
389     PetscCall(VecAXPBY(y, -1., 2., pch2opus->wns[1]));
390   }
391   PetscFunctionReturn(0);
392 }
393 
394 static PetscErrorCode MatMult_NS(Mat M, Vec x, Vec y) {
395   PetscFunctionBegin;
396   PetscCall(MatMultKernel_NS(M, x, y, PETSC_FALSE));
397   PetscFunctionReturn(0);
398 }
399 
400 static PetscErrorCode MatMultTranspose_NS(Mat M, Vec x, Vec y) {
401   PetscFunctionBegin;
402   PetscCall(MatMultKernel_NS(M, x, y, PETSC_TRUE));
403   PetscFunctionReturn(0);
404 }
405 
406 /* Basic Newton-Schultz sampler: (2 * I - M * A)*M, MatMat version */
407 static PetscErrorCode MatMatMultKernel_NS(Mat M, Mat X, Mat Y, PetscBool t) {
408   PC         pc;
409   Mat        A;
410   PC_H2OPUS *pch2opus;
411 
412   PetscFunctionBegin;
413   PetscCall(MatShellGetContext(M, (void *)&pc));
414   pch2opus = (PC_H2OPUS *)pc->data;
415   A        = pch2opus->A;
416   if (pch2opus->wnsmat[0] && pch2opus->wnsmat[0]->cmap->N != X->cmap->N) {
417     PetscCall(MatDestroy(&pch2opus->wnsmat[0]));
418     PetscCall(MatDestroy(&pch2opus->wnsmat[1]));
419   }
420   if (!pch2opus->wnsmat[0]) {
421     PetscCall(MatDuplicate(X, MAT_SHARE_NONZERO_PATTERN, &pch2opus->wnsmat[0]));
422     PetscCall(MatDuplicate(Y, MAT_SHARE_NONZERO_PATTERN, &pch2opus->wnsmat[1]));
423   }
424   if (t) {
425     PetscCall(PCApplyTransposeMat_H2OPUS(pc, X, Y));
426     PetscCall(MatTransposeMatMult(A, Y, MAT_REUSE_MATRIX, PETSC_DEFAULT, &pch2opus->wnsmat[1]));
427     PetscCall(PCApplyTransposeMat_H2OPUS(pc, pch2opus->wnsmat[1], pch2opus->wnsmat[0]));
428     PetscCall(MatScale(Y, 2.));
429     PetscCall(MatAXPY(Y, -1., pch2opus->wnsmat[0], SAME_NONZERO_PATTERN));
430   } else {
431     PetscCall(PCApplyMat_H2OPUS(pc, X, Y));
432     PetscCall(MatMatMult(A, Y, MAT_REUSE_MATRIX, PETSC_DEFAULT, &pch2opus->wnsmat[0]));
433     PetscCall(PCApplyMat_H2OPUS(pc, pch2opus->wnsmat[0], pch2opus->wnsmat[1]));
434     PetscCall(MatScale(Y, 2.));
435     PetscCall(MatAXPY(Y, -1., pch2opus->wnsmat[1], SAME_NONZERO_PATTERN));
436   }
437   PetscFunctionReturn(0);
438 }
439 
440 static PetscErrorCode MatMatMultNumeric_NS(Mat M, Mat X, Mat Y, void *ctx) {
441   PetscFunctionBegin;
442   PetscCall(MatMatMultKernel_NS(M, X, Y, PETSC_FALSE));
443   PetscFunctionReturn(0);
444 }
445 
446 static PetscErrorCode PCH2OpusSetUpSampler_Private(PC pc) {
447   PC_H2OPUS *pch2opus = (PC_H2OPUS *)pc->data;
448   Mat        A        = pc->useAmat ? pc->mat : pc->pmat;
449 
450   PetscFunctionBegin;
451   if (!pch2opus->S) {
452     PetscInt M, N, m, n;
453 
454     PetscCall(MatGetSize(A, &M, &N));
455     PetscCall(MatGetLocalSize(A, &m, &n));
456     PetscCall(MatCreateShell(PetscObjectComm((PetscObject)A), m, n, M, N, pc, &pch2opus->S));
457     PetscCall(MatSetBlockSizesFromMats(pch2opus->S, A, A));
458 #if defined(PETSC_H2OPUS_USE_GPU)
459     PetscCall(MatShellSetVecType(pch2opus->S, VECCUDA));
460 #endif
461   }
462   if (pch2opus->hyperorder >= 2) {
463     PetscCall(MatShellSetOperation(pch2opus->S, MATOP_MULT, (void (*)(void))MatMult_Hyper));
464     PetscCall(MatShellSetOperation(pch2opus->S, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Hyper));
465     PetscCall(MatShellSetMatProductOperation(pch2opus->S, MATPRODUCT_AB, NULL, MatMatMultNumeric_Hyper, NULL, MATDENSE, MATDENSE));
466     PetscCall(MatShellSetMatProductOperation(pch2opus->S, MATPRODUCT_AB, NULL, MatMatMultNumeric_Hyper, NULL, MATDENSECUDA, MATDENSECUDA));
467   } else {
468     PetscCall(MatShellSetOperation(pch2opus->S, MATOP_MULT, (void (*)(void))MatMult_NS));
469     PetscCall(MatShellSetOperation(pch2opus->S, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_NS));
470     PetscCall(MatShellSetMatProductOperation(pch2opus->S, MATPRODUCT_AB, NULL, MatMatMultNumeric_NS, NULL, MATDENSE, MATDENSE));
471     PetscCall(MatShellSetMatProductOperation(pch2opus->S, MATPRODUCT_AB, NULL, MatMatMultNumeric_NS, NULL, MATDENSECUDA, MATDENSECUDA));
472   }
473   PetscCall(MatPropagateSymmetryOptions(A, pch2opus->S));
474   PetscCall(MatBindToCPU(pch2opus->S, pch2opus->boundtocpu));
475   /* XXX */
476   PetscCall(MatSetOption(pch2opus->S, MAT_SYMMETRIC, PETSC_TRUE));
477   PetscFunctionReturn(0);
478 }
479 
480 static PetscErrorCode PCSetUp_H2OPUS(PC pc) {
481   PC_H2OPUS *pch2opus = (PC_H2OPUS *)pc->data;
482   Mat        A        = pc->useAmat ? pc->mat : pc->pmat;
483   NormType   norm     = pch2opus->normtype;
484   PetscReal  initerr  = 0.0, err;
485   PetscBool  ish2opus;
486 
487   PetscFunctionBegin;
488   if (!pch2opus->T) {
489     PetscInt    M, N, m, n;
490     const char *prefix;
491 
492     PetscCall(PCGetOptionsPrefix(pc, &prefix));
493     PetscCall(MatGetSize(A, &M, &N));
494     PetscCall(MatGetLocalSize(A, &m, &n));
495     PetscCall(MatCreateShell(PetscObjectComm((PetscObject)pc->pmat), m, n, M, N, pc, &pch2opus->T));
496     PetscCall(MatSetBlockSizesFromMats(pch2opus->T, A, A));
497     PetscCall(MatShellSetOperation(pch2opus->T, MATOP_MULT, (void (*)(void))MatMult_MAmI));
498     PetscCall(MatShellSetOperation(pch2opus->T, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_MAmI));
499     PetscCall(MatShellSetOperation(pch2opus->T, MATOP_NORM, (void (*)(void))MatNorm_H2OPUS));
500 #if defined(PETSC_H2OPUS_USE_GPU)
501     PetscCall(MatShellSetVecType(pch2opus->T, VECCUDA));
502 #endif
503     PetscCall(PetscLogObjectParent((PetscObject)pc, (PetscObject)pch2opus->T));
504     PetscCall(MatSetOptionsPrefix(pch2opus->T, prefix));
505     PetscCall(MatAppendOptionsPrefix(pch2opus->T, "pc_h2opus_est_"));
506   }
507   PetscCall(MatDestroy(&pch2opus->A));
508   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATH2OPUS, &ish2opus));
509   if (ish2opus) {
510     PetscCall(PetscObjectReference((PetscObject)A));
511     pch2opus->A = A;
512   } else {
513     const char *prefix;
514     PetscCall(MatCreateH2OpusFromMat(A, pch2opus->sdim, pch2opus->coords, PETSC_FALSE, pch2opus->eta, pch2opus->leafsize, pch2opus->max_rank, pch2opus->bs, pch2opus->mrtol, &pch2opus->A));
515     /* XXX */
516     PetscCall(MatSetOption(pch2opus->A, MAT_SYMMETRIC, PETSC_TRUE));
517     PetscCall(PCGetOptionsPrefix(pc, &prefix));
518     PetscCall(MatSetOptionsPrefix(pch2opus->A, prefix));
519     PetscCall(MatAppendOptionsPrefix(pch2opus->A, "pc_h2opus_init_"));
520     PetscCall(MatSetFromOptions(pch2opus->A));
521     PetscCall(MatAssemblyBegin(pch2opus->A, MAT_FINAL_ASSEMBLY));
522     PetscCall(MatAssemblyEnd(pch2opus->A, MAT_FINAL_ASSEMBLY));
523     /* XXX */
524     PetscCall(MatSetOption(pch2opus->A, MAT_SYMMETRIC, PETSC_TRUE));
525 
526     /* always perform construction on the GPU unless forcecpu is true */
527     PetscCall(MatBindToCPU(pch2opus->A, pch2opus->forcecpu));
528   }
529 #if defined(PETSC_H2OPUS_USE_GPU)
530   pch2opus->boundtocpu = pch2opus->forcecpu ? PETSC_TRUE : pch2opus->A->boundtocpu;
531 #endif
532   PetscCall(MatBindToCPU(pch2opus->T, pch2opus->boundtocpu));
533   if (pch2opus->M) { /* see if we can reuse M as initial guess */
534     PetscReal reuse;
535 
536     PetscCall(MatBindToCPU(pch2opus->M, pch2opus->boundtocpu));
537     PetscCall(MatNorm(pch2opus->T, norm, &reuse));
538     if (reuse >= 1.0) PetscCall(MatDestroy(&pch2opus->M));
539   }
540   if (!pch2opus->M) {
541     const char *prefix;
542     PetscCall(MatDuplicate(pch2opus->A, MAT_COPY_VALUES, &pch2opus->M));
543     PetscCall(PCGetOptionsPrefix(pc, &prefix));
544     PetscCall(MatSetOptionsPrefix(pch2opus->M, prefix));
545     PetscCall(MatAppendOptionsPrefix(pch2opus->M, "pc_h2opus_inv_"));
546     PetscCall(MatSetFromOptions(pch2opus->M));
547     PetscCall(PCH2OpusSetUpInit(pc));
548     PetscCall(MatScale(pch2opus->M, pch2opus->s0));
549   }
550   /* A and M have the same h2 matrix structure, save on reordering routines */
551   PetscCall(MatH2OpusSetNativeMult(pch2opus->A, PETSC_TRUE));
552   PetscCall(MatH2OpusSetNativeMult(pch2opus->M, PETSC_TRUE));
553   if (norm == NORM_1 || norm == NORM_2 || norm == NORM_INFINITY) PetscCall(MatNorm(pch2opus->T, norm, &initerr));
554   if (PetscIsInfOrNanReal(initerr)) pc->failedreason = PC_SETUP_ERROR;
555   err = initerr;
556   if (pch2opus->monitor) { PetscCall(PetscPrintf(PetscObjectComm((PetscObject)pc), "%" PetscInt_FMT ": ||M*A - I|| NORM%s abs %g rel %g\n", 0, NormTypes[norm], (double)err, (double)(err / initerr))); }
557   if (initerr > pch2opus->atol && !pc->failedreason) {
558     PetscInt i;
559 
560     PetscCall(PCH2OpusSetUpSampler_Private(pc));
561     for (i = 0; i < pch2opus->maxits; i++) {
562       Mat         M;
563       const char *prefix;
564 
565       PetscCall(MatDuplicate(pch2opus->M, MAT_SHARE_NONZERO_PATTERN, &M));
566       PetscCall(MatGetOptionsPrefix(pch2opus->M, &prefix));
567       PetscCall(MatSetOptionsPrefix(M, prefix));
568       PetscCall(MatH2OpusSetSamplingMat(M, pch2opus->S, PETSC_DECIDE, PETSC_DECIDE));
569       PetscCall(MatSetFromOptions(M));
570       PetscCall(MatH2OpusSetNativeMult(M, PETSC_TRUE));
571       PetscCall(MatAssemblyBegin(M, MAT_FINAL_ASSEMBLY));
572       PetscCall(MatAssemblyEnd(M, MAT_FINAL_ASSEMBLY));
573       PetscCall(MatDestroy(&pch2opus->M));
574       pch2opus->M = M;
575       if (norm == NORM_1 || norm == NORM_2 || norm == NORM_INFINITY) PetscCall(MatNorm(pch2opus->T, norm, &err));
576       if (pch2opus->monitor) { PetscCall(PetscPrintf(PetscObjectComm((PetscObject)pc), "%" PetscInt_FMT ": ||M*A - I|| NORM%s abs %g rel %g\n", i + 1, NormTypes[norm], (double)err, (double)(err / initerr))); }
577       if (PetscIsInfOrNanReal(err)) pc->failedreason = PC_SETUP_ERROR;
578       if (err < pch2opus->atol || err < pch2opus->rtol * initerr || pc->failedreason) break;
579     }
580   }
581   /* cleanup setup workspace */
582   PetscCall(MatH2OpusSetNativeMult(pch2opus->A, PETSC_FALSE));
583   PetscCall(MatH2OpusSetNativeMult(pch2opus->M, PETSC_FALSE));
584   PetscCall(VecDestroy(&pch2opus->wns[0]));
585   PetscCall(VecDestroy(&pch2opus->wns[1]));
586   PetscCall(VecDestroy(&pch2opus->wns[2]));
587   PetscCall(VecDestroy(&pch2opus->wns[3]));
588   PetscCall(MatDestroy(&pch2opus->wnsmat[0]));
589   PetscCall(MatDestroy(&pch2opus->wnsmat[1]));
590   PetscCall(MatDestroy(&pch2opus->wnsmat[2]));
591   PetscCall(MatDestroy(&pch2opus->wnsmat[3]));
592   PetscFunctionReturn(0);
593 }
594 
595 static PetscErrorCode PCView_H2OPUS(PC pc, PetscViewer viewer) {
596   PC_H2OPUS *pch2opus = (PC_H2OPUS *)pc->data;
597   PetscBool  isascii;
598 
599   PetscFunctionBegin;
600   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
601   if (isascii) {
602     if (pch2opus->A && pch2opus->A != pc->mat && pch2opus->A != pc->pmat) {
603       PetscCall(PetscViewerASCIIPrintf(viewer, "Initial approximation matrix\n"));
604       PetscCall(PetscViewerASCIIPushTab(viewer));
605       PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO_DETAIL));
606       PetscCall(MatView(pch2opus->A, viewer));
607       PetscCall(PetscViewerPopFormat(viewer));
608       PetscCall(PetscViewerASCIIPopTab(viewer));
609     }
610     if (pch2opus->M) {
611       PetscCall(PetscViewerASCIIPrintf(viewer, "Inner matrix constructed\n"));
612       PetscCall(PetscViewerASCIIPushTab(viewer));
613       PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO_DETAIL));
614       PetscCall(MatView(pch2opus->M, viewer));
615       PetscCall(PetscViewerPopFormat(viewer));
616       PetscCall(PetscViewerASCIIPopTab(viewer));
617     }
618     PetscCall(PetscViewerASCIIPrintf(viewer, "Initial scaling: %g\n", pch2opus->s0));
619   }
620   PetscFunctionReturn(0);
621 }
622 
623 PETSC_EXTERN PetscErrorCode PCCreate_H2OPUS(PC pc) {
624   PC_H2OPUS *pch2opus;
625 
626   PetscFunctionBegin;
627   PetscCall(PetscNewLog(pc, &pch2opus));
628   pc->data = (void *)pch2opus;
629 
630   pch2opus->atol       = 1.e-2;
631   pch2opus->rtol       = 1.e-6;
632   pch2opus->maxits     = 50;
633   pch2opus->hyperorder = 1; /* defaults to basic NewtonSchultz */
634   pch2opus->normtype   = NORM_2;
635 
636   /* these are needed when we are sampling the pmat */
637   pch2opus->eta      = PETSC_DECIDE;
638   pch2opus->leafsize = PETSC_DECIDE;
639   pch2opus->max_rank = PETSC_DECIDE;
640   pch2opus->bs       = PETSC_DECIDE;
641   pch2opus->mrtol    = PETSC_DECIDE;
642 #if defined(PETSC_H2OPUS_USE_GPU)
643   pch2opus->boundtocpu = PETSC_FALSE;
644 #else
645   pch2opus->boundtocpu = PETSC_TRUE;
646 #endif
647   pc->ops->destroy        = PCDestroy_H2OPUS;
648   pc->ops->setup          = PCSetUp_H2OPUS;
649   pc->ops->apply          = PCApply_H2OPUS;
650   pc->ops->matapply       = PCApplyMat_H2OPUS;
651   pc->ops->applytranspose = PCApplyTranspose_H2OPUS;
652   pc->ops->reset          = PCReset_H2OPUS;
653   pc->ops->setfromoptions = PCSetFromOptions_H2OPUS;
654   pc->ops->view           = PCView_H2OPUS;
655 
656   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCSetCoordinates_C", PCSetCoordinates_H2OPUS));
657   PetscFunctionReturn(0);
658 }
659