xref: /petsc/src/mat/impls/dense/mpi/cupm/matmpidensecupm.hpp (revision bcd3bd92eda2d5998e2f14c4bbfb33bd936bdc3e)
1 #ifndef PETSCMATMPIDENSECUPM_HPP
2 #define PETSCMATMPIDENSECUPM_HPP
3 
4 #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/
5 #include <../src/mat/impls/dense/mpi/mpidense.h>
6 
7 #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp>
8 #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp>
9 
10 namespace Petsc
11 {
12 
13 namespace mat
14 {
15 
16 namespace cupm
17 {
18 
19 namespace impl
20 {
21 
22 template <device::cupm::DeviceType T>
23 class MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> {
24 public:
25   MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>);
26 
27 private:
28   PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept;
29   PETSC_NODISCARD static constexpr MatType       MATIMPLCUPM_() noexcept;
30 
31   static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept;
32 
33   template <bool to_host>
34   static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept;
35 
36 public:
37   PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept;
38 
39   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept;
40   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept;
41 
42   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept;
43   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept;
44 
45   static PetscErrorCode Create(Mat) noexcept;
46 
47   static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept;
48   static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept;
49   static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept;
50 
51   template <PetscMemType, PetscMemoryAccessMode>
52   static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
53   template <PetscMemType, PetscMemoryAccessMode>
54   static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
55 
56 private:
57   template <PetscMemType mtype, PetscMemoryAccessMode mode>
58   static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept
59   {
60     return GetArray<mtype, mode>(m, p);
61   }
62 
63   template <PetscMemType mtype, PetscMemoryAccessMode mode>
64   static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept
65   {
66     return RestoreArray<mtype, mode>(m, p);
67   }
68 
69 public:
70   template <PetscMemoryAccessMode>
71   static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept;
72   template <PetscMemoryAccessMode>
73   static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept;
74 
75   static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept;
76   static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept;
77   static PetscErrorCode ResetArray(Mat) noexcept;
78 };
79 
80 } // namespace impl
81 
82 namespace
83 {
84 
85 // Declare this here so that the functions below can make use of it
86 template <device::cupm::DeviceType T>
87 inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept
88 {
89   PetscFunctionBegin;
90   PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate));
91   PetscFunctionReturn(PETSC_SUCCESS);
92 }
93 
94 } // anonymous namespace
95 
96 namespace impl
97 {
98 
99 // ==========================================================================================
100 // MatDense_MPI_CUPM -- Private API
101 // ==========================================================================================
102 
103 template <device::cupm::DeviceType T>
104 inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept
105 {
106   return static_cast<Mat_MPIDense *>(m->data);
107 }
108 
109 template <device::cupm::DeviceType T>
110 inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept
111 {
112   return MATMPIDENSECUPM();
113 }
114 
115 // ==========================================================================================
116 
117 template <device::cupm::DeviceType T>
118 inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
119 {
120   PetscFunctionBegin;
121   if (auto &mimplA = MatIMPLCast(A)->A) {
122     PetscCall(MatSetType(mimplA, MATSEQDENSECUPM()));
123     PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array));
124   } else {
125     PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx));
126   }
127   PetscFunctionReturn(PETSC_SUCCESS);
128 }
129 
130 template <device::cupm::DeviceType T>
131 template <bool to_host>
132 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept
133 {
134   PetscFunctionBegin;
135   if (reuse == MAT_INITIAL_MATRIX) {
136     PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat));
137   } else if (reuse == MAT_REUSE_MATRIX) {
138     PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN));
139   }
140   {
141     const auto B    = *newmat;
142     const auto pobj = PetscObjectCast(B);
143 
144     if (to_host) {
145       PetscCall(BindToCPU(B, PETSC_TRUE));
146     } else {
147       PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
148     }
149 
150     PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype));
151     PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM()));
152 
153     // ============================================================
154     // Composed Ops
155     // ============================================================
156     MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense);
157     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
158     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
159     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
160     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
161     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
162     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
163     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
164     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
165     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
166     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
167     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray);
168     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray);
169     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray);
170     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMSetPreallocation_C(), nullptr, SetPreallocation);
171 
172     if (to_host) {
173       if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A));
174       B->offloadmask = PETSC_OFFLOAD_CPU;
175     } else {
176       if (auto &m_A = MatIMPLCast(B)->A) {
177         PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A));
178         B->offloadmask = PETSC_OFFLOAD_BOTH;
179       } else {
180         B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
181       }
182       PetscCall(BindToCPU(B, PETSC_FALSE));
183     }
184 
185     // ============================================================
186     // Function Pointer Ops
187     // ============================================================
188     MatSetOp_CUPM(to_host, B, getdiagonal, MatGetDiagonal_MPIDense, GetDiagonal);
189     MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU);
190   }
191   PetscFunctionReturn(PETSC_SUCCESS);
192 }
193 
194 // ==========================================================================================
195 // MatDense_MPI_CUPM -- Public API
196 // ==========================================================================================
197 
198 template <device::cupm::DeviceType T>
199 inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept
200 {
201   return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C";
202 }
203 
204 template <device::cupm::DeviceType T>
205 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept
206 {
207   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C";
208 }
209 
210 template <device::cupm::DeviceType T>
211 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept
212 {
213   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C";
214 }
215 
216 template <device::cupm::DeviceType T>
217 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept
218 {
219   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C";
220 }
221 
222 template <device::cupm::DeviceType T>
223 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept
224 {
225   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C";
226 }
227 
228 // ==========================================================================================
229 
230 template <device::cupm::DeviceType T>
231 inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept
232 {
233   PetscFunctionBegin;
234   PetscCall(MatCreate_MPIDense(A));
235   PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A));
236   PetscFunctionReturn(PETSC_SUCCESS);
237 }
238 
239 // ==========================================================================================
240 
241 template <device::cupm::DeviceType T>
242 inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept
243 {
244   const auto mimpl = MatIMPLCast(A);
245   const auto pobj  = PetscObjectCast(A);
246 
247   PetscFunctionBegin;
248   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
249   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
250   if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost));
251   A->boundtocpu = usehost;
252   PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype));
253   if (!usehost) {
254     PetscBool iscupm;
255 
256     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm));
257     if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec));
258     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm));
259     if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat));
260   }
261 
262   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
263   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
264   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>);
265   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>);
266   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
267   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
268 
269   MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift);
270 
271   if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost));
272   PetscFunctionReturn(PETSC_SUCCESS);
273 }
274 
275 template <device::cupm::DeviceType T>
276 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
277 {
278   PetscFunctionBegin;
279   PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat));
280   PetscFunctionReturn(PETSC_SUCCESS);
281 }
282 
283 template <device::cupm::DeviceType T>
284 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
285 {
286   PetscFunctionBegin;
287   PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat));
288   PetscFunctionReturn(PETSC_SUCCESS);
289 }
290 
291 // ==========================================================================================
292 
293 template <device::cupm::DeviceType T>
294 template <PetscMemType, PetscMemoryAccessMode access>
295 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext dctx) noexcept
296 {
297   auto &mimplA = MatIMPLCast(A)->A;
298 
299   PetscFunctionBegin;
300   if (!mimplA) PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, nullptr, &mimplA, dctx));
301   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimplA, array));
302   PetscFunctionReturn(PETSC_SUCCESS);
303 }
304 
305 template <device::cupm::DeviceType T>
306 template <PetscMemType, PetscMemoryAccessMode access>
307 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
308 {
309   PetscFunctionBegin;
310   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array));
311   PetscFunctionReturn(PETSC_SUCCESS);
312 }
313 
314 // ==========================================================================================
315 
316 template <device::cupm::DeviceType T>
317 template <PetscMemoryAccessMode access>
318 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept
319 {
320   using namespace vec::cupm;
321 
322   const auto mimpl   = MatIMPLCast(A);
323   const auto mimpl_A = mimpl->A;
324   const auto pobj    = PetscObjectCast(A);
325   PetscInt   lda;
326 
327   PetscFunctionBegin;
328   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
329   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
330   mimpl->vecinuse = col + 1;
331 
332   if (!mimpl->cvec) PetscCall(MatDenseCreateColumnVec_Private(A, &mimpl->cvec));
333 
334   PetscCall(MatDenseGetLDA(mimpl_A, &lda));
335   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
336   PetscCall(VecCUPMPlaceArrayAsync<T>(mimpl->cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda)));
337 
338   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(mimpl->cvec));
339   *v = mimpl->cvec;
340   PetscFunctionReturn(PETSC_SUCCESS);
341 }
342 
343 template <device::cupm::DeviceType T>
344 template <PetscMemoryAccessMode access>
345 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept
346 {
347   using namespace vec::cupm;
348 
349   const auto mimpl = MatIMPLCast(A);
350   const auto cvec  = mimpl->cvec;
351 
352   PetscFunctionBegin;
353   PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first");
354   PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector");
355   mimpl->vecinuse = 0;
356 
357   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
358   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec));
359   PetscCall(VecCUPMResetArrayAsync<T>(cvec));
360 
361   if (v) *v = nullptr;
362   PetscFunctionReturn(PETSC_SUCCESS);
363 }
364 
365 // ==========================================================================================
366 
367 template <device::cupm::DeviceType T>
368 inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept
369 {
370   const auto mimpl = MatIMPLCast(A);
371 
372   PetscFunctionBegin;
373   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
374   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
375   PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array));
376   PetscFunctionReturn(PETSC_SUCCESS);
377 }
378 
379 template <device::cupm::DeviceType T>
380 inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept
381 {
382   const auto mimpl = MatIMPLCast(A);
383 
384   PetscFunctionBegin;
385   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
386   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
387   PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array));
388   PetscFunctionReturn(PETSC_SUCCESS);
389 }
390 
391 template <device::cupm::DeviceType T>
392 inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept
393 {
394   const auto mimpl = MatIMPLCast(A);
395 
396   PetscFunctionBegin;
397   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
398   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
399   PetscCall(MatDenseCUPMResetArray<T>(mimpl->A));
400   PetscFunctionReturn(PETSC_SUCCESS);
401 }
402 
403 } // namespace impl
404 
405 namespace
406 {
407 
408 template <device::cupm::DeviceType T>
409 inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept
410 {
411   PetscMPIInt size;
412 
413   PetscFunctionBegin;
414   PetscAssertPointer(A, 7);
415   PetscCallMPI(MPI_Comm_size(comm, &size));
416   if (size > 1) {
417     PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx));
418   } else {
419     if (n == PETSC_DECIDE) n = N;
420     if (m == PETSC_DECIDE) m = M;
421     // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down
422     // the line
423     PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx));
424   }
425   PetscFunctionReturn(PETSC_SUCCESS);
426 }
427 
428 } // anonymous namespace
429 
430 } // namespace cupm
431 
432 } // namespace mat
433 
434 } // namespace Petsc
435 
436 #endif // PETSCMATMPIDENSECUPM_HPP
437