xref: /petsc/src/ksp/ksp/impls/cg/pipecg2/pipecg2.c (revision d8e47b638cf8f604a99e9678e1df24f82d959cd7)
1 #include <petsc/private/kspimpl.h>
2 
3 /* The auxiliary functions below are merged vector operations that load vectors from memory once and use
4    the data multiple times by performing vector operations element-wise. These functions
5    only apply to sequential vectors.
6 */
7 /*   VecMergedDot_Private function merges the dot products for gamma, delta and dp */
VecMergedDot_Private(Vec U,Vec W,Vec R,PetscInt normtype,PetscScalar * ru,PetscScalar * wu,PetscScalar * uu)8 static PetscErrorCode VecMergedDot_Private(Vec U, Vec W, Vec R, PetscInt normtype, PetscScalar *ru, PetscScalar *wu, PetscScalar *uu)
9 {
10   const PetscScalar *PETSC_RESTRICT PU, *PETSC_RESTRICT PW, *PETSC_RESTRICT PR;
11   PetscScalar sumru = 0.0, sumwu = 0.0, sumuu = 0.0;
12   PetscInt    j, n;
13 
14   PetscFunctionBegin;
15   PetscCall(VecGetArrayRead(U, (const PetscScalar **)&PU));
16   PetscCall(VecGetArrayRead(W, (const PetscScalar **)&PW));
17   PetscCall(VecGetArrayRead(R, (const PetscScalar **)&PR));
18   PetscCall(VecGetLocalSize(U, &n));
19 
20   if (normtype == KSP_NORM_PRECONDITIONED) {
21     PetscPragmaSIMD
22     for (j = 0; j < n; j++) {
23       sumwu += PW[j] * PetscConj(PU[j]);
24       sumru += PR[j] * PetscConj(PU[j]);
25       sumuu += PU[j] * PetscConj(PU[j]);
26     }
27   } else if (normtype == KSP_NORM_UNPRECONDITIONED) {
28     PetscPragmaSIMD
29     for (j = 0; j < n; j++) {
30       sumwu += PW[j] * PetscConj(PU[j]);
31       sumru += PR[j] * PetscConj(PU[j]);
32       sumuu += PR[j] * PetscConj(PR[j]);
33     }
34   } else if (normtype == KSP_NORM_NATURAL) {
35     PetscPragmaSIMD
36     for (j = 0; j < n; j++) {
37       sumwu += PW[j] * PetscConj(PU[j]);
38       sumru += PR[j] * PetscConj(PU[j]);
39     }
40     sumuu = sumru;
41   }
42 
43   *ru = sumru;
44   *wu = sumwu;
45   *uu = sumuu;
46 
47   PetscCall(VecRestoreArrayRead(U, (const PetscScalar **)&PU));
48   PetscCall(VecRestoreArrayRead(W, (const PetscScalar **)&PW));
49   PetscCall(VecRestoreArrayRead(R, (const PetscScalar **)&PR));
50   PetscFunctionReturn(PETSC_SUCCESS);
51 }
52 
53 /*   VecMergedDot2_Private function merges the dot products for lambda_1 and lambda_4 */
VecMergedDot2_Private(Vec N,Vec M,Vec W,PetscScalar * wm,PetscScalar * nm)54 static PetscErrorCode VecMergedDot2_Private(Vec N, Vec M, Vec W, PetscScalar *wm, PetscScalar *nm)
55 {
56   const PetscScalar *PETSC_RESTRICT PN, *PETSC_RESTRICT PM, *PETSC_RESTRICT PW;
57   PetscScalar sumwm = 0.0, sumnm = 0.0;
58   PetscInt    j, n;
59 
60   PetscFunctionBegin;
61   PetscCall(VecGetArrayRead(W, (const PetscScalar **)&PW));
62   PetscCall(VecGetArrayRead(N, (const PetscScalar **)&PN));
63   PetscCall(VecGetArrayRead(M, (const PetscScalar **)&PM));
64   PetscCall(VecGetLocalSize(N, &n));
65 
66   PetscPragmaSIMD
67   for (j = 0; j < n; j++) {
68     sumwm += PW[j] * PetscConj(PM[j]);
69     sumnm += PN[j] * PetscConj(PM[j]);
70   }
71 
72   *wm = sumwm;
73   *nm = sumnm;
74 
75   PetscCall(VecRestoreArrayRead(W, (const PetscScalar **)&PW));
76   PetscCall(VecRestoreArrayRead(N, (const PetscScalar **)&PN));
77   PetscCall(VecRestoreArrayRead(M, (const PetscScalar **)&PM));
78   PetscFunctionReturn(PETSC_SUCCESS);
79 }
80 
81 /*   VecMergedOpsShort_Private function merges the dot products, AXPY and SAXPY operations for all vectors for iteration 0  */
VecMergedOpsShort_Private(Vec vx,Vec vr,Vec vz,Vec vw,Vec vp,Vec vq,Vec vc,Vec vd,Vec vg0,Vec vh0,Vec vg1,Vec vh1,Vec vs,Vec va1,Vec vb1,Vec ve,Vec vf,Vec vm,Vec vn,Vec vu,PetscInt normtype,PetscScalar beta0,PetscScalar alpha0,PetscScalar beta1,PetscScalar alpha1,PetscScalar * lambda)82 static PetscErrorCode VecMergedOpsShort_Private(Vec vx, Vec vr, Vec vz, Vec vw, Vec vp, Vec vq, Vec vc, Vec vd, Vec vg0, Vec vh0, Vec vg1, Vec vh1, Vec vs, Vec va1, Vec vb1, Vec ve, Vec vf, Vec vm, Vec vn, Vec vu, PetscInt normtype, PetscScalar beta0, PetscScalar alpha0, PetscScalar beta1, PetscScalar alpha1, PetscScalar *lambda)
83 {
84   PetscScalar *PETSC_RESTRICT px, *PETSC_RESTRICT pr, *PETSC_RESTRICT pz, *PETSC_RESTRICT pw;
85   PetscScalar *PETSC_RESTRICT pp, *PETSC_RESTRICT pq;
86   PetscScalar *PETSC_RESTRICT pc, *PETSC_RESTRICT pd, *PETSC_RESTRICT pg0, *PETSC_RESTRICT ph0, *PETSC_RESTRICT pg1, *PETSC_RESTRICT ph1, *PETSC_RESTRICT ps, *PETSC_RESTRICT pa1, *PETSC_RESTRICT pb1, *PETSC_RESTRICT pe, *PETSC_RESTRICT pf, *PETSC_RESTRICT pm, *PETSC_RESTRICT pn, *PETSC_RESTRICT pu;
87   PetscInt j, n;
88 
89   PetscFunctionBegin;
90   PetscCall(VecGetArray(vx, (PetscScalar **)&px));
91   PetscCall(VecGetArray(vr, (PetscScalar **)&pr));
92   PetscCall(VecGetArray(vz, (PetscScalar **)&pz));
93   PetscCall(VecGetArray(vw, (PetscScalar **)&pw));
94   PetscCall(VecGetArray(vp, (PetscScalar **)&pp));
95   PetscCall(VecGetArray(vq, (PetscScalar **)&pq));
96   PetscCall(VecGetArray(vc, (PetscScalar **)&pc));
97   PetscCall(VecGetArray(vd, (PetscScalar **)&pd));
98   PetscCall(VecGetArray(vg0, (PetscScalar **)&pg0));
99   PetscCall(VecGetArray(vh0, (PetscScalar **)&ph0));
100   PetscCall(VecGetArray(vg1, (PetscScalar **)&pg1));
101   PetscCall(VecGetArray(vh1, (PetscScalar **)&ph1));
102   PetscCall(VecGetArray(vs, (PetscScalar **)&ps));
103   PetscCall(VecGetArray(va1, (PetscScalar **)&pa1));
104   PetscCall(VecGetArray(vb1, (PetscScalar **)&pb1));
105   PetscCall(VecGetArray(ve, (PetscScalar **)&pe));
106   PetscCall(VecGetArray(vf, (PetscScalar **)&pf));
107   PetscCall(VecGetArray(vm, (PetscScalar **)&pm));
108   PetscCall(VecGetArray(vn, (PetscScalar **)&pn));
109   PetscCall(VecGetArray(vu, (PetscScalar **)&pu));
110 
111   PetscCall(VecGetLocalSize(vx, &n));
112   for (j = 0; j < 15; j++) lambda[j] = 0.0;
113 
114   if (normtype == KSP_NORM_PRECONDITIONED) {
115     PetscPragmaSIMD
116     for (j = 0; j < n; j++) {
117       pz[j]  = pn[j];
118       pq[j]  = pm[j];
119       ps[j]  = pw[j];
120       pp[j]  = pu[j];
121       pc[j]  = pg0[j];
122       pd[j]  = ph0[j];
123       pa1[j] = pe[j];
124       pb1[j] = pf[j];
125 
126       px[j]  = px[j] + alpha0 * pp[j];
127       pr[j]  = pr[j] - alpha0 * ps[j];
128       pu[j]  = pu[j] - alpha0 * pq[j];
129       pw[j]  = pw[j] - alpha0 * pz[j];
130       pm[j]  = pm[j] - alpha0 * pc[j];
131       pn[j]  = pn[j] - alpha0 * pd[j];
132       pg0[j] = pg0[j] - alpha0 * pa1[j];
133       ph0[j] = ph0[j] - alpha0 * pb1[j];
134 
135       pg1[j] = pg0[j];
136       ph1[j] = ph0[j];
137 
138       pz[j] = pn[j] + beta1 * pz[j];
139       pq[j] = pm[j] + beta1 * pq[j];
140       ps[j] = pw[j] + beta1 * ps[j];
141       pp[j] = pu[j] + beta1 * pp[j];
142       pc[j] = pg0[j] + beta1 * pc[j];
143       pd[j] = ph0[j] + beta1 * pd[j];
144 
145       px[j] = px[j] + alpha1 * pp[j];
146       pr[j] = pr[j] - alpha1 * ps[j];
147       pu[j] = pu[j] - alpha1 * pq[j];
148       pw[j] = pw[j] - alpha1 * pz[j];
149       pm[j] = pm[j] - alpha1 * pc[j];
150       pn[j] = pn[j] - alpha1 * pd[j];
151 
152       lambda[0] += ps[j] * PetscConj(pu[j]);
153       lambda[1] += pw[j] * PetscConj(pm[j]);
154       lambda[2] += pw[j] * PetscConj(pq[j]);
155       lambda[4] += ps[j] * PetscConj(pq[j]);
156       lambda[6] += pn[j] * PetscConj(pm[j]);
157       lambda[7] += pn[j] * PetscConj(pq[j]);
158       lambda[9] += pz[j] * PetscConj(pq[j]);
159       lambda[10] += pr[j] * PetscConj(pu[j]);
160       lambda[11] += pw[j] * PetscConj(pu[j]);
161       lambda[12] += pu[j] * PetscConj(pu[j]);
162     }
163     lambda[3]  = PetscConj(lambda[2]);
164     lambda[5]  = PetscConj(lambda[1]);
165     lambda[8]  = PetscConj(lambda[7]);
166     lambda[13] = PetscConj(lambda[11]);
167     lambda[14] = PetscConj(lambda[0]);
168 
169   } else if (normtype == KSP_NORM_UNPRECONDITIONED) {
170     PetscPragmaSIMD
171     for (j = 0; j < n; j++) {
172       pz[j]  = pn[j];
173       pq[j]  = pm[j];
174       ps[j]  = pw[j];
175       pp[j]  = pu[j];
176       pc[j]  = pg0[j];
177       pd[j]  = ph0[j];
178       pa1[j] = pe[j];
179       pb1[j] = pf[j];
180 
181       px[j]  = px[j] + alpha0 * pp[j];
182       pr[j]  = pr[j] - alpha0 * ps[j];
183       pu[j]  = pu[j] - alpha0 * pq[j];
184       pw[j]  = pw[j] - alpha0 * pz[j];
185       pm[j]  = pm[j] - alpha0 * pc[j];
186       pn[j]  = pn[j] - alpha0 * pd[j];
187       pg0[j] = pg0[j] - alpha0 * pa1[j];
188       ph0[j] = ph0[j] - alpha0 * pb1[j];
189 
190       pg1[j] = pg0[j];
191       ph1[j] = ph0[j];
192 
193       pz[j] = pn[j] + beta1 * pz[j];
194       pq[j] = pm[j] + beta1 * pq[j];
195       ps[j] = pw[j] + beta1 * ps[j];
196       pp[j] = pu[j] + beta1 * pp[j];
197       pc[j] = pg0[j] + beta1 * pc[j];
198       pd[j] = ph0[j] + beta1 * pd[j];
199 
200       px[j] = px[j] + alpha1 * pp[j];
201       pr[j] = pr[j] - alpha1 * ps[j];
202       pu[j] = pu[j] - alpha1 * pq[j];
203       pw[j] = pw[j] - alpha1 * pz[j];
204       pm[j] = pm[j] - alpha1 * pc[j];
205       pn[j] = pn[j] - alpha1 * pd[j];
206 
207       lambda[0] += ps[j] * PetscConj(pu[j]);
208       lambda[1] += pw[j] * PetscConj(pm[j]);
209       lambda[2] += pw[j] * PetscConj(pq[j]);
210       lambda[4] += ps[j] * PetscConj(pq[j]);
211       lambda[6] += pn[j] * PetscConj(pm[j]);
212       lambda[7] += pn[j] * PetscConj(pq[j]);
213       lambda[9] += pz[j] * PetscConj(pq[j]);
214       lambda[10] += pr[j] * PetscConj(pu[j]);
215       lambda[11] += pw[j] * PetscConj(pu[j]);
216       lambda[12] += pr[j] * PetscConj(pr[j]);
217     }
218     lambda[3]  = PetscConj(lambda[2]);
219     lambda[5]  = PetscConj(lambda[1]);
220     lambda[8]  = PetscConj(lambda[7]);
221     lambda[13] = PetscConj(lambda[11]);
222     lambda[14] = PetscConj(lambda[0]);
223 
224   } else if (normtype == KSP_NORM_NATURAL) {
225     PetscPragmaSIMD
226     for (j = 0; j < n; j++) {
227       pz[j]  = pn[j];
228       pq[j]  = pm[j];
229       ps[j]  = pw[j];
230       pp[j]  = pu[j];
231       pc[j]  = pg0[j];
232       pd[j]  = ph0[j];
233       pa1[j] = pe[j];
234       pb1[j] = pf[j];
235 
236       px[j]  = px[j] + alpha0 * pp[j];
237       pr[j]  = pr[j] - alpha0 * ps[j];
238       pu[j]  = pu[j] - alpha0 * pq[j];
239       pw[j]  = pw[j] - alpha0 * pz[j];
240       pm[j]  = pm[j] - alpha0 * pc[j];
241       pn[j]  = pn[j] - alpha0 * pd[j];
242       pg0[j] = pg0[j] - alpha0 * pa1[j];
243       ph0[j] = ph0[j] - alpha0 * pb1[j];
244 
245       pg1[j] = pg0[j];
246       ph1[j] = ph0[j];
247 
248       pz[j] = pn[j] + beta1 * pz[j];
249       pq[j] = pm[j] + beta1 * pq[j];
250       ps[j] = pw[j] + beta1 * ps[j];
251       pp[j] = pu[j] + beta1 * pp[j];
252       pc[j] = pg0[j] + beta1 * pc[j];
253       pd[j] = ph0[j] + beta1 * pd[j];
254 
255       px[j] = px[j] + alpha1 * pp[j];
256       pr[j] = pr[j] - alpha1 * ps[j];
257       pu[j] = pu[j] - alpha1 * pq[j];
258       pw[j] = pw[j] - alpha1 * pz[j];
259       pm[j] = pm[j] - alpha1 * pc[j];
260       pn[j] = pn[j] - alpha1 * pd[j];
261 
262       lambda[0] += ps[j] * PetscConj(pu[j]);
263       lambda[1] += pw[j] * PetscConj(pm[j]);
264       lambda[2] += pw[j] * PetscConj(pq[j]);
265       lambda[4] += ps[j] * PetscConj(pq[j]);
266       lambda[6] += pn[j] * PetscConj(pm[j]);
267       lambda[7] += pn[j] * PetscConj(pq[j]);
268       lambda[9] += pz[j] * PetscConj(pq[j]);
269       lambda[10] += pr[j] * PetscConj(pu[j]);
270       lambda[11] += pw[j] * PetscConj(pu[j]);
271     }
272     lambda[3]  = PetscConj(lambda[2]);
273     lambda[5]  = PetscConj(lambda[1]);
274     lambda[8]  = PetscConj(lambda[7]);
275     lambda[13] = PetscConj(lambda[11]);
276     lambda[14] = PetscConj(lambda[0]);
277     lambda[12] = lambda[10];
278   }
279 
280   PetscCall(VecRestoreArray(vx, (PetscScalar **)&px));
281   PetscCall(VecRestoreArray(vr, (PetscScalar **)&pr));
282   PetscCall(VecRestoreArray(vz, (PetscScalar **)&pz));
283   PetscCall(VecRestoreArray(vw, (PetscScalar **)&pw));
284   PetscCall(VecRestoreArray(vp, (PetscScalar **)&pp));
285   PetscCall(VecRestoreArray(vq, (PetscScalar **)&pq));
286   PetscCall(VecRestoreArray(vc, (PetscScalar **)&pc));
287   PetscCall(VecRestoreArray(vd, (PetscScalar **)&pd));
288   PetscCall(VecRestoreArray(vg0, (PetscScalar **)&pg0));
289   PetscCall(VecRestoreArray(vh0, (PetscScalar **)&ph0));
290   PetscCall(VecRestoreArray(vg1, (PetscScalar **)&pg1));
291   PetscCall(VecRestoreArray(vh1, (PetscScalar **)&ph1));
292   PetscCall(VecRestoreArray(vs, (PetscScalar **)&ps));
293   PetscCall(VecRestoreArray(va1, (PetscScalar **)&pa1));
294   PetscCall(VecRestoreArray(vb1, (PetscScalar **)&pb1));
295   PetscCall(VecRestoreArray(ve, (PetscScalar **)&pe));
296   PetscCall(VecRestoreArray(vf, (PetscScalar **)&pf));
297   PetscCall(VecRestoreArray(vm, (PetscScalar **)&pm));
298   PetscCall(VecRestoreArray(vn, (PetscScalar **)&pn));
299   PetscCall(VecRestoreArray(vu, (PetscScalar **)&pu));
300   PetscFunctionReturn(PETSC_SUCCESS);
301 }
302 
303 /*   VecMergedOps_Private function merges the dot products, AXPY and SAXPY operations for all vectors for iteration > 0  */
VecMergedOps_Private(Vec vx,Vec vr,Vec vz,Vec vw,Vec vp,Vec vq,Vec vc,Vec vd,Vec vg0,Vec vh0,Vec vg1,Vec vh1,Vec vs,Vec va1,Vec vb1,Vec ve,Vec vf,Vec vm,Vec vn,Vec vu,PetscInt normtype,PetscScalar beta0,PetscScalar alpha0,PetscScalar beta1,PetscScalar alpha1,PetscScalar * lambda,PetscScalar alphaold)304 static PetscErrorCode VecMergedOps_Private(Vec vx, Vec vr, Vec vz, Vec vw, Vec vp, Vec vq, Vec vc, Vec vd, Vec vg0, Vec vh0, Vec vg1, Vec vh1, Vec vs, Vec va1, Vec vb1, Vec ve, Vec vf, Vec vm, Vec vn, Vec vu, PetscInt normtype, PetscScalar beta0, PetscScalar alpha0, PetscScalar beta1, PetscScalar alpha1, PetscScalar *lambda, PetscScalar alphaold)
305 {
306   PetscScalar *PETSC_RESTRICT px, *PETSC_RESTRICT pr, *PETSC_RESTRICT pz, *PETSC_RESTRICT pw;
307   PetscScalar *PETSC_RESTRICT pp, *PETSC_RESTRICT pq;
308   PetscScalar *PETSC_RESTRICT pc, *PETSC_RESTRICT pd, *PETSC_RESTRICT pg0, *PETSC_RESTRICT ph0, *PETSC_RESTRICT pg1, *PETSC_RESTRICT ph1, *PETSC_RESTRICT ps, *PETSC_RESTRICT pa1, *PETSC_RESTRICT pb1, *PETSC_RESTRICT pe, *PETSC_RESTRICT pf, *PETSC_RESTRICT pm, *PETSC_RESTRICT pn, *PETSC_RESTRICT pu;
309   PetscInt j, n;
310 
311   PetscFunctionBegin;
312   PetscCall(VecGetArray(vx, (PetscScalar **)&px));
313   PetscCall(VecGetArray(vr, (PetscScalar **)&pr));
314   PetscCall(VecGetArray(vz, (PetscScalar **)&pz));
315   PetscCall(VecGetArray(vw, (PetscScalar **)&pw));
316   PetscCall(VecGetArray(vp, (PetscScalar **)&pp));
317   PetscCall(VecGetArray(vq, (PetscScalar **)&pq));
318   PetscCall(VecGetArray(vc, (PetscScalar **)&pc));
319   PetscCall(VecGetArray(vd, (PetscScalar **)&pd));
320   PetscCall(VecGetArray(vg0, (PetscScalar **)&pg0));
321   PetscCall(VecGetArray(vh0, (PetscScalar **)&ph0));
322   PetscCall(VecGetArray(vg1, (PetscScalar **)&pg1));
323   PetscCall(VecGetArray(vh1, (PetscScalar **)&ph1));
324   PetscCall(VecGetArray(vs, (PetscScalar **)&ps));
325   PetscCall(VecGetArray(va1, (PetscScalar **)&pa1));
326   PetscCall(VecGetArray(vb1, (PetscScalar **)&pb1));
327   PetscCall(VecGetArray(ve, (PetscScalar **)&pe));
328   PetscCall(VecGetArray(vf, (PetscScalar **)&pf));
329   PetscCall(VecGetArray(vm, (PetscScalar **)&pm));
330   PetscCall(VecGetArray(vn, (PetscScalar **)&pn));
331   PetscCall(VecGetArray(vu, (PetscScalar **)&pu));
332 
333   PetscCall(VecGetLocalSize(vx, &n));
334   for (j = 0; j < 15; j++) lambda[j] = 0.0;
335 
336   if (normtype == KSP_NORM_PRECONDITIONED) {
337     PetscPragmaSIMD
338     for (j = 0; j < n; j++) {
339       pa1[j] = (pg1[j] - pg0[j]) / alphaold;
340       pb1[j] = (ph1[j] - ph0[j]) / alphaold;
341 
342       pz[j]  = pn[j] + beta0 * pz[j];
343       pq[j]  = pm[j] + beta0 * pq[j];
344       ps[j]  = pw[j] + beta0 * ps[j];
345       pp[j]  = pu[j] + beta0 * pp[j];
346       pc[j]  = pg0[j] + beta0 * pc[j];
347       pd[j]  = ph0[j] + beta0 * pd[j];
348       pa1[j] = pe[j] + beta0 * pa1[j];
349       pb1[j] = pf[j] + beta0 * pb1[j];
350 
351       px[j]  = px[j] + alpha0 * pp[j];
352       pr[j]  = pr[j] - alpha0 * ps[j];
353       pu[j]  = pu[j] - alpha0 * pq[j];
354       pw[j]  = pw[j] - alpha0 * pz[j];
355       pm[j]  = pm[j] - alpha0 * pc[j];
356       pn[j]  = pn[j] - alpha0 * pd[j];
357       pg0[j] = pg0[j] - alpha0 * pa1[j];
358       ph0[j] = ph0[j] - alpha0 * pb1[j];
359 
360       pg1[j] = pg0[j];
361       ph1[j] = ph0[j];
362 
363       pz[j] = pn[j] + beta1 * pz[j];
364       pq[j] = pm[j] + beta1 * pq[j];
365       ps[j] = pw[j] + beta1 * ps[j];
366       pp[j] = pu[j] + beta1 * pp[j];
367       pc[j] = pg0[j] + beta1 * pc[j];
368       pd[j] = ph0[j] + beta1 * pd[j];
369 
370       px[j] = px[j] + alpha1 * pp[j];
371       pr[j] = pr[j] - alpha1 * ps[j];
372       pu[j] = pu[j] - alpha1 * pq[j];
373       pw[j] = pw[j] - alpha1 * pz[j];
374       pm[j] = pm[j] - alpha1 * pc[j];
375       pn[j] = pn[j] - alpha1 * pd[j];
376 
377       lambda[0] += ps[j] * PetscConj(pu[j]);
378       lambda[1] += pw[j] * PetscConj(pm[j]);
379       lambda[2] += pw[j] * PetscConj(pq[j]);
380       lambda[4] += ps[j] * PetscConj(pq[j]);
381       lambda[6] += pn[j] * PetscConj(pm[j]);
382       lambda[7] += pn[j] * PetscConj(pq[j]);
383       lambda[9] += pz[j] * PetscConj(pq[j]);
384       lambda[10] += pr[j] * PetscConj(pu[j]);
385       lambda[11] += pw[j] * PetscConj(pu[j]);
386       lambda[12] += pu[j] * PetscConj(pu[j]);
387     }
388     lambda[3]  = PetscConj(lambda[2]);
389     lambda[5]  = PetscConj(lambda[1]);
390     lambda[8]  = PetscConj(lambda[7]);
391     lambda[13] = PetscConj(lambda[11]);
392     lambda[14] = PetscConj(lambda[0]);
393   } else if (normtype == KSP_NORM_UNPRECONDITIONED) {
394     PetscPragmaSIMD
395     for (j = 0; j < n; j++) {
396       pa1[j] = (pg1[j] - pg0[j]) / alphaold;
397       pb1[j] = (ph1[j] - ph0[j]) / alphaold;
398 
399       pz[j]  = pn[j] + beta0 * pz[j];
400       pq[j]  = pm[j] + beta0 * pq[j];
401       ps[j]  = pw[j] + beta0 * ps[j];
402       pp[j]  = pu[j] + beta0 * pp[j];
403       pc[j]  = pg0[j] + beta0 * pc[j];
404       pd[j]  = ph0[j] + beta0 * pd[j];
405       pa1[j] = pe[j] + beta0 * pa1[j];
406       pb1[j] = pf[j] + beta0 * pb1[j];
407 
408       px[j]  = px[j] + alpha0 * pp[j];
409       pr[j]  = pr[j] - alpha0 * ps[j];
410       pu[j]  = pu[j] - alpha0 * pq[j];
411       pw[j]  = pw[j] - alpha0 * pz[j];
412       pm[j]  = pm[j] - alpha0 * pc[j];
413       pn[j]  = pn[j] - alpha0 * pd[j];
414       pg0[j] = pg0[j] - alpha0 * pa1[j];
415       ph0[j] = ph0[j] - alpha0 * pb1[j];
416 
417       pg1[j] = pg0[j];
418       ph1[j] = ph0[j];
419 
420       pz[j] = pn[j] + beta1 * pz[j];
421       pq[j] = pm[j] + beta1 * pq[j];
422       ps[j] = pw[j] + beta1 * ps[j];
423       pp[j] = pu[j] + beta1 * pp[j];
424       pc[j] = pg0[j] + beta1 * pc[j];
425       pd[j] = ph0[j] + beta1 * pd[j];
426 
427       px[j] = px[j] + alpha1 * pp[j];
428       pr[j] = pr[j] - alpha1 * ps[j];
429       pu[j] = pu[j] - alpha1 * pq[j];
430       pw[j] = pw[j] - alpha1 * pz[j];
431       pm[j] = pm[j] - alpha1 * pc[j];
432       pn[j] = pn[j] - alpha1 * pd[j];
433 
434       lambda[0] += ps[j] * PetscConj(pu[j]);
435       lambda[1] += pw[j] * PetscConj(pm[j]);
436       lambda[2] += pw[j] * PetscConj(pq[j]);
437       lambda[4] += ps[j] * PetscConj(pq[j]);
438       lambda[6] += pn[j] * PetscConj(pm[j]);
439       lambda[7] += pn[j] * PetscConj(pq[j]);
440       lambda[9] += pz[j] * PetscConj(pq[j]);
441       lambda[10] += pr[j] * PetscConj(pu[j]);
442       lambda[11] += pw[j] * PetscConj(pu[j]);
443       lambda[12] += pr[j] * PetscConj(pr[j]);
444     }
445     lambda[3]  = PetscConj(lambda[2]);
446     lambda[5]  = PetscConj(lambda[1]);
447     lambda[8]  = PetscConj(lambda[7]);
448     lambda[13] = PetscConj(lambda[11]);
449     lambda[14] = PetscConj(lambda[0]);
450   } else if (normtype == KSP_NORM_NATURAL) {
451     PetscPragmaSIMD
452     for (j = 0; j < n; j++) {
453       pa1[j] = (pg1[j] - pg0[j]) / alphaold;
454       pb1[j] = (ph1[j] - ph0[j]) / alphaold;
455 
456       pz[j]  = pn[j] + beta0 * pz[j];
457       pq[j]  = pm[j] + beta0 * pq[j];
458       ps[j]  = pw[j] + beta0 * ps[j];
459       pp[j]  = pu[j] + beta0 * pp[j];
460       pc[j]  = pg0[j] + beta0 * pc[j];
461       pd[j]  = ph0[j] + beta0 * pd[j];
462       pa1[j] = pe[j] + beta0 * pa1[j];
463       pb1[j] = pf[j] + beta0 * pb1[j];
464 
465       px[j]  = px[j] + alpha0 * pp[j];
466       pr[j]  = pr[j] - alpha0 * ps[j];
467       pu[j]  = pu[j] - alpha0 * pq[j];
468       pw[j]  = pw[j] - alpha0 * pz[j];
469       pm[j]  = pm[j] - alpha0 * pc[j];
470       pn[j]  = pn[j] - alpha0 * pd[j];
471       pg0[j] = pg0[j] - alpha0 * pa1[j];
472       ph0[j] = ph0[j] - alpha0 * pb1[j];
473 
474       pg1[j] = pg0[j];
475       ph1[j] = ph0[j];
476 
477       pz[j] = pn[j] + beta1 * pz[j];
478       pq[j] = pm[j] + beta1 * pq[j];
479       ps[j] = pw[j] + beta1 * ps[j];
480       pp[j] = pu[j] + beta1 * pp[j];
481       pc[j] = pg0[j] + beta1 * pc[j];
482       pd[j] = ph0[j] + beta1 * pd[j];
483 
484       px[j] = px[j] + alpha1 * pp[j];
485       pr[j] = pr[j] - alpha1 * ps[j];
486       pu[j] = pu[j] - alpha1 * pq[j];
487       pw[j] = pw[j] - alpha1 * pz[j];
488       pm[j] = pm[j] - alpha1 * pc[j];
489       pn[j] = pn[j] - alpha1 * pd[j];
490 
491       lambda[0] += ps[j] * PetscConj(pu[j]);
492       lambda[1] += pw[j] * PetscConj(pm[j]);
493       lambda[2] += pw[j] * PetscConj(pq[j]);
494       lambda[4] += ps[j] * PetscConj(pq[j]);
495       lambda[6] += pn[j] * PetscConj(pm[j]);
496       lambda[7] += pn[j] * PetscConj(pq[j]);
497       lambda[9] += pz[j] * PetscConj(pq[j]);
498       lambda[10] += pr[j] * PetscConj(pu[j]);
499       lambda[11] += pw[j] * PetscConj(pu[j]);
500     }
501     lambda[3]  = PetscConj(lambda[2]);
502     lambda[5]  = PetscConj(lambda[1]);
503     lambda[8]  = PetscConj(lambda[7]);
504     lambda[13] = PetscConj(lambda[11]);
505     lambda[14] = PetscConj(lambda[0]);
506     lambda[12] = lambda[10];
507   }
508 
509   PetscCall(VecRestoreArray(vx, (PetscScalar **)&px));
510   PetscCall(VecRestoreArray(vr, (PetscScalar **)&pr));
511   PetscCall(VecRestoreArray(vz, (PetscScalar **)&pz));
512   PetscCall(VecRestoreArray(vw, (PetscScalar **)&pw));
513   PetscCall(VecRestoreArray(vp, (PetscScalar **)&pp));
514   PetscCall(VecRestoreArray(vq, (PetscScalar **)&pq));
515   PetscCall(VecRestoreArray(vc, (PetscScalar **)&pc));
516   PetscCall(VecRestoreArray(vd, (PetscScalar **)&pd));
517   PetscCall(VecRestoreArray(vg0, (PetscScalar **)&pg0));
518   PetscCall(VecRestoreArray(vh0, (PetscScalar **)&ph0));
519   PetscCall(VecRestoreArray(vg1, (PetscScalar **)&pg1));
520   PetscCall(VecRestoreArray(vh1, (PetscScalar **)&ph1));
521   PetscCall(VecRestoreArray(vs, (PetscScalar **)&ps));
522   PetscCall(VecRestoreArray(va1, (PetscScalar **)&pa1));
523   PetscCall(VecRestoreArray(vb1, (PetscScalar **)&pb1));
524   PetscCall(VecRestoreArray(ve, (PetscScalar **)&pe));
525   PetscCall(VecRestoreArray(vf, (PetscScalar **)&pf));
526   PetscCall(VecRestoreArray(vm, (PetscScalar **)&pm));
527   PetscCall(VecRestoreArray(vn, (PetscScalar **)&pn));
528   PetscCall(VecRestoreArray(vu, (PetscScalar **)&pu));
529   PetscFunctionReturn(PETSC_SUCCESS);
530 }
531 
532 /*
533      KSPSetUp_PIPECG2 - Sets up the workspace needed by the PIPECG method.
534 
535       This is called once, usually automatically by KSPSolve() or KSPSetUp()
536      but can be called directly by KSPSetUp()
537 */
KSPSetUp_PIPECG2(KSP ksp)538 static PetscErrorCode KSPSetUp_PIPECG2(KSP ksp)
539 {
540   PetscFunctionBegin;
541   /* get work vectors needed by PIPECG2 */
542   PetscCall(KSPSetWorkVecs(ksp, 20));
543   PetscFunctionReturn(PETSC_SUCCESS);
544 }
545 
546 /*
547  KSPSolve_PIPECG2 - This routine actually applies the PIPECG2 method
548 */
KSPSolve_PIPECG2(KSP ksp)549 static PetscErrorCode KSPSolve_PIPECG2(KSP ksp)
550 {
551   PetscInt    i, n;
552   PetscScalar alpha[2], beta[2], gamma[2], delta[2], lambda[15];
553   PetscScalar dps = 0.0, alphaold = 0.0;
554   PetscReal   dp = 0.0;
555   Vec         X, B, Z, P, W, Q, U, M, N, R, S, C, D, E, F, G[2], H[2], A1, B1;
556   Mat         Amat, Pmat;
557   PetscBool   diagonalscale;
558   MPI_Comm    pcomm;
559   MPI_Request req;
560   MPI_Status  stat;
561 
562   PetscFunctionBegin;
563   pcomm = PetscObjectComm((PetscObject)ksp);
564   PetscCall(PCGetDiagonalScale(ksp->pc, &diagonalscale));
565   PetscCheck(!diagonalscale, PetscObjectComm((PetscObject)ksp), PETSC_ERR_SUP, "Krylov method %s does not support diagonal scaling", ((PetscObject)ksp)->type_name);
566 
567   X    = ksp->vec_sol;
568   B    = ksp->vec_rhs;
569   M    = ksp->work[0];
570   Z    = ksp->work[1];
571   P    = ksp->work[2];
572   N    = ksp->work[3];
573   W    = ksp->work[4];
574   Q    = ksp->work[5];
575   U    = ksp->work[6];
576   R    = ksp->work[7];
577   S    = ksp->work[8];
578   C    = ksp->work[9];
579   D    = ksp->work[10];
580   E    = ksp->work[11];
581   F    = ksp->work[12];
582   G[0] = ksp->work[13];
583   H[0] = ksp->work[14];
584   G[1] = ksp->work[15];
585   H[1] = ksp->work[16];
586   A1   = ksp->work[17];
587   B1   = ksp->work[18];
588 
589   PetscCall(PetscMemzero(alpha, 2 * sizeof(PetscScalar)));
590   PetscCall(PetscMemzero(beta, 2 * sizeof(PetscScalar)));
591   PetscCall(PetscMemzero(gamma, 2 * sizeof(PetscScalar)));
592   PetscCall(PetscMemzero(delta, 2 * sizeof(PetscScalar)));
593   PetscCall(PetscMemzero(lambda, 15 * sizeof(PetscScalar)));
594 
595   PetscCall(VecGetLocalSize(B, &n));
596   PetscCall(PCGetOperators(ksp->pc, &Amat, &Pmat));
597 
598   ksp->its = 0;
599   if (!ksp->guess_zero) {
600     PetscCall(KSP_MatMult(ksp, Amat, X, R)); /*  r <- b - Ax  */
601     PetscCall(VecAYPX(R, -1.0, B));
602   } else {
603     PetscCall(VecCopy(B, R)); /*  r <- b (x is 0) */
604   }
605 
606   PetscCall(KSP_PCApply(ksp, R, U));       /*  u <- Br  */
607   PetscCall(KSP_MatMult(ksp, Amat, U, W)); /*  w <- Au  */
608 
609   PetscCall(VecMergedDot_Private(U, W, R, ksp->normtype, &gamma[0], &delta[0], &dps)); /*  gamma  <- r'*u , delta <- w'*u , dp <- u'*u or r'*r or r'*u depending on ksp_norm_type  */
610   lambda[10] = gamma[0];
611   lambda[11] = delta[0];
612   lambda[12] = dps;
613 
614 #if defined(PETSC_HAVE_MPI_NONBLOCKING_COLLECTIVES)
615   PetscCallMPI(MPI_Iallreduce(MPI_IN_PLACE, &lambda[10], 3, MPIU_SCALAR, MPIU_SUM, pcomm, &req));
616 #else
617   PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, &lambda[10], 3, MPIU_SCALAR, MPIU_SUM, pcomm));
618   req = MPI_REQUEST_NULL;
619 #endif
620 
621   PetscCall(KSP_PCApply(ksp, W, M));       /*  m <- Bw  */
622   PetscCall(KSP_MatMult(ksp, Amat, M, N)); /*  n <- Am  */
623 
624   PetscCall(KSP_PCApply(ksp, N, G[0]));          /*  g <- Bn  */
625   PetscCall(KSP_MatMult(ksp, Amat, G[0], H[0])); /*  h <- Ag  */
626 
627   PetscCall(KSP_PCApply(ksp, H[0], E));    /*  e <- Bh  */
628   PetscCall(KSP_MatMult(ksp, Amat, E, F)); /*  f <- Ae  */
629 
630   PetscCallMPI(MPI_Wait(&req, &stat));
631 
632   gamma[0] = lambda[10];
633   delta[0] = lambda[11];
634   dp       = PetscSqrtReal(PetscAbsScalar(lambda[12]));
635 
636   PetscCall(VecMergedDot2_Private(N, M, W, &lambda[1], &lambda[6])); /*  lambda_1 <- w'*m , lambda_4 <- n'*m  */
637   PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, &lambda[1], 1, MPIU_SCALAR, MPIU_SUM, pcomm));
638   PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, &lambda[6], 1, MPIU_SCALAR, MPIU_SUM, pcomm));
639 
640   lambda[5]  = PetscConj(lambda[1]);
641   lambda[13] = PetscConj(lambda[11]);
642 
643   PetscCall(KSPLogResidualHistory(ksp, dp));
644   PetscCall(KSPMonitor(ksp, 0, dp));
645   ksp->rnorm = dp;
646 
647   PetscCall((*ksp->converged)(ksp, 0, dp, &ksp->reason, ksp->cnvP)); /* test for convergence */
648   if (ksp->reason) PetscFunctionReturn(PETSC_SUCCESS);
649 
650   for (i = 2; i < ksp->max_it; i += 2) {
651     if (i == 2) {
652       beta[0]  = 0;
653       alpha[0] = gamma[0] / delta[0];
654 
655       gamma[1] = gamma[0] - alpha[0] * lambda[13] - alpha[0] * delta[0] + alpha[0] * alpha[0] * lambda[1];
656       delta[1] = delta[0] - alpha[0] * lambda[1] - alpha[0] * lambda[5] + alpha[0] * alpha[0] * lambda[6];
657 
658       beta[1]  = gamma[1] / gamma[0];
659       alpha[1] = gamma[1] / (delta[1] - beta[1] / alpha[0] * gamma[1]);
660 
661       PetscCall(VecMergedOpsShort_Private(X, R, Z, W, P, Q, C, D, G[0], H[0], G[1], H[1], S, A1, B1, E, F, M, N, U, ksp->normtype, beta[0], alpha[0], beta[1], alpha[1], lambda));
662     } else {
663       beta[0]  = gamma[1] / gamma[0];
664       alpha[0] = gamma[1] / (delta[1] - beta[0] / alpha[1] * gamma[1]);
665 
666       gamma[0] = gamma[1];
667       delta[0] = delta[1];
668 
669       gamma[1] = gamma[0] - alpha[0] * (lambda[13] + beta[0] * lambda[14]) - alpha[0] * (delta[0] + beta[0] * lambda[0]) + alpha[0] * alpha[0] * (lambda[1] + beta[0] * lambda[2] + beta[0] * lambda[3] + beta[0] * beta[0] * lambda[4]);
670 
671       delta[1] = delta[0] - alpha[0] * (lambda[1] + beta[0] * lambda[2]) - alpha[0] * (lambda[5] + beta[0] * lambda[3]) + alpha[0] * alpha[0] * (lambda[6] + beta[0] * lambda[7] + beta[0] * lambda[8] + beta[0] * beta[0] * lambda[9]);
672 
673       beta[1]  = gamma[1] / gamma[0];
674       alpha[1] = gamma[1] / (delta[1] - beta[1] / alpha[0] * gamma[1]);
675 
676       PetscCall(VecMergedOps_Private(X, R, Z, W, P, Q, C, D, G[0], H[0], G[1], H[1], S, A1, B1, E, F, M, N, U, ksp->normtype, beta[0], alpha[0], beta[1], alpha[1], lambda, alphaold));
677     }
678 
679     gamma[0] = gamma[1];
680     delta[0] = delta[1];
681 
682 #if defined(PETSC_HAVE_MPI_NONBLOCKING_COLLECTIVES)
683     PetscCallMPI(MPI_Iallreduce(MPI_IN_PLACE, lambda, 15, MPIU_SCALAR, MPIU_SUM, pcomm, &req));
684 #else
685     PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, lambda, 15, MPIU_SCALAR, MPIU_SUM, pcomm));
686     req = MPI_REQUEST_NULL;
687 #endif
688 
689     PetscCall(KSP_PCApply(ksp, N, G[0]));          /*  g <- Bn  */
690     PetscCall(KSP_MatMult(ksp, Amat, G[0], H[0])); /*  h <- Ag  */
691 
692     PetscCall(KSP_PCApply(ksp, H[0], E));    /*  e <- Bh  */
693     PetscCall(KSP_MatMult(ksp, Amat, E, F)); /*  f <- Ae */
694 
695     PetscCallMPI(MPI_Wait(&req, &stat));
696 
697     gamma[1] = lambda[10];
698     delta[1] = lambda[11];
699     dp       = PetscSqrtReal(PetscAbsScalar(lambda[12]));
700 
701     alphaold = alpha[1];
702     ksp->its = i;
703 
704     if (i > 0) {
705       if (ksp->normtype == KSP_NORM_NONE) dp = 0.0;
706       ksp->rnorm = dp;
707       PetscCall(KSPLogResidualHistory(ksp, dp));
708       PetscCall(KSPMonitor(ksp, i, dp));
709       PetscCall((*ksp->converged)(ksp, i, dp, &ksp->reason, ksp->cnvP));
710       if (ksp->reason) break;
711     }
712   }
713 
714   if (i >= ksp->max_it) ksp->reason = KSP_DIVERGED_ITS;
715   PetscFunctionReturn(PETSC_SUCCESS);
716 }
717 
718 /*MC
719    KSPPIPECG2 - Pipelined conjugate gradient method with a single non-blocking reduction per two iterations {cite}`tiwari2020pipelined`. [](sec_pipelineksp)
720 
721    Level: intermediate
722 
723    Notes:
724    This method has only a single non-blocking reduction per two iterations, compared to 2 blocking for standard `KSPCG`.  The
725    non-blocking reduction is overlapped by two matrix-vector products and two preconditioner applications.
726 
727    The solver has a two-step inner iteration, each of which computes the solution and updates the residual norm.
728    Hence the values from `KSPGetResidualHistory()` and `KSPGetIterationNumber()` will differ.
729 
730    MPI configuration may be necessary for reductions to make asynchronous progress, which is important for performance of pipelined methods.
731    See [](doc_faq_pipelined)
732 
733    Developer Note:
734    The implementation code contains a good amount of hand-tuned fusion of multiple inner products and similar computations on multiple vectors
735 
736    Contributed by:
737    Manasi Tiwari, Computational and Data Sciences, Indian Institute of Science, Bangalore
738 
739 .seealso: [](ch_ksp), [](doc_faq_pipelined), [](sec_pipelineksp), `KSPCreate()`, `KSPSetType()`, `KSPCG`, `KSPPIPECG`, `KSPGROPPCG`
740 M*/
KSPCreate_PIPECG2(KSP ksp)741 PETSC_EXTERN PetscErrorCode KSPCreate_PIPECG2(KSP ksp)
742 {
743   PetscFunctionBegin;
744   PetscCall(KSPSetSupportedNorm(ksp, KSP_NORM_UNPRECONDITIONED, PC_LEFT, 2));
745   PetscCall(KSPSetSupportedNorm(ksp, KSP_NORM_PRECONDITIONED, PC_LEFT, 2));
746   PetscCall(KSPSetSupportedNorm(ksp, KSP_NORM_NATURAL, PC_LEFT, 2));
747   PetscCall(KSPSetSupportedNorm(ksp, KSP_NORM_NONE, PC_LEFT, 1));
748 
749   ksp->ops->setup          = KSPSetUp_PIPECG2;
750   ksp->ops->solve          = KSPSolve_PIPECG2;
751   ksp->ops->destroy        = KSPDestroyDefault;
752   ksp->ops->view           = NULL;
753   ksp->ops->setfromoptions = NULL;
754   ksp->ops->buildsolution  = KSPBuildSolutionDefault;
755   ksp->ops->buildresidual  = KSPBuildResidualDefault;
756   PetscFunctionReturn(PETSC_SUCCESS);
757 }
758