xref: /petsc/src/mat/impls/dense/mpi/cupm/matmpidensecupm.hpp (revision bb59972c36a4dbc3a3dca65232097048bda6fd19)
1 #pragma once
2 
3 #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/
4 #include <../src/mat/impls/dense/mpi/mpidense.h>
5 
6 #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp>
7 #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp>
8 
9 namespace Petsc
10 {
11 
12 namespace mat
13 {
14 
15 namespace cupm
16 {
17 
18 namespace impl
19 {
20 
21 template <device::cupm::DeviceType T>
22 class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> {
23 public:
24   MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>);
25 
26 private:
27   PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept;
28   PETSC_NODISCARD static constexpr MatType       MATIMPLCUPM_() noexcept;
29 
30   static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept;
31 
32   template <bool to_host>
33   static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept;
34 
35 public:
36   PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept;
37 
38   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept;
39   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept;
40 
41   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept;
42   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept;
43 
44   static PetscErrorCode Create(Mat) noexcept;
45 
46   static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept;
47   static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept;
48   static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept;
49 
50   template <PetscMemType, PetscMemoryAccessMode>
51   static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
52   template <PetscMemType, PetscMemoryAccessMode>
53   static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
54 
55 private:
56   template <PetscMemType mtype, PetscMemoryAccessMode mode>
GetArrayC_(Mat m,PetscScalar ** p)57   static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept
58   {
59     return GetArray<mtype, mode>(m, p);
60   }
61 
62   template <PetscMemType mtype, PetscMemoryAccessMode mode>
RestoreArrayC_(Mat m,PetscScalar ** p)63   static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept
64   {
65     return RestoreArray<mtype, mode>(m, p);
66   }
67 
68 public:
69   template <PetscMemoryAccessMode>
70   static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept;
71   template <PetscMemoryAccessMode>
72   static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept;
73 
74   static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept;
75   static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept;
76   static PetscErrorCode ResetArray(Mat) noexcept;
77 };
78 
79 } // namespace impl
80 
81 namespace
82 {
83 
84 // Declare this here so that the functions below can make use of it
85 template <device::cupm::DeviceType T>
MatCreateMPIDenseCUPM(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt M,PetscInt N,PetscScalar * data,Mat * A,PetscDeviceContext dctx=nullptr,bool preallocate=true)86 inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept
87 {
88   PetscFunctionBegin;
89   PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate));
90   PetscFunctionReturn(PETSC_SUCCESS);
91 }
92 
93 } // anonymous namespace
94 
95 namespace impl
96 {
97 
98 // ==========================================================================================
99 // MatDense_MPI_CUPM -- Private API
100 // ==========================================================================================
101 
102 template <device::cupm::DeviceType T>
MatIMPLCast_(Mat m)103 inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept
104 {
105   return static_cast<Mat_MPIDense *>(m->data);
106 }
107 
108 template <device::cupm::DeviceType T>
MATIMPLCUPM_()109 inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept
110 {
111   return MATMPIDENSECUPM();
112 }
113 
114 // ==========================================================================================
115 
116 template <device::cupm::DeviceType T>
SetPreallocation_(Mat A,PetscDeviceContext dctx,PetscScalar * device_array)117 inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
118 {
119   PetscFunctionBegin;
120   if (auto &mimplA = MatIMPLCast(A)->A) {
121     PetscCall(MatSetType(mimplA, MATSEQDENSECUPM()));
122     PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array));
123   } else {
124     PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx));
125   }
126   PetscFunctionReturn(PETSC_SUCCESS);
127 }
128 
129 template <device::cupm::DeviceType T>
130 template <bool to_host>
Convert_Dispatch_(Mat M,MatType,MatReuse reuse,Mat * newmat)131 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept
132 {
133   PetscFunctionBegin;
134   if (reuse == MAT_INITIAL_MATRIX) {
135     PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat));
136   } else if (reuse == MAT_REUSE_MATRIX) {
137     PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN));
138   }
139   {
140     const auto B    = *newmat;
141     const auto pobj = PetscObjectCast(B);
142 
143     if (to_host) {
144       PetscCall(BindToCPU(B, PETSC_TRUE));
145     } else {
146       PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
147     }
148 
149     PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype));
150     PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM()));
151 
152     // ============================================================
153     // Composed Ops
154     // ============================================================
155     MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense);
156     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
157     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
158     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
159     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
160     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
161     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
162     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
163     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
164     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
165     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
166     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray);
167     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray);
168     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray);
169     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMSetPreallocation_C(), nullptr, SetPreallocation);
170 
171     if (to_host) {
172       if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A));
173       B->offloadmask = PETSC_OFFLOAD_CPU;
174     } else {
175       if (auto &m_A = MatIMPLCast(B)->A) {
176         PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A));
177         B->offloadmask = PETSC_OFFLOAD_BOTH;
178       } else {
179         B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
180       }
181       PetscCall(BindToCPU(B, PETSC_FALSE));
182     }
183 
184     // ============================================================
185     // Function Pointer Ops
186     // ============================================================
187     MatSetOp_CUPM(to_host, B, getdiagonal, MatGetDiagonal_MPIDense, GetDiagonal);
188     MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU);
189   }
190   PetscFunctionReturn(PETSC_SUCCESS);
191 }
192 
193 // ==========================================================================================
194 // MatDense_MPI_CUPM -- Public API
195 // ==========================================================================================
196 
197 template <device::cupm::DeviceType T>
MatConvert_mpidensecupm_mpidense_C()198 inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept
199 {
200   return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C";
201 }
202 
203 template <device::cupm::DeviceType T>
MatProductSetFromOptions_mpiaij_mpidensecupm_C()204 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept
205 {
206   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C";
207 }
208 
209 template <device::cupm::DeviceType T>
MatProductSetFromOptions_mpidensecupm_mpiaij_C()210 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept
211 {
212   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C";
213 }
214 
215 template <device::cupm::DeviceType T>
MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C()216 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept
217 {
218   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C";
219 }
220 
221 template <device::cupm::DeviceType T>
MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C()222 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept
223 {
224   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C";
225 }
226 
227 // ==========================================================================================
228 
229 template <device::cupm::DeviceType T>
Create(Mat A)230 inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept
231 {
232   PetscFunctionBegin;
233   PetscCall(MatCreate_MPIDense(A));
234   PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A));
235   PetscFunctionReturn(PETSC_SUCCESS);
236 }
237 
238 // ==========================================================================================
239 
240 template <device::cupm::DeviceType T>
BindToCPU(Mat A,PetscBool usehost)241 inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept
242 {
243   const auto mimpl = MatIMPLCast(A);
244   const auto pobj  = PetscObjectCast(A);
245 
246   PetscFunctionBegin;
247   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
248   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
249   if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost));
250   A->boundtocpu = usehost;
251   PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype));
252   if (!usehost) {
253     PetscBool iscupm;
254 
255     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm));
256     if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec));
257     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm));
258     if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat));
259   }
260 
261   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
262   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
263   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>);
264   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>);
265   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
266   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
267 
268   MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift);
269 
270   if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost));
271   PetscFunctionReturn(PETSC_SUCCESS);
272 }
273 
274 template <device::cupm::DeviceType T>
Convert_MPIDenseCUPM_MPIDense(Mat M,MatType mtype,MatReuse reuse,Mat * newmat)275 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
276 {
277   PetscFunctionBegin;
278   PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat));
279   PetscFunctionReturn(PETSC_SUCCESS);
280 }
281 
282 template <device::cupm::DeviceType T>
Convert_MPIDense_MPIDenseCUPM(Mat M,MatType mtype,MatReuse reuse,Mat * newmat)283 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
284 {
285   PetscFunctionBegin;
286   PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat));
287   PetscFunctionReturn(PETSC_SUCCESS);
288 }
289 
290 // ==========================================================================================
291 
292 template <device::cupm::DeviceType T>
293 template <PetscMemType, PetscMemoryAccessMode access>
GetArray(Mat A,PetscScalar ** array,PetscDeviceContext dctx)294 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext dctx) noexcept
295 {
296   auto &mimplA = MatIMPLCast(A)->A;
297 
298   PetscFunctionBegin;
299   if (!mimplA) PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, nullptr, &mimplA, dctx));
300   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimplA, array));
301   PetscFunctionReturn(PETSC_SUCCESS);
302 }
303 
304 template <device::cupm::DeviceType T>
305 template <PetscMemType, PetscMemoryAccessMode access>
RestoreArray(Mat A,PetscScalar ** array,PetscDeviceContext)306 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
307 {
308   PetscFunctionBegin;
309   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array));
310   PetscFunctionReturn(PETSC_SUCCESS);
311 }
312 
313 // ==========================================================================================
314 
315 template <device::cupm::DeviceType T>
316 template <PetscMemoryAccessMode access>
GetColumnVec(Mat A,PetscInt col,Vec * v)317 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept
318 {
319   using namespace vec::cupm;
320 
321   const auto mimpl   = MatIMPLCast(A);
322   const auto mimpl_A = mimpl->A;
323   const auto pobj    = PetscObjectCast(A);
324   PetscInt   lda;
325 
326   PetscFunctionBegin;
327   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
328   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
329   mimpl->vecinuse = col + 1;
330 
331   if (!mimpl->cvec) PetscCall(MatDenseCreateColumnVec_Private(A, &mimpl->cvec));
332 
333   PetscCall(MatDenseGetLDA(mimpl_A, &lda));
334   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
335   PetscCall(VecCUPMPlaceArrayAsync<T>(mimpl->cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda)));
336 
337   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(mimpl->cvec));
338   *v = mimpl->cvec;
339   PetscFunctionReturn(PETSC_SUCCESS);
340 }
341 
342 template <device::cupm::DeviceType T>
343 template <PetscMemoryAccessMode access>
RestoreColumnVec(Mat A,PetscInt,Vec * v)344 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept
345 {
346   using namespace vec::cupm;
347 
348   const auto mimpl = MatIMPLCast(A);
349   const auto cvec  = mimpl->cvec;
350 
351   PetscFunctionBegin;
352   PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first");
353   PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector");
354   mimpl->vecinuse = 0;
355 
356   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
357   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec));
358   PetscCall(VecCUPMResetArrayAsync<T>(cvec));
359 
360   if (v) *v = nullptr;
361   PetscFunctionReturn(PETSC_SUCCESS);
362 }
363 
364 // ==========================================================================================
365 
366 template <device::cupm::DeviceType T>
PlaceArray(Mat A,const PetscScalar * array)367 inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept
368 {
369   const auto mimpl = MatIMPLCast(A);
370 
371   PetscFunctionBegin;
372   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
373   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
374   PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array));
375   PetscFunctionReturn(PETSC_SUCCESS);
376 }
377 
378 template <device::cupm::DeviceType T>
ReplaceArray(Mat A,const PetscScalar * array)379 inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept
380 {
381   const auto mimpl = MatIMPLCast(A);
382 
383   PetscFunctionBegin;
384   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
385   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
386   PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array));
387   PetscFunctionReturn(PETSC_SUCCESS);
388 }
389 
390 template <device::cupm::DeviceType T>
ResetArray(Mat A)391 inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept
392 {
393   const auto mimpl = MatIMPLCast(A);
394 
395   PetscFunctionBegin;
396   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
397   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
398   PetscCall(MatDenseCUPMResetArray<T>(mimpl->A));
399   PetscFunctionReturn(PETSC_SUCCESS);
400 }
401 
402 } // namespace impl
403 
404 namespace
405 {
406 
407 template <device::cupm::DeviceType T>
MatCreateDenseCUPM(MPI_Comm comm,PetscInt n,PetscInt m,PetscInt N,PetscInt M,PetscScalar * data,Mat * A,PetscDeviceContext dctx=nullptr)408 inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept
409 {
410   PetscMPIInt size;
411 
412   PetscFunctionBegin;
413   PetscAssertPointer(A, 7);
414   PetscCallMPI(MPI_Comm_size(comm, &size));
415   if (size > 1) {
416     PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx));
417   } else {
418     if (n == PETSC_DECIDE) n = N;
419     if (m == PETSC_DECIDE) m = M;
420     // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down
421     // the line
422     PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx));
423   }
424   PetscFunctionReturn(PETSC_SUCCESS);
425 }
426 
427 } // anonymous namespace
428 
429 } // namespace cupm
430 
431 } // namespace mat
432 
433 } // namespace Petsc
434