1 #include <../src/ksp/ksp/utils/lmvm/dense/denseqn.h> /*I "petscksp.h" I*/
2 #include <petscblaslapack.h>
3 #include <petscmat.h>
4 #include <petscsys.h>
5 #include <petscsystypes.h>
6 #include <petscis.h>
7 #include <petscoptions.h>
8 #include <petscdevice.h>
9 #include <petsc/private/deviceimpl.h>
10
11 const char *const MatLMVMDenseTypes[] = {"reorder", "inplace", "MatLMVMDenseType", "MAT_LMVM_DENSE_", NULL};
12
VecCyclicShift(Mat B,Vec X,PetscInt d,Vec cyclic_work_vec)13 PETSC_INTERN PetscErrorCode VecCyclicShift(Mat B, Vec X, PetscInt d, Vec cyclic_work_vec)
14 {
15 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
16 PetscInt m = lmvm->m;
17 PetscInt n;
18 const PetscScalar *src;
19 PetscScalar *dest;
20 PetscMemType src_memtype;
21 PetscMemType dest_memtype;
22
23 PetscFunctionBegin;
24 PetscCall(VecGetLocalSize(X, &n));
25 if (!cyclic_work_vec) PetscCall(VecDuplicate(X, &cyclic_work_vec));
26 PetscCall(VecCopy(X, cyclic_work_vec));
27 PetscCall(VecGetArrayReadAndMemType(cyclic_work_vec, &src, &src_memtype));
28 PetscCall(VecGetArrayWriteAndMemType(X, &dest, &dest_memtype));
29 if (n == 0) { /* no work on this process */
30 PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
31 PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
32 PetscFunctionReturn(PETSC_SUCCESS);
33 }
34 PetscAssert(src_memtype == dest_memtype, PETSC_COMM_SELF, PETSC_ERR_PLIB, "memtype of duplicate does not match");
35 if (PetscMemTypeHost(src_memtype)) {
36 PetscCall(PetscArraycpy(dest, &src[d], m - d));
37 PetscCall(PetscArraycpy(&dest[m - d], src, d));
38 } else {
39 PetscDeviceContext dctx;
40
41 PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
42 PetscCall(PetscDeviceRegisterMemory(dest, dest_memtype, m * sizeof(*dest)));
43 PetscCall(PetscDeviceRegisterMemory(src, src_memtype, m * sizeof(*src)));
44 PetscCall(PetscDeviceArrayCopy(dctx, dest, &src[d], m - d));
45 PetscCall(PetscDeviceArrayCopy(dctx, &dest[m - d], src, d));
46 }
47 PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
48 PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
49 PetscFunctionReturn(PETSC_SUCCESS);
50 }
51
recycle_index(PetscInt m,PetscInt idx)52 static inline PetscInt recycle_index(PetscInt m, PetscInt idx)
53 {
54 return idx % m;
55 }
56
oldest_update(PetscInt m,PetscInt idx)57 static inline PetscInt oldest_update(PetscInt m, PetscInt idx)
58 {
59 return PetscMax(0, idx - m);
60 }
61
VecRecycleOrderToHistoryOrder(Mat B,Vec X,PetscInt num_updates,Vec cyclic_work_vec)62 PETSC_INTERN PetscErrorCode VecRecycleOrderToHistoryOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
63 {
64 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
65 PetscInt m = lmvm->m;
66 PetscInt oldest_index;
67
68 PetscFunctionBegin;
69 oldest_index = recycle_index(m, oldest_update(m, num_updates));
70 if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in history order */
71 PetscCall(VecCyclicShift(B, X, oldest_index, cyclic_work_vec));
72 PetscFunctionReturn(PETSC_SUCCESS);
73 }
74
VecHistoryOrderToRecycleOrder(Mat B,Vec X,PetscInt num_updates,Vec cyclic_work_vec)75 PETSC_INTERN PetscErrorCode VecHistoryOrderToRecycleOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
76 {
77 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
78 PetscInt m = lmvm->m;
79 PetscInt oldest_index;
80
81 PetscFunctionBegin;
82 oldest_index = recycle_index(m, oldest_update(m, num_updates));
83 if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in recycle order */
84 PetscCall(VecCyclicShift(B, X, m - oldest_index, cyclic_work_vec));
85 PetscFunctionReturn(PETSC_SUCCESS);
86 }
87
MatUpperTriangularSolveInPlace_Internal(MatLMVMDenseType type,PetscMemType memtype,PetscBool hermitian_transpose,PetscInt m,PetscInt oldest,PetscInt next,const PetscScalar A[],PetscInt lda,PetscScalar x[],PetscInt stride)88 PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_Internal(MatLMVMDenseType type, PetscMemType memtype, PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
89 {
90 PetscInt oldest_index = oldest % m;
91 PetscInt next_index = (next - 1) % m + 1;
92 PetscInt N = next - oldest;
93
94 PetscFunctionBegin;
95 /* if oldest_index == 0, the two strategies are equivalent, redirect to the simpler one */
96 if (oldest_index == 0) type = MAT_LMVM_DENSE_REORDER;
97 switch (type) {
98 case MAT_LMVM_DENSE_REORDER:
99 if (PetscMemTypeHost(memtype)) {
100 PetscBLASInt n, lda_blas, one = 1;
101 PetscCall(PetscBLASIntCast(N, &n));
102 PetscCall(PetscBLASIntCast(lda, &lda_blas));
103 PetscCallBLAS("BLAStrsv", BLAStrsv_("U", hermitian_transpose ? "C" : "N", "NotUnitTriangular", &n, A, &lda_blas, x, &one));
104 PetscCall(PetscLogFlops(1.0 * n * n));
105 #if defined(PETSC_HAVE_CUPM)
106 } else if (PetscMemTypeDevice(memtype)) {
107 PetscCall(MatUpperTriangularSolveInPlace_CUPM(hermitian_transpose, N, A, lda, x, 1));
108 #endif
109 } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
110 break;
111 case MAT_LMVM_DENSE_INPLACE:
112 if (PetscMemTypeHost(memtype)) {
113 PetscBLASInt n_old, n_new, lda_blas, one = 1;
114 PetscScalar minus_one = -1.0;
115 PetscScalar sone = 1.0;
116 PetscCall(PetscBLASIntCast(m - oldest_index, &n_old));
117 PetscCall(PetscBLASIntCast(next_index, &n_new));
118 PetscCall(PetscBLASIntCast(lda, &lda_blas));
119 if (!hermitian_transpose) {
120 if (n_new > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
121 if (n_new > 0 && n_old > 0) PetscCallBLAS("BLASgemv", BLASgemv_("N", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, x, &one, &sone, &x[oldest_index], &one));
122 if (n_old > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
123 } else {
124 if (n_old > 0) {
125 PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
126 if (n_new > 0 && n_old > 0) PetscCallBLAS("BLASgemv", BLASgemv_("C", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, &x[oldest_index], &one, &sone, x, &one));
127 }
128 if (n_new > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
129 }
130 PetscCall(PetscLogFlops(1.0 * N * N));
131 #if defined(PETSC_HAVE_CUPM)
132 } else if (PetscMemTypeDevice(memtype)) {
133 PetscCall(MatUpperTriangularSolveInPlaceCyclic_CUPM(hermitian_transpose, m, oldest, next, A, lda, x, stride));
134 #endif
135 } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
136 break;
137 default:
138 PetscUnreachable();
139 }
140 PetscFunctionReturn(PETSC_SUCCESS);
141 }
142
MatUpperTriangularSolveInPlace(Mat B,Mat Amat,Vec X,PetscBool hermitian_transpose,PetscInt num_updates,MatLMVMDenseType strategy)143 PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace(Mat B, Mat Amat, Vec X, PetscBool hermitian_transpose, PetscInt num_updates, MatLMVMDenseType strategy)
144 {
145 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
146 PetscInt m = lmvm->m;
147 PetscInt h, local_n;
148 PetscInt lda;
149 PetscScalar *x;
150 PetscMemType memtype_r, memtype_x;
151 const PetscScalar *A;
152
153 PetscFunctionBegin;
154 h = num_updates - oldest_update(m, num_updates);
155 if (!h) PetscFunctionReturn(PETSC_SUCCESS);
156 PetscCall(VecGetLocalSize(X, &local_n));
157 PetscCall(VecGetArrayAndMemType(X, &x, &memtype_x));
158 PetscCall(MatDenseGetArrayReadAndMemType(Amat, &A, &memtype_r));
159 if (!local_n) {
160 PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
161 PetscCall(VecRestoreArrayAndMemType(X, &x));
162 PetscFunctionReturn(PETSC_SUCCESS);
163 }
164 PetscAssert(memtype_x == memtype_r, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Incompatible device pointers");
165 PetscCall(MatDenseGetLDA(Amat, &lda));
166 PetscCall(MatUpperTriangularSolveInPlace_Internal(strategy, memtype_x, hermitian_transpose, m, oldest_update(m, num_updates), num_updates, A, lda, x, 1));
167 PetscCall(VecRestoreArrayWriteAndMemType(X, &x));
168 PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
169 PetscFunctionReturn(PETSC_SUCCESS);
170 }
171
172 /* Shifts R[end-m_keep:end,end-m_keep:end] to R[0:m_keep, 0:m_keep] */
173
MatMove_LR3(Mat B,Mat R,PetscInt m_keep)174 PETSC_INTERN PetscErrorCode MatMove_LR3(Mat B, Mat R, PetscInt m_keep)
175 {
176 Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
177 Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
178 PetscInt M;
179 Mat mat_local, local_sub, local_temp, temp_sub;
180
181 PetscFunctionBegin;
182 if (!lqn->temp_mat) PetscCall(MatDuplicate(R, MAT_SHARE_NONZERO_PATTERN, &lqn->temp_mat));
183 PetscCall(MatGetLocalSize(R, &M, NULL));
184 if (M == 0) PetscFunctionReturn(PETSC_SUCCESS);
185
186 PetscCall(MatDenseGetLocalMatrix(R, &mat_local));
187 PetscCall(MatDenseGetLocalMatrix(lqn->temp_mat, &local_temp));
188 PetscCall(MatDenseGetSubMatrix(mat_local, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &local_sub));
189 PetscCall(MatDenseGetSubMatrix(local_temp, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &temp_sub));
190 PetscCall(MatCopy(local_sub, temp_sub, SAME_NONZERO_PATTERN));
191 PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
192 PetscCall(MatDenseGetSubMatrix(mat_local, 0, m_keep, 0, m_keep, &local_sub));
193 PetscCall(MatCopy(temp_sub, local_sub, SAME_NONZERO_PATTERN));
194 PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
195 PetscCall(MatDenseRestoreSubMatrix(local_temp, &temp_sub));
196 PetscFunctionReturn(PETSC_SUCCESS);
197 }
198