xref: /petsc/src/mat/impls/dense/mpi/cupm/matmpidensecupm.hpp (revision bcee047adeeb73090d7e36cc71e39fc287cdbb97)
1 #ifndef PETSCMATMPIDENSECUPM_HPP
2 #define PETSCMATMPIDENSECUPM_HPP
3 
4 #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/
5 #include <../src/mat/impls/dense/mpi/mpidense.h>
6 
7 #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp>
8 #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp>
9 
10 namespace Petsc
11 {
12 
13 namespace mat
14 {
15 
16 namespace cupm
17 {
18 
19 namespace impl
20 {
21 
22 template <device::cupm::DeviceType T>
23 class MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> {
24 public:
25   MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>);
26 
27 private:
28   PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept;
29   PETSC_NODISCARD static constexpr MatType       MATIMPLCUPM_() noexcept;
30 
31   static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept;
32 
33   template <bool to_host>
34   static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept;
35 
36 public:
37   PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept;
38 
39   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept;
40   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept;
41 
42   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept;
43   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept;
44 
45   static PetscErrorCode Create(Mat) noexcept;
46 
47   static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept;
48   static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept;
49   static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept;
50 
51   template <PetscMemType, PetscMemoryAccessMode>
52   static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
53   template <PetscMemType, PetscMemoryAccessMode>
54   static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
55 
56 private:
57   template <PetscMemType mtype, PetscMemoryAccessMode mode>
58   static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept
59   {
60     return GetArray<mtype, mode>(m, p);
61   }
62 
63   template <PetscMemType mtype, PetscMemoryAccessMode mode>
64   static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept
65   {
66     return RestoreArray<mtype, mode>(m, p);
67   }
68 
69 public:
70   template <PetscMemoryAccessMode>
71   static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept;
72   template <PetscMemoryAccessMode>
73   static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept;
74 
75   static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept;
76   static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept;
77   static PetscErrorCode ResetArray(Mat) noexcept;
78 
79   static PetscErrorCode Shift(Mat, PetscScalar) noexcept;
80 };
81 
82 } // namespace impl
83 
84 namespace
85 {
86 
87 // Declare this here so that the functions below can make use of it
88 template <device::cupm::DeviceType T>
89 inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept
90 {
91   PetscFunctionBegin;
92   PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate));
93   PetscFunctionReturn(PETSC_SUCCESS);
94 }
95 
96 } // anonymous namespace
97 
98 namespace impl
99 {
100 
101 // ==========================================================================================
102 // MatDense_MPI_CUPM -- Private API
103 // ==========================================================================================
104 
105 template <device::cupm::DeviceType T>
106 inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept
107 {
108   return static_cast<Mat_MPIDense *>(m->data);
109 }
110 
111 template <device::cupm::DeviceType T>
112 inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept
113 {
114   return MATMPIDENSECUPM();
115 }
116 
117 // ==========================================================================================
118 
119 template <device::cupm::DeviceType T>
120 inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
121 {
122   PetscFunctionBegin;
123   if (auto &mimplA = MatIMPLCast(A)->A) {
124     PetscCall(MatSetType(mimplA, MATSEQDENSECUPM()));
125     PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array));
126   } else {
127     PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx));
128   }
129   PetscFunctionReturn(PETSC_SUCCESS);
130 }
131 
132 template <device::cupm::DeviceType T>
133 template <bool to_host>
134 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept
135 {
136   PetscFunctionBegin;
137   if (reuse == MAT_INITIAL_MATRIX) {
138     PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat));
139   } else if (reuse == MAT_REUSE_MATRIX) {
140     PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN));
141   }
142   {
143     const auto B    = *newmat;
144     const auto pobj = PetscObjectCast(B);
145 
146     if (to_host) {
147       PetscCall(BindToCPU(B, PETSC_TRUE));
148     } else {
149       PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
150     }
151 
152     PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype));
153     PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM()));
154 
155     // ============================================================
156     // Composed Ops
157     // ============================================================
158     MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense);
159     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
160     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
161     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
162     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
163     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
164     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
165     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
166     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
167     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
168     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
169     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray);
170     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray);
171     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray);
172     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMSetPreallocation_C(), nullptr, SetPreallocation);
173 
174     if (to_host) {
175       if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A));
176       B->offloadmask = PETSC_OFFLOAD_CPU;
177     } else {
178       if (auto &m_A = MatIMPLCast(B)->A) {
179         PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A));
180         B->offloadmask = PETSC_OFFLOAD_BOTH;
181       } else {
182         B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
183       }
184       PetscCall(BindToCPU(B, PETSC_FALSE));
185     }
186 
187     // ============================================================
188     // Function Pointer Ops
189     // ============================================================
190     MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU);
191   }
192   PetscFunctionReturn(PETSC_SUCCESS);
193 }
194 
195 // ==========================================================================================
196 // MatDense_MPI_CUPM -- Public API
197 // ==========================================================================================
198 
199 template <device::cupm::DeviceType T>
200 inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept
201 {
202   return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C";
203 }
204 
205 template <device::cupm::DeviceType T>
206 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept
207 {
208   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C";
209 }
210 
211 template <device::cupm::DeviceType T>
212 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept
213 {
214   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C";
215 }
216 
217 template <device::cupm::DeviceType T>
218 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept
219 {
220   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C";
221 }
222 
223 template <device::cupm::DeviceType T>
224 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept
225 {
226   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C";
227 }
228 
229 // ==========================================================================================
230 
231 template <device::cupm::DeviceType T>
232 inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept
233 {
234   PetscFunctionBegin;
235   PetscCall(MatCreate_MPIDense(A));
236   PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A));
237   PetscFunctionReturn(PETSC_SUCCESS);
238 }
239 
240 // ==========================================================================================
241 
242 template <device::cupm::DeviceType T>
243 inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept
244 {
245   const auto mimpl = MatIMPLCast(A);
246   const auto pobj  = PetscObjectCast(A);
247 
248   PetscFunctionBegin;
249   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
250   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
251   if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost));
252   A->boundtocpu = usehost;
253   PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype));
254   if (!usehost) {
255     PetscBool iscupm;
256 
257     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm));
258     if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec));
259     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm));
260     if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat));
261   }
262 
263   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
264   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
265   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>);
266   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>);
267   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
268   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
269 
270   MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift);
271 
272   if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost));
273   PetscFunctionReturn(PETSC_SUCCESS);
274 }
275 
276 template <device::cupm::DeviceType T>
277 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
278 {
279   PetscFunctionBegin;
280   PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat));
281   PetscFunctionReturn(PETSC_SUCCESS);
282 }
283 
284 template <device::cupm::DeviceType T>
285 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
286 {
287   PetscFunctionBegin;
288   PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat));
289   PetscFunctionReturn(PETSC_SUCCESS);
290 }
291 
292 // ==========================================================================================
293 
294 template <device::cupm::DeviceType T>
295 template <PetscMemType, PetscMemoryAccessMode access>
296 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
297 {
298   PetscFunctionBegin;
299   PetscCall(MatDenseCUPMGetArray_Private<T, access>(MatIMPLCast(A)->A, array));
300   PetscFunctionReturn(PETSC_SUCCESS);
301 }
302 
303 template <device::cupm::DeviceType T>
304 template <PetscMemType, PetscMemoryAccessMode access>
305 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
306 {
307   PetscFunctionBegin;
308   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array));
309   PetscFunctionReturn(PETSC_SUCCESS);
310 }
311 
312 // ==========================================================================================
313 
314 template <device::cupm::DeviceType T>
315 template <PetscMemoryAccessMode access>
316 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept
317 {
318   using namespace vec::cupm;
319 
320   const auto mimpl   = MatIMPLCast(A);
321   const auto mimpl_A = mimpl->A;
322   const auto pobj    = PetscObjectCast(A);
323   PetscInt   lda;
324 
325   PetscFunctionBegin;
326   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
327   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
328   mimpl->vecinuse = col + 1;
329 
330   if (!mimpl->cvec) PetscCall(MatDenseCreateColumnVec_Private(A, &mimpl->cvec));
331 
332   PetscCall(MatDenseGetLDA(mimpl_A, &lda));
333   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
334   PetscCall(VecCUPMPlaceArrayAsync<T>(mimpl->cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda)));
335 
336   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(mimpl->cvec));
337   *v = mimpl->cvec;
338   PetscFunctionReturn(PETSC_SUCCESS);
339 }
340 
341 template <device::cupm::DeviceType T>
342 template <PetscMemoryAccessMode access>
343 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept
344 {
345   using namespace vec::cupm;
346 
347   const auto mimpl = MatIMPLCast(A);
348   const auto cvec  = mimpl->cvec;
349 
350   PetscFunctionBegin;
351   PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first");
352   PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector");
353   mimpl->vecinuse = 0;
354 
355   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
356   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec));
357   PetscCall(VecCUPMResetArrayAsync<T>(cvec));
358 
359   if (v) *v = nullptr;
360   PetscFunctionReturn(PETSC_SUCCESS);
361 }
362 
363 // ==========================================================================================
364 
365 template <device::cupm::DeviceType T>
366 inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept
367 {
368   const auto mimpl = MatIMPLCast(A);
369 
370   PetscFunctionBegin;
371   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
372   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
373   PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array));
374   PetscFunctionReturn(PETSC_SUCCESS);
375 }
376 
377 template <device::cupm::DeviceType T>
378 inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept
379 {
380   const auto mimpl = MatIMPLCast(A);
381 
382   PetscFunctionBegin;
383   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
384   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
385   PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array));
386   PetscFunctionReturn(PETSC_SUCCESS);
387 }
388 
389 template <device::cupm::DeviceType T>
390 inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept
391 {
392   const auto mimpl = MatIMPLCast(A);
393 
394   PetscFunctionBegin;
395   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
396   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
397   PetscCall(MatDenseCUPMResetArray<T>(mimpl->A));
398   PetscFunctionReturn(PETSC_SUCCESS);
399 }
400 
401 // ==========================================================================================
402 
403 template <device::cupm::DeviceType T>
404 inline PetscErrorCode MatDense_MPI_CUPM<T>::Shift(Mat A, PetscScalar alpha) noexcept
405 {
406   PetscDeviceContext dctx;
407 
408   PetscFunctionBegin;
409   PetscCall(GetHandles_(&dctx));
410   PetscCall(PetscInfo(A, "Performing Shift on backend\n"));
411   PetscCall(DiagonalUnaryTransform(A, A->rmap->rstart, A->rmap->rend, A->cmap->N, dctx, device::cupm::functors::make_plus_equals(alpha)));
412   PetscFunctionReturn(PETSC_SUCCESS);
413 }
414 
415 } // namespace impl
416 
417 namespace
418 {
419 
420 template <device::cupm::DeviceType T>
421 inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept
422 {
423   PetscMPIInt size;
424 
425   PetscFunctionBegin;
426   PetscValidPointer(A, 7);
427   PetscCallMPI(MPI_Comm_size(comm, &size));
428   if (size > 1) {
429     PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx));
430   } else {
431     if (n == PETSC_DECIDE) n = N;
432     if (m == PETSC_DECIDE) m = M;
433     // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down
434     // the line
435     PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx));
436   }
437   PetscFunctionReturn(PETSC_SUCCESS);
438 }
439 
440 } // anonymous namespace
441 
442 } // namespace cupm
443 
444 } // namespace mat
445 
446 } // namespace Petsc
447 
448 #endif // PETSCMATMPIDENSECUPM_HPP
449