xref: /petsc/src/ksp/ksp/utils/lmvm/blas_cyclic/blas_cyclic.c (revision 017deb10d530c1b6d9744fcd772cd96c5fcd74f2)
1 #include "blas_cyclic.h"
2 #if PetscDefined(HAVE_CXX)
3   #include "cupm/blas_cyclic_cupm.h"
4 #endif
5 #include <petsc/private/vecimpl.h>
6 #include <petsc/private/matimpl.h>
7 #include <petscblaslapack.h>
8 
9 PetscLogEvent AXPBY_Cyc, DMV_Cyc, DSV_Cyc, TRSV_Cyc, GEMV_Cyc, HEMV_Cyc;
10 
11 #define VecCheckAllEntriesFirstRank(a, arg) PetscCheck((a)->map->range[1] == (a)->map->N, PetscObjectComm((PetscObject)(a)), PETSC_ERR_ARG_SIZ, "Vector argument # %d does not have all of its entries on the first rank", (arg))
12 #define MatCheckAllEntriesFirstRank(a, arg) \
13   PetscCheck(((a)->rmap->range[1] == (a)->rmap->N) && ((a)->cmap->range[1] == (a)->cmap->N), PetscObjectComm((PetscObject)(a)), PETSC_ERR_ARG_SIZ, "Matrix argument # %d does not have all of its entries on the first rank", (arg))
14 
15 // Takes y_stride argument because this is also used for updating a row of a MatDense
AXPBY_Private(PetscInt m,PetscScalar alpha,const PetscScalar x[],PetscScalar beta,PetscScalar y[],PetscInt y_stride)16 static inline void AXPBY_Private(PetscInt m, PetscScalar alpha, const PetscScalar x[], PetscScalar beta, PetscScalar y[], PetscInt y_stride)
17 {
18   for (PetscInt i = 0; i < m; i++) y[i * y_stride] = alpha * x[i] + beta * y[i * y_stride];
19 }
20 
AXPBYCylic_Private(PetscInt m,PetscInt oldest,PetscInt next,PetscScalar alpha,const PetscScalar x[],PetscScalar beta,PetscScalar y[],PetscInt y_stride)21 static PetscErrorCode AXPBYCylic_Private(PetscInt m, PetscInt oldest, PetscInt next, PetscScalar alpha, const PetscScalar x[], PetscScalar beta, PetscScalar y[], PetscInt y_stride)
22 {
23   PetscInt i_oldest = oldest % m;
24   PetscInt i_next   = ((next - 1) % m) + 1;
25 
26   PetscFunctionBegin;
27   if (next - oldest == m) {
28     AXPBY_Private(m, alpha, x, beta, y, y_stride);
29   } else if (i_next > i_oldest) {
30     AXPBY_Private(i_next - i_oldest, alpha, &x[i_oldest], beta, &y[i_oldest * y_stride], y_stride);
31   } else {
32     AXPBY_Private(i_next, alpha, x, beta, y, y_stride);
33     AXPBY_Private(m - i_oldest, alpha, &x[i_oldest], beta, &y[i_oldest * y_stride], y_stride);
34   }
35   PetscFunctionReturn(PETSC_SUCCESS);
36 }
37 
VecAXPBYCyclic(PetscInt oldest,PetscInt next,PetscScalar alpha,Vec x,PetscScalar beta,Vec y)38 PETSC_INTERN PetscErrorCode VecAXPBYCyclic(PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y)
39 {
40   const PetscScalar *x_;
41   PetscScalar       *y_;
42   PetscInt           m, m_local;
43   PetscMemType       x_memtype, y_memtype;
44   PetscBool          on_host = PETSC_FALSE;
45 
46   PetscFunctionBegin;
47   PetscValidHeaderSpecific(x, VEC_CLASSID, 4);
48   PetscValidHeaderSpecific(y, VEC_CLASSID, 6);
49   PetscCheckSameComm(x, 4, y, 6);
50   VecCheckSameSize(x, 4, y, 6);
51   VecCheckAllEntriesFirstRank(x, 4);
52   VecCheckAllEntriesFirstRank(y, 6);
53   PetscCall(VecGetSize(x, &m));
54   PetscCall(VecGetLocalSize(x, &m_local));
55   if (!m) PetscFunctionReturn(PETSC_SUCCESS);
56   PetscCall(PetscLogEventBegin(AXPBY_Cyc, NULL, NULL, NULL, NULL));
57   PetscCall(VecGetArrayReadAndMemType(x, &x_, &x_memtype));
58   PetscCall(VecGetArrayAndMemType(y, &y_, &y_memtype));
59   if (PetscMemTypeDevice(x_memtype) && PetscMemTypeDevice(y_memtype)) {
60 #if PetscDefined(HAVE_CUPM)
61     if (m_local == m) PetscCall(AXPBYCyclic_CUPM_Private(m, oldest, next, alpha, x_, beta, y_, 1));
62 #else
63     SETERRQ(PetscObjectComm((PetscObject)x), PETSC_ERR_PLIB, "PetscMemTypeDevice needs either CUDA or HIP support");
64 #endif
65   } else if (m_local == m) on_host = PETSC_TRUE;
66   PetscCall(VecRestoreArrayReadAndMemType(x, &x_));
67   PetscCall(VecRestoreArrayAndMemType(y, &y_));
68   if (on_host) {
69     PetscCall(VecGetArrayRead(x, &x_));
70     PetscCall(VecGetArray(y, &y_));
71     PetscCall(AXPBYCylic_Private(m, oldest, next, alpha, x_, beta, y_, 1));
72     PetscCall(VecRestoreArray(y, &y_));
73     PetscCall(VecRestoreArrayRead(x, &x_));
74   }
75   PetscCall(PetscLogEventEnd(AXPBY_Cyc, NULL, NULL, NULL, NULL));
76   PetscFunctionReturn(PETSC_SUCCESS);
77 }
78 
DMV_Private(PetscBool hermitian_transpose,PetscInt m,PetscScalar alpha,const PetscScalar A[],const PetscScalar x[],PetscScalar beta,PetscScalar y[])79 static inline void DMV_Private(PetscBool hermitian_transpose, PetscInt m, PetscScalar alpha, const PetscScalar A[], const PetscScalar x[], PetscScalar beta, PetscScalar y[])
80 {
81   if (!hermitian_transpose) {
82     for (PetscInt i = 0; i < m; i++) y[i] = alpha * A[i] * x[i] + beta * y[i];
83   } else {
84     for (PetscInt i = 0; i < m; i++) y[i] = alpha * PetscConj(A[i]) * x[i] + beta * y[i];
85   }
86 }
87 
DMVCylic_Private(PetscBool hermitian_transpose,PetscInt m,PetscInt oldest,PetscInt next,PetscScalar alpha,const PetscScalar A[],const PetscScalar x[],PetscScalar beta,PetscScalar y[])88 static PetscErrorCode DMVCylic_Private(PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, PetscScalar alpha, const PetscScalar A[], const PetscScalar x[], PetscScalar beta, PetscScalar y[])
89 {
90   PetscInt i_oldest = oldest % m;
91   PetscInt i_next   = ((next - 1) % m) + 1;
92 
93   PetscFunctionBegin;
94   if (next - oldest == m) {
95     DMV_Private(hermitian_transpose, m, alpha, A, x, beta, y);
96   } else if (i_next > i_oldest) {
97     DMV_Private(hermitian_transpose, i_next - i_oldest, alpha, &A[i_oldest], &x[i_oldest], beta, &y[i_oldest]);
98   } else {
99     DMV_Private(hermitian_transpose, i_next, alpha, A, x, beta, y);
100     DMV_Private(hermitian_transpose, m - i_oldest, alpha, &A[i_oldest], &x[i_oldest], beta, &y[i_oldest]);
101   }
102   PetscFunctionReturn(PETSC_SUCCESS);
103 }
104 
VecDMVCyclic(PetscBool hermitian_transpose,PetscInt oldest,PetscInt next,PetscScalar alpha,Vec A,Vec x,PetscScalar beta,Vec y)105 PETSC_INTERN PetscErrorCode VecDMVCyclic(PetscBool hermitian_transpose, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec A, Vec x, PetscScalar beta, Vec y)
106 {
107   const PetscScalar *A_;
108   const PetscScalar *x_;
109   PetscScalar       *y_;
110   PetscInt           m, m_local;
111   PetscMemType       A_memtype, x_memtype, y_memtype;
112   PetscBool          on_host = PETSC_FALSE;
113 
114   PetscFunctionBegin;
115   PetscValidHeaderSpecific(A, VEC_CLASSID, 5);
116   PetscValidHeaderSpecific(x, VEC_CLASSID, 6);
117   PetscValidHeaderSpecific(y, VEC_CLASSID, 8);
118   PetscCheckSameComm(A, 5, x, 6);
119   PetscCheckSameComm(A, 5, y, 8);
120   VecCheckSameSize(A, 5, x, 6);
121   VecCheckSameSize(A, 5, y, 8);
122   VecCheckAllEntriesFirstRank(A, 5);
123   VecCheckAllEntriesFirstRank(x, 6);
124   VecCheckAllEntriesFirstRank(y, 8);
125   PetscCall(VecGetSize(A, &m));
126   PetscCall(VecGetLocalSize(A, &m_local));
127   if (!m) PetscFunctionReturn(PETSC_SUCCESS);
128   PetscCall(PetscLogEventBegin(DMV_Cyc, NULL, NULL, NULL, NULL));
129   PetscCall(VecGetArrayReadAndMemType(A, &A_, &A_memtype));
130   PetscCall(VecGetArrayReadAndMemType(x, &x_, &x_memtype));
131   PetscCall(VecGetArrayAndMemType(y, &y_, &y_memtype));
132   if (PetscMemTypeDevice(A_memtype) && PetscMemTypeDevice(x_memtype) && PetscMemTypeDevice(y_memtype)) {
133 #if PetscDefined(HAVE_CUPM)
134     if (m_local == m) PetscCall(DMVCyclic_CUPM_Private(hermitian_transpose, m, oldest, next, alpha, A_, x_, beta, y_));
135 #else
136     SETERRQ(PetscObjectComm((PetscObject)x), PETSC_ERR_PLIB, "PetscMemTypeDevice needs either CUDA or HIP support");
137 #endif
138   } else if (m_local == m) on_host = PETSC_TRUE;
139   PetscCall(VecRestoreArrayAndMemType(y, &y_));
140   PetscCall(VecRestoreArrayReadAndMemType(x, &x_));
141   PetscCall(VecRestoreArrayReadAndMemType(A, &A_));
142   if (on_host) {
143     PetscCall(VecGetArrayRead(A, &A_));
144     PetscCall(VecGetArrayRead(x, &x_));
145     PetscCall(VecGetArray(y, &y_));
146     PetscCall(DMVCylic_Private(hermitian_transpose, m, oldest, next, alpha, A_, x_, beta, y_));
147     PetscCall(VecRestoreArray(y, &y_));
148     PetscCall(VecRestoreArrayRead(x, &x_));
149     PetscCall(VecRestoreArrayRead(A, &A_));
150   }
151   PetscCall(PetscLogEventEnd(DMV_Cyc, NULL, NULL, NULL, NULL));
152   PetscFunctionReturn(PETSC_SUCCESS);
153 }
154 
DSV_Private(PetscBool hermitian_transpose,PetscInt m,const PetscScalar A[],const PetscScalar x[],PetscScalar y[])155 static inline void DSV_Private(PetscBool hermitian_transpose, PetscInt m, const PetscScalar A[], const PetscScalar x[], PetscScalar y[])
156 {
157   if (x != y) {
158     if (!hermitian_transpose) {
159       for (PetscInt i = 0; i < m; i++) y[i] = x[i] / A[i];
160     } else {
161       for (PetscInt i = 0; i < m; i++) y[i] = x[i] / PetscConj(A[i]);
162     }
163   } else {
164     if (!hermitian_transpose) {
165       for (PetscInt i = 0; i < m; i++) y[i] = y[i] / A[i];
166     } else {
167       for (PetscInt i = 0; i < m; i++) y[i] = y[i] / PetscConj(A[i]);
168     }
169   }
170 }
171 
DSVCyclic_Private(PetscBool hermitian_transpose,PetscInt m,PetscInt oldest,PetscInt next,const PetscScalar A[],const PetscScalar x[],PetscScalar y[])172 static PetscErrorCode DSVCyclic_Private(PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], const PetscScalar x[], PetscScalar y[])
173 {
174   PetscInt i_oldest = oldest % m;
175   PetscInt i_next   = ((next - 1) % m) + 1;
176 
177   PetscFunctionBegin;
178   if (next - oldest == m) {
179     DSV_Private(hermitian_transpose, m, A, x, y);
180   } else if (i_next > i_oldest) {
181     DSV_Private(hermitian_transpose, i_next - i_oldest, &A[i_oldest], &x[i_oldest], &y[i_oldest]);
182   } else {
183     DSV_Private(hermitian_transpose, i_next, A, x, y);
184     DSV_Private(hermitian_transpose, m - i_oldest, &A[i_oldest], &x[i_oldest], &y[i_oldest]);
185   }
186   PetscFunctionReturn(PETSC_SUCCESS);
187 }
188 
VecDSVCyclic(PetscBool hermitian_transpose,PetscInt oldest,PetscInt next,Vec A,Vec x,Vec y)189 PETSC_INTERN PetscErrorCode VecDSVCyclic(PetscBool hermitian_transpose, PetscInt oldest, PetscInt next, Vec A, Vec x, Vec y)
190 {
191   const PetscScalar *A_;
192   const PetscScalar *x_ = NULL;
193   PetscScalar       *y_;
194   PetscInt           m, m_local;
195   PetscMemType       A_memtype, x_memtype, y_memtype;
196   PetscBool          on_host = PETSC_FALSE;
197 
198   PetscFunctionBegin;
199   PetscValidHeaderSpecific(A, VEC_CLASSID, 4);
200   PetscValidHeaderSpecific(x, VEC_CLASSID, 5);
201   PetscValidHeaderSpecific(y, VEC_CLASSID, 6);
202   PetscCheckSameComm(A, 4, x, 5);
203   PetscCheckSameComm(A, 4, y, 6);
204   VecCheckSameSize(A, 4, x, 5);
205   VecCheckSameSize(A, 4, y, 6);
206   VecCheckAllEntriesFirstRank(A, 4);
207   VecCheckAllEntriesFirstRank(x, 5);
208   VecCheckAllEntriesFirstRank(y, 6);
209   PetscCall(VecGetSize(A, &m));
210   PetscCall(VecGetLocalSize(A, &m_local));
211   if (!m) PetscFunctionReturn(PETSC_SUCCESS);
212   PetscCall(PetscLogEventBegin(DSV_Cyc, NULL, NULL, NULL, NULL));
213   PetscCall(VecGetArrayReadAndMemType(A, &A_, &A_memtype));
214   PetscCall(VecGetArrayAndMemType(y, &y_, &y_memtype));
215   if (x == y) {
216     x_        = y_;
217     x_memtype = y_memtype;
218   } else {
219     PetscCall(VecGetArrayReadAndMemType(x, &x_, &x_memtype));
220   }
221   if (PetscMemTypeDevice(A_memtype) && PetscMemTypeDevice(x_memtype) && PetscMemTypeDevice(y_memtype)) {
222 #if PetscDefined(HAVE_CUPM)
223     if (m_local == m) PetscCall(DSVCyclic_CUPM_Private(hermitian_transpose, m, oldest, next, A_, x_, y_));
224 #else
225     SETERRQ(PetscObjectComm((PetscObject)x), PETSC_ERR_PLIB, "PetscMemTypeDevice needs either CUDA or HIP support");
226 #endif
227   } else if (m_local == m) on_host = PETSC_TRUE;
228   if (x != y) PetscCall(VecRestoreArrayReadAndMemType(x, &x_));
229   PetscCall(VecRestoreArrayAndMemType(y, &y_));
230   PetscCall(VecRestoreArrayReadAndMemType(A, &A_));
231   if (on_host) {
232     PetscCall(VecGetArrayRead(A, &A_));
233     PetscCall(VecGetArray(y, &y_));
234     if (x == y) {
235       x_ = y_;
236     } else {
237       PetscCall(VecGetArrayRead(x, &x_));
238     }
239     PetscCall(DSVCyclic_Private(hermitian_transpose, m, oldest, next, A_, x_, y_));
240     if (x != y) PetscCall(VecRestoreArrayRead(x, &x_));
241     PetscCall(VecRestoreArray(y, &y_));
242     PetscCall(VecRestoreArrayRead(A, &A_));
243   }
244   PetscCall(PetscLogEventEnd(DSV_Cyc, NULL, NULL, NULL, NULL));
245   PetscFunctionReturn(PETSC_SUCCESS);
246 }
247 
TRSVCyclic_Private(PetscBool hermitian_transpose,PetscInt m,PetscInt oldest,PetscInt next,const PetscScalar A[],PetscInt lda,const PetscScalar x[],PetscScalar y[])248 static PetscErrorCode TRSVCyclic_Private(PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, const PetscScalar x[], PetscScalar y[])
249 {
250   PetscBLASInt b_one = 1, blda, bm;
251   PetscBLASInt i_oldest, i_next;
252   PetscScalar  minus_one = -1.0, one = 1.0;
253 
254   PetscFunctionBegin;
255   PetscCall(PetscBLASIntCast(lda, &blda));
256   PetscCall(PetscBLASIntCast(m, &bm));
257   PetscCall(PetscBLASIntCast(oldest % m, &i_oldest));
258   PetscCall(PetscBLASIntCast(((next - 1) % m) + 1, &i_next));
259   if (i_next > i_oldest) {
260     PetscBLASInt bn    = i_next - i_oldest;
261     const char  *trans = hermitian_transpose ? "C" : "N";
262 
263     if (x != y) PetscCall(PetscArraycpy(&y[i_oldest], &x[i_oldest], bn));
264     PetscCallBLAS("BLAStrsv", BLAStrsv_("U", trans, "N", &bn, &A[i_oldest * (lda + 1)], &blda, &y[i_oldest], &b_one));
265   } else {
266     PetscBLASInt bn = bm - i_oldest;
267     if (x != y) {
268       PetscCall(PetscArraycpy(y, x, i_next));
269       PetscCall(PetscArraycpy(&y[i_oldest], &x[i_oldest], bn));
270     }
271     if (!hermitian_transpose) {
272       if (i_next > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "N", &i_next, A, &blda, y, &b_one));
273       if (i_next > 0 && bn > 0) PetscCallBLAS("BLASgemv", BLASgemv_("N", &bn, &i_next, &minus_one, &A[i_oldest], &blda, y, &b_one, &one, &y[i_oldest], &b_one));
274       if (bn > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "N", &bn, &A[i_oldest * (lda + 1)], &blda, &y[i_oldest], &b_one));
275     } else {
276       if (bn > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "N", &bn, &A[i_oldest * (lda + 1)], &blda, &y[i_oldest], &b_one));
277       if (i_next > 0 && bn > 0) PetscCallBLAS("BLASgemv", BLASgemv_("C", &bn, &i_next, &minus_one, &A[i_oldest], &blda, &y[i_oldest], &b_one, &one, y, &b_one));
278       if (i_next > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "N", &i_next, A, &blda, y, &b_one));
279     }
280   }
281   PetscFunctionReturn(PETSC_SUCCESS);
282 }
283 
MatSeqDenseTRSVCyclic(PetscBool hermitian_transpose,PetscInt oldest,PetscInt next,Mat A,Vec x,Vec y)284 PETSC_INTERN PetscErrorCode MatSeqDenseTRSVCyclic(PetscBool hermitian_transpose, PetscInt oldest, PetscInt next, Mat A, Vec x, Vec y)
285 {
286   const PetscScalar *A_;
287   const PetscScalar *x_ = NULL;
288   PetscScalar       *y_;
289   PetscInt           m, m_local, lda;
290   PetscMemType       A_memtype, x_memtype, y_memtype;
291   PetscBool          on_host = PETSC_FALSE;
292 
293   PetscFunctionBegin;
294   PetscValidHeaderSpecific(A, MAT_CLASSID, 4);
295   PetscValidHeaderSpecific(x, VEC_CLASSID, 5);
296   PetscValidHeaderSpecific(y, VEC_CLASSID, 6);
297   PetscCheckSameComm(A, 4, x, 5);
298   PetscCheckSameComm(A, 4, y, 6);
299   VecCheckMatCompatible(A, x, 5, y, 6);
300   MatCheckAllEntriesFirstRank(A, 4);
301   VecCheckAllEntriesFirstRank(x, 5);
302   VecCheckAllEntriesFirstRank(y, 6);
303   PetscCall(VecGetSize(x, &m));
304   PetscCall(VecGetLocalSize(x, &m_local));
305   if (!m) PetscFunctionReturn(PETSC_SUCCESS);
306   PetscCall(PetscLogEventBegin(TRSV_Cyc, NULL, NULL, NULL, NULL));
307   PetscCall(MatDenseGetLDA(A, &lda));
308   PetscCall(MatDenseGetArrayReadAndMemType(A, &A_, &A_memtype));
309   PetscCall(VecGetArrayAndMemType(y, &y_, &y_memtype));
310   if (x == y) {
311     x_        = y_;
312     x_memtype = y_memtype;
313   } else {
314     PetscCall(VecGetArrayReadAndMemType(x, &x_, &x_memtype));
315   }
316   if (PetscMemTypeDevice(A_memtype) && PetscMemTypeDevice(x_memtype) && PetscMemTypeDevice(y_memtype)) {
317 #if PetscDefined(HAVE_CUPM)
318     if (m_local == m) PetscCall(TRSVCyclic_CUPM_Private(hermitian_transpose, m, oldest, next, A_, lda, x_, y_));
319 #else
320     SETERRQ(PetscObjectComm((PetscObject)x), PETSC_ERR_PLIB, "PetscMemTypeDevice needs either CUDA or HIP support");
321 #endif
322   } else if (m_local == m) on_host = PETSC_TRUE;
323   if (x != y) PetscCall(VecRestoreArrayReadAndMemType(x, &x_));
324   PetscCall(VecRestoreArrayAndMemType(y, &y_));
325   PetscCall(MatDenseRestoreArrayReadAndMemType(A, &A_));
326   if (on_host) {
327     PetscCall(MatDenseGetArrayRead(A, &A_));
328     PetscCall(VecGetArray(y, &y_));
329     if (x == y) {
330       x_ = y_;
331     } else {
332       PetscCall(VecGetArrayRead(x, &x_));
333     }
334     PetscCall(TRSVCyclic_Private(hermitian_transpose, m, oldest, next, A_, lda, x_, y_));
335     if (x != y) PetscCall(VecRestoreArrayRead(x, &x_));
336     PetscCall(VecRestoreArray(y, &y_));
337     PetscCall(MatDenseRestoreArrayRead(A, &A_));
338   }
339   PetscCall(PetscLogEventEnd(TRSV_Cyc, NULL, NULL, NULL, NULL));
340   PetscFunctionReturn(PETSC_SUCCESS);
341 }
342 
HEMVCyclic_Private(PetscInt m,PetscInt oldest,PetscInt next,PetscScalar alpha,const PetscScalar A[],PetscInt lda,const PetscScalar x[],PetscScalar beta,PetscScalar y[])343 static PetscErrorCode HEMVCyclic_Private(PetscInt m, PetscInt oldest, PetscInt next, PetscScalar alpha, const PetscScalar A[], PetscInt lda, const PetscScalar x[], PetscScalar beta, PetscScalar y[])
344 {
345   PetscBLASInt b_one = 1, blda, bm;
346   PetscBLASInt i_oldest, i_next;
347   PetscScalar  one = 1.0;
348 
349   PetscFunctionBegin;
350   PetscCall(PetscBLASIntCast(lda, &blda));
351   PetscCall(PetscBLASIntCast(m, &bm));
352   PetscCall(PetscBLASIntCast(oldest % m, &i_oldest));
353   PetscCall(PetscBLASIntCast(((next - 1) % m) + 1, &i_next));
354   if (i_next > i_oldest) {
355     PetscBLASInt bn = i_next - i_oldest;
356 
357     PetscCallBLAS("BLAShemv", BLAShemv_("U", &bn, &alpha, &A[i_oldest * (lda + 1)], &blda, &x[i_oldest], &b_one, &beta, &y[i_oldest], &b_one));
358   } else {
359     PetscBLASInt bn = bm - i_oldest;
360     if (i_next > 0) PetscCallBLAS("BLAShemv", BLAShemv_("U", &i_next, &alpha, A, &blda, x, &b_one, &beta, y, &b_one));
361     if (bn > 0) PetscCallBLAS("BLAShemv", BLAShemv_("U", &bn, &alpha, &A[i_oldest * (lda + 1)], &blda, &x[i_oldest], &b_one, &beta, &y[i_oldest], &b_one));
362     if (i_next > 0 && bn > 0) {
363       PetscCallBLAS("BLASgemv", BLASgemv_("N", &bn, &i_next, &alpha, &A[i_oldest], &blda, x, &b_one, &one, &y[i_oldest], &b_one));
364       PetscCallBLAS("BLASgemv", BLASgemv_("C", &bn, &i_next, &alpha, &A[i_oldest], &blda, &x[i_oldest], &b_one, &one, y, &b_one));
365     }
366   }
367   PetscFunctionReturn(PETSC_SUCCESS);
368 }
369 
MatSeqDenseHEMVCyclic(PetscInt oldest,PetscInt next,PetscScalar alpha,Mat A,Vec x,PetscScalar beta,Vec y)370 PETSC_INTERN PetscErrorCode MatSeqDenseHEMVCyclic(PetscInt oldest, PetscInt next, PetscScalar alpha, Mat A, Vec x, PetscScalar beta, Vec y)
371 {
372   const PetscScalar *A_;
373   const PetscScalar *x_ = NULL;
374   PetscScalar       *y_;
375   PetscInt           m, m_local, lda;
376   PetscMemType       A_memtype, x_memtype, y_memtype;
377   PetscBool          on_host = PETSC_FALSE;
378 
379   PetscFunctionBegin;
380   PetscValidHeaderSpecific(A, MAT_CLASSID, 4);
381   PetscValidHeaderSpecific(x, VEC_CLASSID, 5);
382   PetscValidHeaderSpecific(y, VEC_CLASSID, 7);
383   PetscCheckSameComm(A, 4, x, 5);
384   PetscCheckSameComm(A, 4, y, 7);
385   VecCheckMatCompatible(A, x, 5, y, 7);
386   MatCheckAllEntriesFirstRank(A, 4);
387   VecCheckAllEntriesFirstRank(x, 5);
388   VecCheckAllEntriesFirstRank(y, 7);
389   PetscCall(VecGetSize(x, &m));
390   PetscCall(VecGetLocalSize(x, &m_local));
391   if (!m) PetscFunctionReturn(PETSC_SUCCESS);
392   PetscCall(PetscLogEventBegin(HEMV_Cyc, NULL, NULL, NULL, NULL));
393   PetscCall(MatDenseGetLDA(A, &lda));
394   PetscCall(MatDenseGetArrayReadAndMemType(A, &A_, &A_memtype));
395   PetscCall(VecGetArrayReadAndMemType(x, &x_, &x_memtype));
396   PetscCall(VecGetArrayAndMemType(y, &y_, &y_memtype));
397   if (PetscMemTypeDevice(A_memtype) && PetscMemTypeDevice(x_memtype) && PetscMemTypeDevice(y_memtype)) {
398 #if PetscDefined(HAVE_CUPM)
399     if (m_local == m) PetscCall(HEMVCyclic_CUPM_Private(m, oldest, next, alpha, A_, lda, x_, beta, y_));
400 #else
401     SETERRQ(PetscObjectComm((PetscObject)x), PETSC_ERR_PLIB, "PetscMemTypeDevice needs either CUDA or HIP support");
402 #endif
403   } else if (m_local == m) on_host = PETSC_TRUE;
404   PetscCall(VecRestoreArrayAndMemType(y, &y_));
405   PetscCall(VecRestoreArrayReadAndMemType(x, &x_));
406   PetscCall(MatDenseRestoreArrayReadAndMemType(A, &A_));
407   if (on_host) {
408     PetscCall(MatDenseGetArrayRead(A, &A_));
409     PetscCall(VecGetArrayRead(x, &x_));
410     PetscCall(VecGetArray(y, &y_));
411     PetscCall(HEMVCyclic_Private(m, oldest, next, alpha, A_, lda, x_, beta, y_));
412     PetscCall(VecRestoreArray(y, &y_));
413     PetscCall(VecRestoreArrayRead(x, &x_));
414     PetscCall(MatDenseRestoreArrayRead(A, &A_));
415   }
416   PetscCall(PetscLogEventEnd(HEMV_Cyc, NULL, NULL, NULL, NULL));
417   PetscFunctionReturn(PETSC_SUCCESS);
418 }
419 
GEMVCyclic_Private(PetscBool hermitian_transpose,PetscInt m,PetscInt oldest,PetscInt next,PetscScalar alpha,const PetscScalar A[],PetscInt lda,const PetscScalar x[],PetscScalar beta,PetscScalar y[])420 static PetscErrorCode GEMVCyclic_Private(PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, PetscScalar alpha, const PetscScalar A[], PetscInt lda, const PetscScalar x[], PetscScalar beta, PetscScalar y[])
421 {
422   PetscBLASInt b_one = 1, blda, bm;
423   PetscBLASInt i_oldest, i_next;
424   PetscScalar  one   = 1.0;
425   const char  *trans = hermitian_transpose ? "C" : "N";
426 
427   PetscFunctionBegin;
428   PetscCall(PetscBLASIntCast(lda, &blda));
429   PetscCall(PetscBLASIntCast(m, &bm));
430   PetscCall(PetscBLASIntCast(oldest % m, &i_oldest));
431   PetscCall(PetscBLASIntCast(((next - 1) % m) + 1, &i_next));
432   if (next - oldest == m) {
433     PetscCallBLAS("BLASgemv", BLASgemv_(trans, &bm, &bm, &alpha, A, &blda, x, &b_one, &beta, y, &b_one));
434   } else if (i_next > i_oldest) {
435     PetscBLASInt bn = i_next - i_oldest;
436 
437     PetscCallBLAS("BLASgemv", BLASgemv_(trans, &bn, &bn, &alpha, &A[i_oldest * (lda + 1)], &blda, &x[i_oldest], &b_one, &beta, &y[i_oldest], &b_one));
438   } else {
439     PetscBLASInt bn = bm - i_oldest;
440     if (i_next > 0) PetscCallBLAS("BLASgemv", BLASgemv_(trans, &i_next, &i_next, &alpha, A, &blda, x, &b_one, &beta, y, &b_one));
441     if (bn > 0) PetscCallBLAS("BLASgemv", BLASgemv_(trans, &bn, &bn, &alpha, &A[i_oldest * (lda + 1)], &blda, &x[i_oldest], &b_one, &beta, &y[i_oldest], &b_one));
442     if (i_next > 0 && bn > 0) {
443       if (!hermitian_transpose) {
444         PetscCallBLAS("BLASgemv", BLASgemv_("N", &bn, &i_next, &alpha, &A[i_oldest], &blda, x, &b_one, &one, &y[i_oldest], &b_one));
445         PetscCallBLAS("BLASgemv", BLASgemv_("N", &i_next, &bn, &alpha, &A[i_oldest * lda], &blda, &x[i_oldest], &b_one, &one, y, &b_one));
446       } else {
447         PetscCallBLAS("BLASgemv", BLASgemv_("C", &i_next, &bn, &alpha, &A[i_oldest * lda], &blda, x, &b_one, &one, &y[i_oldest], &b_one));
448         PetscCallBLAS("BLASgemv", BLASgemv_("C", &bn, &i_next, &alpha, &A[i_oldest], &blda, &x[i_oldest], &b_one, &one, y, &b_one));
449       }
450     }
451   }
452   PetscFunctionReturn(PETSC_SUCCESS);
453 }
454 
MatSeqDenseGEMVCyclic(PetscBool hermitian_transpose,PetscInt oldest,PetscInt next,PetscScalar alpha,Mat A,Vec x,PetscScalar beta,Vec y)455 PETSC_INTERN PetscErrorCode MatSeqDenseGEMVCyclic(PetscBool hermitian_transpose, PetscInt oldest, PetscInt next, PetscScalar alpha, Mat A, Vec x, PetscScalar beta, Vec y)
456 {
457   const PetscScalar *A_;
458   const PetscScalar *x_ = NULL;
459   PetscScalar       *y_;
460   PetscInt           m, m_local, lda;
461   PetscMemType       A_memtype, x_memtype, y_memtype;
462   PetscBool          on_host = PETSC_FALSE;
463 
464   PetscFunctionBegin;
465   PetscValidHeaderSpecific(A, MAT_CLASSID, 5);
466   PetscValidHeaderSpecific(x, VEC_CLASSID, 6);
467   PetscValidHeaderSpecific(y, VEC_CLASSID, 8);
468   PetscCheckSameComm(A, 5, x, 6);
469   PetscCheckSameComm(A, 5, y, 8);
470   VecCheckMatCompatible(A, x, 6, y, 8);
471   MatCheckAllEntriesFirstRank(A, 5);
472   VecCheckAllEntriesFirstRank(x, 6);
473   VecCheckAllEntriesFirstRank(y, 8);
474   PetscCall(VecGetSize(x, &m));
475   PetscCall(VecGetLocalSize(x, &m_local));
476   if (!m) PetscFunctionReturn(PETSC_SUCCESS);
477   PetscCall(PetscLogEventBegin(GEMV_Cyc, NULL, NULL, NULL, NULL));
478   PetscCall(MatDenseGetLDA(A, &lda));
479   PetscCall(MatDenseGetArrayReadAndMemType(A, &A_, &A_memtype));
480   PetscCall(VecGetArrayReadAndMemType(x, &x_, &x_memtype));
481   PetscCall(VecGetArrayAndMemType(y, &y_, &y_memtype));
482   if (PetscMemTypeDevice(A_memtype) && PetscMemTypeDevice(x_memtype) && PetscMemTypeDevice(y_memtype)) {
483 #if PetscDefined(HAVE_CUPM)
484     if (m_local == m) PetscCall(GEMVCyclic_CUPM_Private(hermitian_transpose, m, oldest, next, alpha, A_, lda, x_, beta, y_));
485 #else
486     SETERRQ(PetscObjectComm((PetscObject)x), PETSC_ERR_PLIB, "PetscMemTypeDevice needs either CUDA or HIP support");
487 #endif
488   } else if (m_local == m) on_host = PETSC_TRUE;
489   PetscCall(VecRestoreArrayAndMemType(y, &y_));
490   PetscCall(VecRestoreArrayReadAndMemType(x, &x_));
491   PetscCall(MatDenseRestoreArrayReadAndMemType(A, &A_));
492   if (on_host) {
493     PetscCall(MatDenseGetArrayRead(A, &A_));
494     PetscCall(VecGetArrayRead(x, &x_));
495     PetscCall(VecGetArray(y, &y_));
496     PetscCall(GEMVCyclic_Private(hermitian_transpose, m, oldest, next, alpha, A_, lda, x_, beta, y_));
497     PetscCall(VecRestoreArray(y, &y_));
498     PetscCall(VecRestoreArrayRead(x, &x_));
499     PetscCall(MatDenseRestoreArrayRead(A, &A_));
500   }
501   PetscCall(PetscLogEventEnd(GEMV_Cyc, NULL, NULL, NULL, NULL));
502   PetscFunctionReturn(PETSC_SUCCESS);
503 }
504 
MatSeqDenseRowAXPBYCyclic(PetscInt oldest,PetscInt next,PetscScalar alpha,Vec x,PetscScalar beta,Mat Y,PetscInt row)505 PETSC_INTERN PetscErrorCode MatSeqDenseRowAXPBYCyclic(PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Mat Y, PetscInt row)
506 {
507   const PetscScalar *x_ = NULL;
508   PetscScalar       *y_;
509   PetscInt           m, m_local, ldy;
510   PetscMemType       x_memtype, y_memtype;
511   PetscBool          on_host = PETSC_FALSE;
512 
513   PetscFunctionBegin;
514   PetscValidHeaderSpecific(x, VEC_CLASSID, 4);
515   PetscValidHeaderSpecific(Y, MAT_CLASSID, 6);
516   PetscCheckSameComm(x, 4, Y, 6);
517   VecCheckMatCompatible(Y, x, 4, x, 4);
518   VecCheckAllEntriesFirstRank(x, 4);
519   MatCheckAllEntriesFirstRank(Y, 6);
520   PetscCall(VecGetSize(x, &m));
521   PetscCall(VecGetLocalSize(x, &m_local));
522   if (!m) PetscFunctionReturn(PETSC_SUCCESS);
523   PetscCall(PetscLogEventBegin(AXPBY_Cyc, NULL, NULL, NULL, NULL));
524   PetscCall(MatDenseGetLDA(Y, &ldy));
525   PetscCall(VecGetArrayReadAndMemType(x, &x_, &x_memtype));
526   PetscCall(MatDenseGetArrayAndMemType(Y, &y_, &y_memtype));
527   if (PetscMemTypeDevice(x_memtype) && PetscMemTypeDevice(y_memtype)) {
528 #if PetscDefined(HAVE_CUPM)
529     if (m_local == m) PetscCall(AXPBYCyclic_CUPM_Private(m, oldest, next, alpha, x_, beta, &y_[row % m], ldy));
530 #else
531     SETERRQ(PetscObjectComm((PetscObject)x), PETSC_ERR_PLIB, "PetscMemTypeDevice needs either CUDA or HIP support");
532 #endif
533   } else if (m_local == m) on_host = PETSC_TRUE;
534   PetscCall(MatDenseRestoreArrayAndMemType(Y, &y_));
535   PetscCall(VecRestoreArrayReadAndMemType(x, &x_));
536   if (on_host) {
537     PetscCall(VecGetArrayRead(x, &x_));
538     PetscCall(MatDenseGetArray(Y, &y_));
539     PetscCall(AXPBYCylic_Private(m, oldest, next, alpha, x_, beta, &y_[row % m], ldy));
540     PetscCall(MatDenseRestoreArray(Y, &y_));
541     PetscCall(VecRestoreArrayRead(x, &x_));
542   }
543   PetscCall(PetscLogEventEnd(AXPBY_Cyc, NULL, NULL, NULL, NULL));
544   PetscFunctionReturn(PETSC_SUCCESS);
545 }
546 
MatMultColumnRange(Mat A,Vec xx,Vec yy,PetscInt c_start,PetscInt c_end)547 PETSC_INTERN PetscErrorCode MatMultColumnRange(Mat A, Vec xx, Vec yy, PetscInt c_start, PetscInt c_end)
548 {
549   PetscFunctionBegin;
550   PetscCall(PetscLogEventBegin(MAT_Mult, A, NULL, NULL, NULL));
551   PetscUseMethod(A, "MatMultColumnRange_C", (Mat, Vec, Vec, PetscInt, PetscInt), (A, xx, yy, c_start, c_end));
552   PetscCall(PetscLogEventEnd(MAT_Mult, A, NULL, NULL, NULL));
553   PetscFunctionReturn(PETSC_SUCCESS);
554 }
555 
MatMultAddColumnRange(Mat A,Vec xx,Vec zz,Vec yy,PetscInt c_start,PetscInt c_end)556 PETSC_INTERN PetscErrorCode MatMultAddColumnRange(Mat A, Vec xx, Vec zz, Vec yy, PetscInt c_start, PetscInt c_end)
557 {
558   PetscFunctionBegin;
559   PetscCall(PetscLogEventBegin(MAT_MultAdd, A, NULL, NULL, NULL));
560   PetscUseMethod(A, "MatMultAddColumnRange_C", (Mat, Vec, Vec, Vec, PetscInt, PetscInt), (A, xx, zz, yy, c_start, c_end));
561   PetscCall(PetscLogEventEnd(MAT_MultAdd, A, NULL, NULL, NULL));
562   PetscFunctionReturn(PETSC_SUCCESS);
563 }
564 
MatMultHermitianTransposeColumnRange(Mat A,Vec xx,Vec yy,PetscInt c_start,PetscInt c_end)565 PETSC_INTERN PetscErrorCode MatMultHermitianTransposeColumnRange(Mat A, Vec xx, Vec yy, PetscInt c_start, PetscInt c_end)
566 {
567   PetscFunctionBegin;
568   PetscCall(PetscLogEventBegin(MAT_MultHermitianTranspose, A, NULL, NULL, NULL));
569   PetscUseMethod(A, "MatMultHermitianTransposeColumnRange_C", (Mat, Vec, Vec, PetscInt, PetscInt), (A, xx, yy, c_start, c_end));
570   PetscCall(PetscLogEventEnd(MAT_MultHermitianTranspose, A, NULL, NULL, NULL));
571   PetscFunctionReturn(PETSC_SUCCESS);
572 }
573 
MatMultHermitianTransposeAddColumnRange(Mat A,Vec xx,Vec zz,Vec yy,PetscInt c_start,PetscInt c_end)574 PETSC_INTERN PetscErrorCode MatMultHermitianTransposeAddColumnRange(Mat A, Vec xx, Vec zz, Vec yy, PetscInt c_start, PetscInt c_end)
575 {
576   PetscFunctionBegin;
577   PetscCall(PetscLogEventBegin(MAT_MultHermitianTransposeAdd, A, NULL, NULL, NULL));
578   PetscUseMethod(A, "MatMultHermitianTransposeAddColumnRange_C", (Mat, Vec, Vec, Vec, PetscInt, PetscInt), (A, xx, zz, yy, c_start, c_end));
579   PetscCall(PetscLogEventEnd(MAT_MultHermitianTransposeAdd, A, NULL, NULL, NULL));
580   PetscFunctionReturn(PETSC_SUCCESS);
581 }
582