xref: /petsc/src/ksp/ksp/utils/lmvm/dense/cd_utils.c (revision 58bddbc0aeb8e2276be3739270a4176cb222ba3a)
1 #include <../src/ksp/ksp/utils/lmvm/dense/denseqn.h> /*I "petscksp.h" I*/
2 #include <petscblaslapack.h>
3 #include <petscmat.h>
4 #include <petscsys.h>
5 #include <petscsystypes.h>
6 #include <petscis.h>
7 #include <petscoptions.h>
8 #include <petscdevice.h>
9 #include <petsc/private/deviceimpl.h>
10 
11 const char *const MatLMVMDenseTypes[] = {"reorder", "inplace", "MatLMVMDenseType", "MAT_LMVM_DENSE_", NULL};
12 
VecCyclicShift(Mat B,Vec X,PetscInt d,Vec cyclic_work_vec)13 PETSC_INTERN PetscErrorCode VecCyclicShift(Mat B, Vec X, PetscInt d, Vec cyclic_work_vec)
14 {
15   Mat_LMVM          *lmvm = (Mat_LMVM *)B->data;
16   PetscInt           m    = lmvm->m;
17   PetscInt           n;
18   const PetscScalar *src;
19   PetscScalar       *dest;
20   PetscMemType       src_memtype;
21   PetscMemType       dest_memtype;
22 
23   PetscFunctionBegin;
24   PetscCall(VecGetLocalSize(X, &n));
25   if (!cyclic_work_vec) PetscCall(VecDuplicate(X, &cyclic_work_vec));
26   PetscCall(VecCopy(X, cyclic_work_vec));
27   PetscCall(VecGetArrayReadAndMemType(cyclic_work_vec, &src, &src_memtype));
28   PetscCall(VecGetArrayWriteAndMemType(X, &dest, &dest_memtype));
29   if (n == 0) { /* no work on this process */
30     PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
31     PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
32     PetscFunctionReturn(PETSC_SUCCESS);
33   }
34   PetscAssert(src_memtype == dest_memtype, PETSC_COMM_SELF, PETSC_ERR_PLIB, "memtype of duplicate does not match");
35   if (PetscMemTypeHost(src_memtype)) {
36     PetscCall(PetscArraycpy(dest, &src[d], m - d));
37     PetscCall(PetscArraycpy(&dest[m - d], src, d));
38   } else {
39     PetscDeviceContext dctx;
40 
41     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
42     PetscCall(PetscDeviceRegisterMemory(dest, dest_memtype, m * sizeof(*dest)));
43     PetscCall(PetscDeviceRegisterMemory(src, src_memtype, m * sizeof(*src)));
44     PetscCall(PetscDeviceArrayCopy(dctx, dest, &src[d], m - d));
45     PetscCall(PetscDeviceArrayCopy(dctx, &dest[m - d], src, d));
46   }
47   PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
48   PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
49   PetscFunctionReturn(PETSC_SUCCESS);
50 }
51 
recycle_index(PetscInt m,PetscInt idx)52 static inline PetscInt recycle_index(PetscInt m, PetscInt idx)
53 {
54   return idx % m;
55 }
56 
oldest_update(PetscInt m,PetscInt idx)57 static inline PetscInt oldest_update(PetscInt m, PetscInt idx)
58 {
59   return PetscMax(0, idx - m);
60 }
61 
VecRecycleOrderToHistoryOrder(Mat B,Vec X,PetscInt num_updates,Vec cyclic_work_vec)62 PETSC_INTERN PetscErrorCode VecRecycleOrderToHistoryOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
63 {
64   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
65   PetscInt  m    = lmvm->m;
66   PetscInt  oldest_index;
67 
68   PetscFunctionBegin;
69   oldest_index = recycle_index(m, oldest_update(m, num_updates));
70   if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in history order */
71   PetscCall(VecCyclicShift(B, X, oldest_index, cyclic_work_vec));
72   PetscFunctionReturn(PETSC_SUCCESS);
73 }
74 
VecHistoryOrderToRecycleOrder(Mat B,Vec X,PetscInt num_updates,Vec cyclic_work_vec)75 PETSC_INTERN PetscErrorCode VecHistoryOrderToRecycleOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
76 {
77   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
78   PetscInt  m    = lmvm->m;
79   PetscInt  oldest_index;
80 
81   PetscFunctionBegin;
82   oldest_index = recycle_index(m, oldest_update(m, num_updates));
83   if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in recycle order */
84   PetscCall(VecCyclicShift(B, X, m - oldest_index, cyclic_work_vec));
85   PetscFunctionReturn(PETSC_SUCCESS);
86 }
87 
MatUpperTriangularSolveInPlace_Internal(MatLMVMDenseType type,PetscMemType memtype,PetscBool hermitian_transpose,PetscInt m,PetscInt oldest,PetscInt next,const PetscScalar A[],PetscInt lda,PetscScalar x[],PetscInt stride)88 PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_Internal(MatLMVMDenseType type, PetscMemType memtype, PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
89 {
90   PetscInt oldest_index = oldest % m;
91   PetscInt next_index   = (next - 1) % m + 1;
92   PetscInt N            = next - oldest;
93 
94   PetscFunctionBegin;
95   /* if oldest_index == 0, the two strategies are equivalent, redirect to the simpler one */
96   if (oldest_index == 0) type = MAT_LMVM_DENSE_REORDER;
97   switch (type) {
98   case MAT_LMVM_DENSE_REORDER:
99     if (PetscMemTypeHost(memtype)) {
100       PetscBLASInt n, lda_blas, one = 1;
101       PetscCall(PetscBLASIntCast(N, &n));
102       PetscCall(PetscBLASIntCast(lda, &lda_blas));
103       PetscCallBLAS("BLAStrsv", BLAStrsv_("U", hermitian_transpose ? "C" : "N", "NotUnitTriangular", &n, A, &lda_blas, x, &one));
104       PetscCall(PetscLogFlops(1.0 * n * n));
105 #if defined(PETSC_HAVE_CUPM)
106     } else if (PetscMemTypeDevice(memtype)) {
107       PetscCall(MatUpperTriangularSolveInPlace_CUPM(hermitian_transpose, N, A, lda, x, 1));
108 #endif
109     } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
110     break;
111   case MAT_LMVM_DENSE_INPLACE:
112     if (PetscMemTypeHost(memtype)) {
113       PetscBLASInt n_old, n_new, lda_blas, one = 1;
114       PetscScalar  minus_one = -1.0;
115       PetscScalar  sone      = 1.0;
116       PetscCall(PetscBLASIntCast(m - oldest_index, &n_old));
117       PetscCall(PetscBLASIntCast(next_index, &n_new));
118       PetscCall(PetscBLASIntCast(lda, &lda_blas));
119       if (!hermitian_transpose) {
120         if (n_new > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
121         if (n_new > 0 && n_old > 0) PetscCallBLAS("BLASgemv", BLASgemv_("N", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, x, &one, &sone, &x[oldest_index], &one));
122         if (n_old > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
123       } else {
124         if (n_old > 0) {
125           PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
126           if (n_new > 0 && n_old > 0) PetscCallBLAS("BLASgemv", BLASgemv_("C", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, &x[oldest_index], &one, &sone, x, &one));
127         }
128         if (n_new > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
129       }
130       PetscCall(PetscLogFlops(1.0 * N * N));
131 #if defined(PETSC_HAVE_CUPM)
132     } else if (PetscMemTypeDevice(memtype)) {
133       PetscCall(MatUpperTriangularSolveInPlaceCyclic_CUPM(hermitian_transpose, m, oldest, next, A, lda, x, stride));
134 #endif
135     } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
136     break;
137   default:
138     PetscUnreachable();
139   }
140   PetscFunctionReturn(PETSC_SUCCESS);
141 }
142 
MatUpperTriangularSolveInPlace(Mat B,Mat Amat,Vec X,PetscBool hermitian_transpose,PetscInt num_updates,MatLMVMDenseType strategy)143 PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace(Mat B, Mat Amat, Vec X, PetscBool hermitian_transpose, PetscInt num_updates, MatLMVMDenseType strategy)
144 {
145   Mat_LMVM          *lmvm = (Mat_LMVM *)B->data;
146   PetscInt           m    = lmvm->m;
147   PetscInt           h, local_n;
148   PetscInt           lda;
149   PetscScalar       *x;
150   PetscMemType       memtype_r, memtype_x;
151   const PetscScalar *A;
152 
153   PetscFunctionBegin;
154   h = num_updates - oldest_update(m, num_updates);
155   if (!h) PetscFunctionReturn(PETSC_SUCCESS);
156   PetscCall(VecGetLocalSize(X, &local_n));
157   PetscCall(VecGetArrayAndMemType(X, &x, &memtype_x));
158   PetscCall(MatDenseGetArrayReadAndMemType(Amat, &A, &memtype_r));
159   if (!local_n) {
160     PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
161     PetscCall(VecRestoreArrayAndMemType(X, &x));
162     PetscFunctionReturn(PETSC_SUCCESS);
163   }
164   PetscAssert(memtype_x == memtype_r, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Incompatible device pointers");
165   PetscCall(MatDenseGetLDA(Amat, &lda));
166   PetscCall(MatUpperTriangularSolveInPlace_Internal(strategy, memtype_x, hermitian_transpose, m, oldest_update(m, num_updates), num_updates, A, lda, x, 1));
167   PetscCall(VecRestoreArrayWriteAndMemType(X, &x));
168   PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
169   PetscFunctionReturn(PETSC_SUCCESS);
170 }
171 
172 /* Shifts R[end-m_keep:end,end-m_keep:end] to R[0:m_keep, 0:m_keep] */
173 
MatMove_LR3(Mat B,Mat R,PetscInt m_keep)174 PETSC_INTERN PetscErrorCode MatMove_LR3(Mat B, Mat R, PetscInt m_keep)
175 {
176   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
177   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
178   PetscInt  M;
179   Mat       mat_local, local_sub, local_temp, temp_sub;
180 
181   PetscFunctionBegin;
182   if (!lqn->temp_mat) PetscCall(MatDuplicate(R, MAT_SHARE_NONZERO_PATTERN, &lqn->temp_mat));
183   PetscCall(MatGetLocalSize(R, &M, NULL));
184   if (M == 0) PetscFunctionReturn(PETSC_SUCCESS);
185 
186   PetscCall(MatDenseGetLocalMatrix(R, &mat_local));
187   PetscCall(MatDenseGetLocalMatrix(lqn->temp_mat, &local_temp));
188   PetscCall(MatDenseGetSubMatrix(mat_local, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &local_sub));
189   PetscCall(MatDenseGetSubMatrix(local_temp, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &temp_sub));
190   PetscCall(MatCopy(local_sub, temp_sub, SAME_NONZERO_PATTERN));
191   PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
192   PetscCall(MatDenseGetSubMatrix(mat_local, 0, m_keep, 0, m_keep, &local_sub));
193   PetscCall(MatCopy(temp_sub, local_sub, SAME_NONZERO_PATTERN));
194   PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
195   PetscCall(MatDenseRestoreSubMatrix(local_temp, &temp_sub));
196   PetscFunctionReturn(PETSC_SUCCESS);
197 }
198