1 #include "../denseqn.h"
2 #include <petsc/private/cupminterface.hpp>
3 #include <petsc/private/cupmobject.hpp>
4
5 namespace Petsc
6 {
7
8 namespace device
9 {
10
11 namespace cupm
12 {
13
14 namespace impl
15 {
16
17 template <DeviceType T>
18 struct UpperTriangular : CUPMObject<T> {
19 PETSC_CUPMOBJECT_HEADER(T);
20
21 static PetscErrorCode SolveInPlace(PetscDeviceContext, PetscBool, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
22 static PetscErrorCode SolveInPlaceCyclic(PetscDeviceContext, PetscBool, PetscInt, PetscInt, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
23 };
24
25 template <DeviceType T>
SolveInPlace(PetscDeviceContext dctx,PetscBool hermitian_transpose,PetscInt N,const PetscScalar A[],PetscInt lda,PetscScalar x[],PetscInt stride)26 PetscErrorCode UpperTriangular<T>::SolveInPlace(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt N, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
27 {
28 cupmBlasInt_t n;
29 cupmBlasHandle_t handle;
30 auto A_ = cupmScalarPtrCast(A);
31 auto x_ = cupmScalarPtrCast(x);
32
33 PetscFunctionBegin;
34 if (!N) PetscFunctionReturn(PETSC_SUCCESS);
35 PetscCall(PetscCUPMBlasIntCast(N, &n));
36 PetscCall(GetHandlesFrom_(dctx, &handle));
37 PetscCall(PetscLogGpuTimeBegin());
38 PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, hermitian_transpose ? CUPMBLAS_OP_C : CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n, A_, lda, x_, stride));
39 PetscCall(PetscLogGpuTimeEnd());
40
41 PetscCall(PetscLogGpuFlops(1.0 * N * N));
42 PetscFunctionReturn(PETSC_SUCCESS);
43 }
44
45 template <DeviceType T>
SolveInPlaceCyclic(PetscDeviceContext dctx,PetscBool hermitian_transpose,PetscInt m,PetscInt oldest,PetscInt next,const PetscScalar A[],PetscInt lda,PetscScalar x[],PetscInt stride)46 PetscErrorCode UpperTriangular<T>::SolveInPlaceCyclic(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
47 {
48 PetscInt N = next - oldest;
49 PetscInt oldest_index = oldest % m;
50 PetscInt next_index = next % m;
51 cupmBlasInt_t n_old, n_new;
52 cupmBlasPointerMode_t pointer_mode;
53 cupmBlasHandle_t handle;
54 auto sone = cupmScalarCast(1.0);
55 auto minus_one = cupmScalarCast(-1.0);
56 auto A_ = cupmScalarPtrCast(A);
57 auto x_ = cupmScalarPtrCast(x);
58
59 PetscFunctionBegin;
60 if (!N) PetscFunctionReturn(PETSC_SUCCESS);
61 PetscCall(PetscCUPMBlasIntCast(m - oldest_index, &n_old));
62 PetscCall(PetscCUPMBlasIntCast(next_index, &n_new));
63 PetscCall(GetHandlesFrom_(dctx, &handle));
64 PetscCall(PetscLogGpuTimeBegin());
65 PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
66 PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
67 if (!hermitian_transpose) {
68 if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_new, A_, lda, x_, stride));
69 if (n_new > 0 && n_old > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_N, n_old, n_new, &minus_one, &A_[oldest_index], lda, x_, stride, &sone, &x_[oldest_index], stride));
70 if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_old, &A_[oldest_index * (lda + 1)], lda, &x_[oldest_index], stride));
71 } else {
72 if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_old, &A_[oldest_index * (lda + 1)], lda, &x_[oldest_index], stride));
73 if (n_new > 0 && n_old > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_C, n_old, n_new, &minus_one, &A_[oldest_index], lda, &x_[oldest_index], stride, &sone, x_, stride));
74 if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_new, A_, lda, x_, stride));
75 }
76 PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
77 PetscCall(PetscLogGpuTimeEnd());
78
79 PetscCall(PetscLogGpuFlops(1.0 * N * N));
80 PetscFunctionReturn(PETSC_SUCCESS);
81 }
82
83 #if PetscDefined(HAVE_CUDA)
84 template struct UpperTriangular<DeviceType::CUDA>;
85 #endif
86
87 #if PetscDefined(HAVE_HIP)
88 template struct UpperTriangular<DeviceType::HIP>;
89 #endif
90
91 } // namespace impl
92
93 } // namespace cupm
94
95 } // namespace device
96
97 } // namespace Petsc
98
MatUpperTriangularSolveInPlace_CUPM(PetscBool hermitian_transpose,PetscInt n,const PetscScalar A[],PetscInt lda,PetscScalar x[],PetscInt stride)99 PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_CUPM(PetscBool hermitian_transpose, PetscInt n, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
100 {
101 using ::Petsc::device::cupm::impl::UpperTriangular;
102 using ::Petsc::device::cupm::DeviceType;
103 PetscDeviceContext dctx;
104 PetscDeviceType device_type;
105
106 PetscFunctionBegin;
107 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
108 PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
109 switch (device_type) {
110 #if PetscDefined(HAVE_CUDA)
111 case PETSC_DEVICE_CUDA:
112 PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
113 break;
114 #endif
115 #if PetscDefined(HAVE_HIP)
116 case PETSC_DEVICE_HIP:
117 PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
118 break;
119 #endif
120 default:
121 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
122 }
123 PetscFunctionReturn(PETSC_SUCCESS);
124 }
125
MatUpperTriangularSolveInPlaceCyclic_CUPM(PetscBool hermitian_transpose,PetscInt m,PetscInt oldest,PetscInt next,const PetscScalar A[],PetscInt lda,PetscScalar x[],PetscInt stride)126 PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlaceCyclic_CUPM(PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
127 {
128 using ::Petsc::device::cupm::impl::UpperTriangular;
129 using ::Petsc::device::cupm::DeviceType;
130 PetscDeviceContext dctx;
131 PetscDeviceType device_type;
132
133 PetscFunctionBegin;
134 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
135 PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
136 switch (device_type) {
137 #if PetscDefined(HAVE_CUDA)
138 case PETSC_DEVICE_CUDA:
139 PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlaceCyclic(dctx, hermitian_transpose, m, oldest, next, A, lda, x, stride));
140 break;
141 #endif
142 #if PetscDefined(HAVE_HIP)
143 case PETSC_DEVICE_HIP:
144 PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlaceCyclic(dctx, hermitian_transpose, m, oldest, next, A, lda, x, stride));
145 break;
146 #endif
147 default:
148 SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
149 }
150 PetscFunctionReturn(PETSC_SUCCESS);
151 }
152