xref: /petsc/src/ksp/ksp/utils/lmvm/dense/cd_cupm/cd_cupm.cxx (revision 58bddbc0aeb8e2276be3739270a4176cb222ba3a)
1 #include "../denseqn.h"
2 #include <petsc/private/cupminterface.hpp>
3 #include <petsc/private/cupmobject.hpp>
4 
5 namespace Petsc
6 {
7 
8 namespace device
9 {
10 
11 namespace cupm
12 {
13 
14 namespace impl
15 {
16 
17 template <DeviceType T>
18 struct UpperTriangular : CUPMObject<T> {
19   PETSC_CUPMOBJECT_HEADER(T);
20 
21   static PetscErrorCode SolveInPlace(PetscDeviceContext, PetscBool, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
22   static PetscErrorCode SolveInPlaceCyclic(PetscDeviceContext, PetscBool, PetscInt, PetscInt, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
23 };
24 
25 template <DeviceType T>
SolveInPlace(PetscDeviceContext dctx,PetscBool hermitian_transpose,PetscInt N,const PetscScalar A[],PetscInt lda,PetscScalar x[],PetscInt stride)26 PetscErrorCode UpperTriangular<T>::SolveInPlace(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt N, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
27 {
28   cupmBlasInt_t    n;
29   cupmBlasHandle_t handle;
30   auto             A_ = cupmScalarPtrCast(A);
31   auto             x_ = cupmScalarPtrCast(x);
32 
33   PetscFunctionBegin;
34   if (!N) PetscFunctionReturn(PETSC_SUCCESS);
35   PetscCall(PetscCUPMBlasIntCast(N, &n));
36   PetscCall(GetHandlesFrom_(dctx, &handle));
37   PetscCall(PetscLogGpuTimeBegin());
38   PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, hermitian_transpose ? CUPMBLAS_OP_C : CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n, A_, lda, x_, stride));
39   PetscCall(PetscLogGpuTimeEnd());
40 
41   PetscCall(PetscLogGpuFlops(1.0 * N * N));
42   PetscFunctionReturn(PETSC_SUCCESS);
43 }
44 
45 template <DeviceType T>
SolveInPlaceCyclic(PetscDeviceContext dctx,PetscBool hermitian_transpose,PetscInt m,PetscInt oldest,PetscInt next,const PetscScalar A[],PetscInt lda,PetscScalar x[],PetscInt stride)46 PetscErrorCode UpperTriangular<T>::SolveInPlaceCyclic(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
47 {
48   PetscInt              N            = next - oldest;
49   PetscInt              oldest_index = oldest % m;
50   PetscInt              next_index   = next % m;
51   cupmBlasInt_t         n_old, n_new;
52   cupmBlasPointerMode_t pointer_mode;
53   cupmBlasHandle_t      handle;
54   auto                  sone      = cupmScalarCast(1.0);
55   auto                  minus_one = cupmScalarCast(-1.0);
56   auto                  A_        = cupmScalarPtrCast(A);
57   auto                  x_        = cupmScalarPtrCast(x);
58 
59   PetscFunctionBegin;
60   if (!N) PetscFunctionReturn(PETSC_SUCCESS);
61   PetscCall(PetscCUPMBlasIntCast(m - oldest_index, &n_old));
62   PetscCall(PetscCUPMBlasIntCast(next_index, &n_new));
63   PetscCall(GetHandlesFrom_(dctx, &handle));
64   PetscCall(PetscLogGpuTimeBegin());
65   PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
66   PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
67   if (!hermitian_transpose) {
68     if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_new, A_, lda, x_, stride));
69     if (n_new > 0 && n_old > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_N, n_old, n_new, &minus_one, &A_[oldest_index], lda, x_, stride, &sone, &x_[oldest_index], stride));
70     if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_old, &A_[oldest_index * (lda + 1)], lda, &x_[oldest_index], stride));
71   } else {
72     if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_old, &A_[oldest_index * (lda + 1)], lda, &x_[oldest_index], stride));
73     if (n_new > 0 && n_old > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_C, n_old, n_new, &minus_one, &A_[oldest_index], lda, &x_[oldest_index], stride, &sone, x_, stride));
74     if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_new, A_, lda, x_, stride));
75   }
76   PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
77   PetscCall(PetscLogGpuTimeEnd());
78 
79   PetscCall(PetscLogGpuFlops(1.0 * N * N));
80   PetscFunctionReturn(PETSC_SUCCESS);
81 }
82 
83 #if PetscDefined(HAVE_CUDA)
84 template struct UpperTriangular<DeviceType::CUDA>;
85 #endif
86 
87 #if PetscDefined(HAVE_HIP)
88 template struct UpperTriangular<DeviceType::HIP>;
89 #endif
90 
91 } // namespace impl
92 
93 } // namespace cupm
94 
95 } // namespace device
96 
97 } // namespace Petsc
98 
MatUpperTriangularSolveInPlace_CUPM(PetscBool hermitian_transpose,PetscInt n,const PetscScalar A[],PetscInt lda,PetscScalar x[],PetscInt stride)99 PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_CUPM(PetscBool hermitian_transpose, PetscInt n, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
100 {
101   using ::Petsc::device::cupm::impl::UpperTriangular;
102   using ::Petsc::device::cupm::DeviceType;
103   PetscDeviceContext dctx;
104   PetscDeviceType    device_type;
105 
106   PetscFunctionBegin;
107   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
108   PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
109   switch (device_type) {
110 #if PetscDefined(HAVE_CUDA)
111   case PETSC_DEVICE_CUDA:
112     PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
113     break;
114 #endif
115 #if PetscDefined(HAVE_HIP)
116   case PETSC_DEVICE_HIP:
117     PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
118     break;
119 #endif
120   default:
121     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
122   }
123   PetscFunctionReturn(PETSC_SUCCESS);
124 }
125 
MatUpperTriangularSolveInPlaceCyclic_CUPM(PetscBool hermitian_transpose,PetscInt m,PetscInt oldest,PetscInt next,const PetscScalar A[],PetscInt lda,PetscScalar x[],PetscInt stride)126 PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlaceCyclic_CUPM(PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
127 {
128   using ::Petsc::device::cupm::impl::UpperTriangular;
129   using ::Petsc::device::cupm::DeviceType;
130   PetscDeviceContext dctx;
131   PetscDeviceType    device_type;
132 
133   PetscFunctionBegin;
134   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
135   PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
136   switch (device_type) {
137 #if PetscDefined(HAVE_CUDA)
138   case PETSC_DEVICE_CUDA:
139     PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlaceCyclic(dctx, hermitian_transpose, m, oldest, next, A, lda, x, stride));
140     break;
141 #endif
142 #if PetscDefined(HAVE_HIP)
143   case PETSC_DEVICE_HIP:
144     PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlaceCyclic(dctx, hermitian_transpose, m, oldest, next, A, lda, x, stride));
145     break;
146 #endif
147   default:
148     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
149   }
150   PetscFunctionReturn(PETSC_SUCCESS);
151 }
152