xref: /petsc/include/petsc/private/veccupmimpl.h (revision bbfde98d24007149e2e73427fb24bfdd802107d3)
1 #ifndef PETSCVECCUPMIMPL_H
2 #define PETSCVECCUPMIMPL_H
3 
4 #include <petsc/private/vecimpl.h>
5 #include <../src/vec/vec/impls/dvecimpl.h> // for Vec_Seq
6 
7 #if PetscDefined(HAVE_NVSHMEM)
8 PETSC_INTERN PetscErrorCode PetscNvshmemInitializeCheck(void);
9 PETSC_INTERN PetscErrorCode PetscNvshmemMalloc(size_t, void **);
10 PETSC_INTERN PetscErrorCode PetscNvshmemCalloc(size_t, void **);
11 PETSC_INTERN PetscErrorCode PetscNvshmemFree_Private(void *);
12   #define PetscNvshmemFree(ptr) ((PetscErrorCode)((ptr) && (PetscNvshmemFree_Private(ptr) || ((ptr) = PETSC_NULLPTR, PETSC_SUCCESS))))
13 PETSC_INTERN PetscErrorCode PetscNvshmemSum(PetscInt, PetscScalar *, const PetscScalar *);
14 PETSC_INTERN PetscErrorCode PetscNvshmemMax(PetscInt, PetscReal *, const PetscReal *);
15 PETSC_INTERN PetscErrorCode VecAllocateNVSHMEM_SeqCUDA(Vec);
16 #else
17   #define PetscNvshmemFree(ptr) PETSC_SUCCESS
18 #endif
19 
20 #if defined(__cplusplus) && PetscDefined(HAVE_DEVICE)
21   #include <petsc/private/deviceimpl.h>
22   #include <petsc/private/cupmobject.hpp>
23   #include <petsc/private/cupmblasinterface.hpp>
24 
25   #include <petsc/private/cpp/functional.hpp>
26 
27   #include <limits> // std::numeric_limits
28 
29 namespace Petsc
30 {
31 
32 namespace vec
33 {
34 
35 namespace cupm
36 {
37 
38 namespace impl
39 {
40 
41 namespace
42 {
43 
44 struct no_op {
45   template <typename... T>
46   constexpr PetscErrorCode operator()(T &&...) const noexcept
47   {
48     return PETSC_SUCCESS;
49   }
50 };
51 
52 template <typename T>
53 struct CooPair {
54   using value_type = T;
55   using size_type  = PetscCount;
56 
57   value_type *&device;
58   value_type *&host;
59   size_type    size;
60 };
61 
62 template <typename U>
63 static constexpr CooPair<U> make_coo_pair(U *&device, U *&host, PetscCount size) noexcept
64 {
65   return {device, host, size};
66 }
67 
68 } // anonymous namespace
69 
70 // forward declarations
71 template <device::cupm::DeviceType>
72 class VecSeq_CUPM;
73 template <device::cupm::DeviceType>
74 class VecMPI_CUPM;
75 
76 // ==========================================================================================
77 // Vec_CUPMBase
78 //
79 // Base class for the VecSeq and VecMPI CUPM implementations. On top of the usual DeviceType
80 // template parameter it also uses CRTP to be able to use values/calls specific to either
81 // VecSeq or VecMPI. This is in effect "inside-out" polymorphism.
82 // ==========================================================================================
83 template <device::cupm::DeviceType T, typename Derived>
84 class Vec_CUPMBase : protected device::cupm::impl::CUPMObject<T> {
85 public:
86   PETSC_CUPMOBJECT_HEADER(T);
87 
88   // ==========================================================================================
89   // Vec_CUPMBase::VectorArray
90   //
91   // RAII versions of the get/restore array routines. Determines constness of the pointer type,
92   // holds the pointer itself provides the implicit conversion operator
93   // ==========================================================================================
94   template <PetscMemType, PetscMemoryAccessMode>
95   class VectorArray;
96 
97 protected:
98   static PetscErrorCode VecView_Debug(Vec v, const char *message = "") noexcept
99   {
100     const auto   pobj  = PetscObjectCast(v);
101     const auto   vimpl = VecIMPLCast(v);
102     const auto   vcu   = VecCUPMCast(v);
103     PetscMemType mtype;
104     MPI_Comm     comm;
105 
106     PetscFunctionBegin;
107     PetscValidPointer(vimpl, 1);
108     PetscValidPointer(vcu, 1);
109     PetscCall(PetscObjectGetComm(pobj, &comm));
110     PetscCall(PetscPrintf(comm, "---------- %s ----------\n", message));
111     PetscCall(PetscObjectPrintClassNamePrefixType(pobj, PETSC_VIEWER_STDOUT_(comm)));
112     PetscCall(PetscPrintf(comm, "Address:             %p\n", v));
113     PetscCall(PetscPrintf(comm, "Size:                %" PetscInt_FMT "\n", v->map->n));
114     PetscCall(PetscPrintf(comm, "Offload mask:        %s\n", PetscOffloadMaskToString(v->offloadmask)));
115     PetscCall(PetscPrintf(comm, "Host ptr:            %p\n", vimpl->array));
116     PetscCall(PetscPrintf(comm, "Device ptr:          %p\n", vcu->array_d));
117     PetscCall(PetscPrintf(comm, "Device alloced ptr:  %p\n", vcu->array_allocated_d));
118     PetscCall(PetscCUPMGetMemType(vcu->array_d, &mtype));
119     PetscCall(PetscPrintf(comm, "dptr is device mem?  %s\n", PetscBools[static_cast<PetscBool>(PetscMemTypeDevice(mtype))]));
120     PetscFunctionReturn(PETSC_SUCCESS);
121   }
122 
123   // Delete the allocated device array if required and replace it with the given array
124   static PetscErrorCode ResetAllocatedDevicePtr_(PetscDeviceContext, Vec, PetscScalar * = nullptr) noexcept;
125   // Check either the host or device impl pointer is allocated and allocate it if
126   // isn't. CastFunctionType casts the Vec to the required type and returns the pointer
127   template <typename CastFunctionType>
128   static PetscErrorCode VecAllocateCheck_(Vec, void *&, CastFunctionType &&) noexcept;
129   // Check the CUPM part (v->spptr) is allocated, otherwise allocate it
130   static PetscErrorCode VecCUPMAllocateCheck_(Vec) noexcept;
131   // Check the Host part (v->data) is allocated, otherwise allocate it
132   static PetscErrorCode VecIMPLAllocateCheck_(Vec) noexcept;
133   // Check the Host array is allocated, otherwise allocate it
134   static PetscErrorCode HostAllocateCheck_(PetscDeviceContext, Vec) noexcept;
135   // Check the CUPM array is allocated, otherwise allocate it
136   static PetscErrorCode DeviceAllocateCheck_(PetscDeviceContext, Vec) noexcept;
137   // Copy HTOD, allocating device if necessary
138   static PetscErrorCode CopyToDevice_(PetscDeviceContext, Vec, bool = false) noexcept;
139   // Copy DTOH, allocating host if necessary
140   static PetscErrorCode CopyToHost_(PetscDeviceContext, Vec, bool = false) noexcept;
141   static PetscErrorCode DestroyDevice_(Vec) noexcept;
142   static PetscErrorCode DestroyHost_(Vec) noexcept;
143 
144 public:
145   struct Vec_CUPM {
146     PetscScalar *array_d;           // gpu data
147     PetscScalar *array_allocated_d; // does PETSc own the array ptr?
148     PetscBool    nvshmem;           // is array allocated in nvshmem? It is used to allocate
149                                     // Mvctx->lvec in nvshmem
150 
151     // COO stuff
152     PetscCount *jmap1_d; // [m+1]: i-th entry of the vector has jmap1[i+1]-jmap1[i] repeats
153                          // in COO arrays
154     PetscCount *perm1_d; // [tot1]: permutation array for local entries
155     PetscCount *imap2_d; // [nnz2]: i-th unique entry in recvbuf is imap2[i]-th entry in
156                          // the vector
157     PetscCount *jmap2_d; // [nnz2+1]
158     PetscCount *perm2_d; // [recvlen]
159     PetscCount *Cperm_d; // [sendlen]: permutation array to fill sendbuf[]. 'C' for
160                          // communication
161 
162     // Buffers for remote values in VecSetValuesCOO()
163     PetscScalar *sendbuf_d;
164     PetscScalar *recvbuf_d;
165   };
166 
167   // Cast the Vec to its Vec_CUPM struct, i.e. return the result of (Vec_CUPM *)v->spptr
168   PETSC_NODISCARD static Vec_CUPM *VecCUPMCast(Vec) noexcept;
169   // Cast the Vec to its host struct, i.e. return the result of (Vec_Seq *)v->data
170   template <typename U = Derived>
171   PETSC_NODISCARD static constexpr auto VecIMPLCast(Vec v) noexcept -> decltype(U::VecIMPLCast_(v));
172   // Get the PetscLogEvents for HTOD and DTOH
173   PETSC_NODISCARD static constexpr PetscLogEvent VEC_CUPMCopyToGPU() noexcept;
174   PETSC_NODISCARD static constexpr PetscLogEvent VEC_CUPMCopyFromGPU() noexcept;
175   // Get the VecTypes
176   PETSC_NODISCARD static constexpr VecType VECSEQCUPM() noexcept;
177   PETSC_NODISCARD static constexpr VecType VECMPICUPM() noexcept;
178   PETSC_NODISCARD static constexpr VecType VECCUPM() noexcept;
179 
180   // Get the device VecType of the calling vector
181   template <typename U = Derived>
182   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM() noexcept;
183   // Get the host VecType of the calling vector
184   template <typename U = Derived>
185   PETSC_NODISCARD static constexpr VecType VECIMPL() noexcept;
186 
187   // Call the host destroy function, i.e. VecDestroy_Seq()
188   static PetscErrorCode VecDestroy_IMPL(Vec) noexcept;
189   // Call the host reset function, i.e. VecResetArray_Seq()
190   static PetscErrorCode VecResetArray_IMPL(Vec) noexcept;
191   // ... you get the idea
192   static PetscErrorCode VecPlaceArray_IMPL(Vec, const PetscScalar *) noexcept;
193   // Call the host creation function, i.e. VecCreate_Seq(), and also initialize the CUPM part
194   // along with it if needed
195   static PetscErrorCode VecCreate_IMPL_Private(Vec, PetscBool *, PetscInt = 0, PetscScalar * = nullptr) noexcept;
196 
197   // Shorthand for creating VectorArray's. Need functions to create them, otherwise using them
198   // as an unnamed temporary leads to most vexing parse
199   PETSC_NODISCARD static auto DeviceArrayRead(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>{dctx, v});
200   PETSC_NODISCARD static auto DeviceArrayWrite(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>{dctx, v});
201   PETSC_NODISCARD static auto DeviceArrayReadWrite(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, v});
202   PETSC_NODISCARD static auto HostArrayRead(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>{dctx, v});
203   PETSC_NODISCARD static auto HostArrayWrite(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>{dctx, v});
204   PETSC_NODISCARD static auto HostArrayReadWrite(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, v});
205 
206   // ops-table functions
207   static PetscErrorCode Create(Vec) noexcept;
208   static PetscErrorCode Destroy(Vec) noexcept;
209   template <PetscMemType, PetscMemoryAccessMode, bool = false>
210   static PetscErrorCode GetArray(Vec, PetscScalar **, PetscDeviceContext) noexcept;
211   template <PetscMemType, PetscMemoryAccessMode, bool = false>
212   static PetscErrorCode GetArray(Vec, PetscScalar **) noexcept;
213   template <PetscMemType, PetscMemoryAccessMode>
214   static PetscErrorCode RestoreArray(Vec, PetscScalar **, PetscDeviceContext) noexcept;
215   template <PetscMemType, PetscMemoryAccessMode>
216   static PetscErrorCode RestoreArray(Vec, PetscScalar **) noexcept;
217   template <PetscMemoryAccessMode>
218   static PetscErrorCode GetArrayAndMemtype(Vec, PetscScalar **, PetscMemType *, PetscDeviceContext) noexcept;
219   template <PetscMemoryAccessMode>
220   static PetscErrorCode GetArrayAndMemtype(Vec, PetscScalar **, PetscMemType *) noexcept;
221   template <PetscMemoryAccessMode>
222   static PetscErrorCode RestoreArrayAndMemtype(Vec, PetscScalar **, PetscDeviceContext) noexcept;
223   template <PetscMemoryAccessMode>
224   static PetscErrorCode RestoreArrayAndMemtype(Vec, PetscScalar **) noexcept;
225   template <PetscMemType>
226   static PetscErrorCode ReplaceArray(Vec, const PetscScalar *) noexcept;
227   template <PetscMemType>
228   static PetscErrorCode ResetArray(Vec) noexcept;
229   template <PetscMemType>
230   static PetscErrorCode PlaceArray(Vec, const PetscScalar *) noexcept;
231 
232   // common ops shared between Seq and MPI
233   static PetscErrorCode Create_CUPM(Vec) noexcept;
234   static PetscErrorCode Create_CUPMBase(MPI_Comm, PetscInt, PetscInt, PetscInt, Vec *, PetscBool, PetscLayout /*reference*/ = nullptr) noexcept;
235   static PetscErrorCode Initialize_CUPMBase(Vec, PetscBool, PetscScalar *, PetscScalar *, PetscDeviceContext) noexcept;
236   template <typename SetupFunctionT = no_op>
237   static PetscErrorCode Duplicate_CUPMBase(Vec, Vec *, PetscDeviceContext, SetupFunctionT && = SetupFunctionT{}) noexcept;
238   static PetscErrorCode BindToCPU_CUPMBase(Vec, PetscBool, PetscDeviceContext) noexcept;
239   static PetscErrorCode GetArrays_CUPMBase(Vec, const PetscScalar **, const PetscScalar **, PetscOffloadMask *, PetscDeviceContext) noexcept;
240   static PetscErrorCode ResetPreallocationCOO_CUPMBase(Vec, PetscDeviceContext) noexcept;
241   template <std::size_t NCount = 0, std::size_t NScal = 0>
242   static PetscErrorCode SetPreallocationCOO_CUPMBase(Vec, PetscCount, const PetscInt[], PetscDeviceContext, const std::array<CooPair<PetscCount>, NCount> & = {}, const std::array<CooPair<PetscScalar>, NScal> & = {}) noexcept;
243 
244   static PetscErrorCode Convert_IMPL_IMPLCUPM(Vec) noexcept;
245 };
246 
247 // ==========================================================================================
248 // Vec_CUPMBase::VectorArray
249 //
250 // RAII versions of the get/restore array routines. Determines constness of the pointer type,
251 // holds the pointer itself and provides the implicit conversion operator.
252 //
253 // On construction this calls the moral equivalent of Vec[CUPM]GetArray[Read|Write]()
254 // (depending on PetscMemoryAccessMode) and on destruction automatically restores the array
255 // for you
256 // ==========================================================================================
257 template <device::cupm::DeviceType T, typename D>
258 template <PetscMemType MT, PetscMemoryAccessMode MA>
259 class Vec_CUPMBase<T, D>::VectorArray : public device::cupm::impl::RestoreableArray<T, MT, MA> {
260   using base_type = device::cupm::impl::RestoreableArray<T, MT, MA>;
261 
262 public:
263   VectorArray(PetscDeviceContext, Vec) noexcept;
264   ~VectorArray() noexcept;
265 
266 private:
267   Vec v_ = nullptr;
268 };
269 
270 // ==========================================================================================
271 // Vec_CUPMBase::VectorArray - Public API
272 // ==========================================================================================
273 
274 template <device::cupm::DeviceType T, typename D>
275 template <PetscMemType MT, PetscMemoryAccessMode MA>
276 inline Vec_CUPMBase<T, D>::VectorArray<MT, MA>::VectorArray(PetscDeviceContext dctx, Vec v) noexcept : base_type{dctx}, v_{v}
277 {
278   PetscFunctionBegin;
279   PetscCallAbort(PETSC_COMM_SELF, Vec_CUPMBase<T, D>::template GetArray<MT, MA, true>(v, &this->ptr_, dctx));
280   PetscFunctionReturnVoid();
281 }
282 
283 template <device::cupm::DeviceType T, typename D>
284 template <PetscMemType MT, PetscMemoryAccessMode MA>
285 inline Vec_CUPMBase<T, D>::VectorArray<MT, MA>::~VectorArray() noexcept
286 {
287   PetscFunctionBegin;
288   PetscCallAbort(PETSC_COMM_SELF, Vec_CUPMBase<T, D>::template RestoreArray<MT, MA>(v_, &this->ptr_, this->dctx_));
289   PetscFunctionReturnVoid();
290 }
291 
292 // ==========================================================================================
293 // Vec_CUPMBase - Protected API
294 // ==========================================================================================
295 
296 template <device::cupm::DeviceType T, typename D>
297 inline PetscErrorCode Vec_CUPMBase<T, D>::ResetAllocatedDevicePtr_(PetscDeviceContext dctx, Vec v, PetscScalar *new_value) noexcept
298 {
299   auto &device_array = VecCUPMCast(v)->array_allocated_d;
300 
301   PetscFunctionBegin;
302   if (device_array) {
303     if (PetscDefined(HAVE_NVSHMEM) && VecCUPMCast(v)->nvshmem) {
304       PetscCall(PetscNvshmemFree(device_array));
305     } else {
306       cupmStream_t stream;
307 
308       PetscCall(GetHandlesFrom_(dctx, &stream));
309       PetscCallCUPM(cupmFreeAsync(device_array, stream));
310     }
311   }
312   device_array = new_value;
313   PetscFunctionReturn(PETSC_SUCCESS);
314 }
315 
316 namespace
317 {
318 
319 inline PetscErrorCode VecCUPMCheckMinimumPinnedMemory_Internal(Vec v, PetscBool *set = nullptr) noexcept
320 {
321   auto      mem = static_cast<PetscInt>(v->minimum_bytes_pinned_memory);
322   PetscBool flg;
323 
324   PetscFunctionBegin;
325   PetscObjectOptionsBegin(PetscObjectCast(v));
326   PetscCall(PetscOptionsRangeInt("-vec_pinned_memory_min", "Minimum size (in bytes) for an allocation to use pinned memory on host", "VecSetPinnedMemoryMin", mem, &mem, &flg, 0, std::numeric_limits<decltype(mem)>::max()));
327   if (flg) v->minimum_bytes_pinned_memory = mem;
328   PetscOptionsEnd();
329   if (set) *set = flg;
330   PetscFunctionReturn(PETSC_SUCCESS);
331 }
332 
333 } // anonymous namespace
334 
335 template <device::cupm::DeviceType T, typename D>
336 template <typename CastFunctionType>
337 inline PetscErrorCode Vec_CUPMBase<T, D>::VecAllocateCheck_(Vec v, void *&dest, CastFunctionType &&cast) noexcept
338 {
339   PetscFunctionBegin;
340   if (PetscLikely(dest)) PetscFunctionReturn(PETSC_SUCCESS);
341   // do the check here so we don't have to do it in every function
342   PetscCall(checkCupmBlasIntCast(v->map->n));
343   {
344     auto impl = cast(v);
345 
346     PetscCall(PetscNew(&impl));
347     dest = impl;
348   }
349   PetscFunctionReturn(PETSC_SUCCESS);
350 }
351 
352 template <device::cupm::DeviceType T, typename D>
353 inline PetscErrorCode Vec_CUPMBase<T, D>::VecIMPLAllocateCheck_(Vec v) noexcept
354 {
355   PetscFunctionBegin;
356   PetscCall(VecAllocateCheck_(v, v->data, VecIMPLCast<D>));
357   PetscFunctionReturn(PETSC_SUCCESS);
358 }
359 
360 // allocate the Vec_CUPM struct. this is normally done through DeviceAllocateCheck_(), but in
361 // certain circumstances (such as when the user places the device array) we do not want to do
362 // the full DeviceAllocateCheck_() as it also allocates the array
363 template <device::cupm::DeviceType T, typename D>
364 inline PetscErrorCode Vec_CUPMBase<T, D>::VecCUPMAllocateCheck_(Vec v) noexcept
365 {
366   PetscFunctionBegin;
367   PetscCall(VecAllocateCheck_(v, v->spptr, VecCUPMCast));
368   PetscFunctionReturn(PETSC_SUCCESS);
369 }
370 
371 template <device::cupm::DeviceType T, typename D>
372 inline PetscErrorCode Vec_CUPMBase<T, D>::HostAllocateCheck_(PetscDeviceContext, Vec v) noexcept
373 {
374   PetscFunctionBegin;
375   PetscCall(VecIMPLAllocateCheck_(v));
376   if (auto &alloc = VecIMPLCast(v)->array_allocated) PetscFunctionReturn(PETSC_SUCCESS);
377   else {
378     PetscCall(VecCUPMCheckMinimumPinnedMemory_Internal(v));
379     {
380       const auto n     = v->map->n;
381       const auto useit = UseCUPMHostAlloc((n * sizeof(*alloc)) > v->minimum_bytes_pinned_memory);
382 
383       v->pinned_memory = static_cast<decltype(v->pinned_memory)>(useit.value());
384       PetscCall(PetscMalloc1(n, &alloc));
385     }
386     if (!VecIMPLCast(v)->array) VecIMPLCast(v)->array = alloc;
387     if (v->offloadmask == PETSC_OFFLOAD_UNALLOCATED) v->offloadmask = PETSC_OFFLOAD_CPU;
388   }
389   PetscFunctionReturn(PETSC_SUCCESS);
390 }
391 
392 template <device::cupm::DeviceType T, typename D>
393 inline PetscErrorCode Vec_CUPMBase<T, D>::DeviceAllocateCheck_(PetscDeviceContext dctx, Vec v) noexcept
394 {
395   PetscFunctionBegin;
396   PetscCall(VecCUPMAllocateCheck_(v));
397   if (auto &alloc = VecCUPMCast(v)->array_d) PetscFunctionReturn(PETSC_SUCCESS);
398   else {
399     const auto   n                 = v->map->n;
400     auto        &array_allocated_d = VecCUPMCast(v)->array_allocated_d;
401     cupmStream_t stream;
402 
403     PetscCall(GetHandlesFrom_(dctx, &stream));
404     PetscCall(PetscCUPMMallocAsync(&array_allocated_d, n, stream));
405     alloc = array_allocated_d;
406     if (v->offloadmask == PETSC_OFFLOAD_UNALLOCATED) {
407       const auto vimp = VecIMPLCast(v);
408       v->offloadmask  = (vimp && vimp->array) ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU;
409     }
410   }
411   PetscFunctionReturn(PETSC_SUCCESS);
412 }
413 
414 template <device::cupm::DeviceType T, typename D>
415 inline PetscErrorCode Vec_CUPMBase<T, D>::CopyToDevice_(PetscDeviceContext dctx, Vec v, bool forceasync) noexcept
416 {
417   PetscFunctionBegin;
418   PetscCall(DeviceAllocateCheck_(dctx, v));
419   if (v->offloadmask == PETSC_OFFLOAD_CPU) {
420     cupmStream_t stream;
421 
422     v->offloadmask = PETSC_OFFLOAD_BOTH;
423     PetscCall(GetHandlesFrom_(dctx, &stream));
424     PetscCall(PetscLogEventBegin(VEC_CUPMCopyToGPU(), v, 0, 0, 0));
425     PetscCall(PetscCUPMMemcpyAsync(VecCUPMCast(v)->array_d, VecIMPLCast(v)->array, v->map->n, cupmMemcpyHostToDevice, stream, forceasync));
426     PetscCall(PetscLogEventEnd(VEC_CUPMCopyToGPU(), v, 0, 0, 0));
427   }
428   PetscFunctionReturn(PETSC_SUCCESS);
429 }
430 
431 template <device::cupm::DeviceType T, typename D>
432 inline PetscErrorCode Vec_CUPMBase<T, D>::CopyToHost_(PetscDeviceContext dctx, Vec v, bool forceasync) noexcept
433 {
434   PetscFunctionBegin;
435   PetscCall(HostAllocateCheck_(dctx, v));
436   if (v->offloadmask == PETSC_OFFLOAD_GPU) {
437     cupmStream_t stream;
438 
439     v->offloadmask = PETSC_OFFLOAD_BOTH;
440     PetscCall(GetHandlesFrom_(dctx, &stream));
441     PetscCall(PetscLogEventBegin(VEC_CUPMCopyFromGPU(), v, 0, 0, 0));
442     PetscCall(PetscCUPMMemcpyAsync(VecIMPLCast(v)->array, VecCUPMCast(v)->array_d, v->map->n, cupmMemcpyDeviceToHost, stream, forceasync));
443     PetscCall(PetscLogEventEnd(VEC_CUPMCopyFromGPU(), v, 0, 0, 0));
444   }
445   PetscFunctionReturn(PETSC_SUCCESS);
446 }
447 
448 template <device::cupm::DeviceType T, typename D>
449 inline PetscErrorCode Vec_CUPMBase<T, D>::DestroyDevice_(Vec v) noexcept
450 {
451   PetscFunctionBegin;
452   if (const auto vcu = VecCUPMCast(v)) {
453     PetscDeviceContext dctx;
454 
455     PetscCall(GetHandles_(&dctx));
456     PetscCall(ResetAllocatedDevicePtr_(dctx, v));
457     PetscCall(ResetPreallocationCOO_CUPMBase(v, dctx));
458     PetscCall(PetscFree(v->spptr));
459   }
460   PetscFunctionReturn(PETSC_SUCCESS);
461 }
462 
463 template <device::cupm::DeviceType T, typename D>
464 inline PetscErrorCode Vec_CUPMBase<T, D>::DestroyHost_(Vec v) noexcept
465 {
466   PetscFunctionBegin;
467   PetscCall(PetscObjectSAWsViewOff(PetscObjectCast(v)));
468   if (const auto vimpl = VecIMPLCast(v)) {
469     if (auto &array_allocated = vimpl->array_allocated) {
470       const auto useit = UseCUPMHostAlloc(v->pinned_memory);
471 
472       // do this ourselves since we may want to use the cupm functions
473       PetscCall(PetscFree(array_allocated));
474     }
475   }
476   v->pinned_memory = PETSC_FALSE;
477   PetscCall(VecDestroy_IMPL(v));
478   PetscFunctionReturn(PETSC_SUCCESS);
479 }
480 
481 // ==========================================================================================
482 // Vec_CUPMBase - Public API
483 // ==========================================================================================
484 
485 template <device::cupm::DeviceType T, typename D>
486 inline typename Vec_CUPMBase<T, D>::Vec_CUPM *Vec_CUPMBase<T, D>::VecCUPMCast(Vec v) noexcept
487 {
488   return static_cast<Vec_CUPM *>(v->spptr);
489 }
490 
491 // This is a trick to get around the fact that in CRTP the derived class is not yet fully
492 // defined because Base<Derived> must necessarily be instantiated before Derived is
493 // complete. By using a dummy template parameter we make the type "dependent" and so will
494 // only be determined when the derived class is instantiated (and therefore fully defined)
495 template <device::cupm::DeviceType T, typename D>
496 template <typename U>
497 inline constexpr auto Vec_CUPMBase<T, D>::VecIMPLCast(Vec v) noexcept -> decltype(U::VecIMPLCast_(v))
498 {
499   return U::VecIMPLCast_(v);
500 }
501 
502 template <device::cupm::DeviceType T, typename D>
503 inline PetscErrorCode Vec_CUPMBase<T, D>::VecDestroy_IMPL(Vec v) noexcept
504 {
505   return D::VecDestroy_IMPL_(v);
506 }
507 
508 template <device::cupm::DeviceType T, typename D>
509 inline PetscErrorCode Vec_CUPMBase<T, D>::VecResetArray_IMPL(Vec v) noexcept
510 {
511   return D::VecResetArray_IMPL_(v);
512 }
513 
514 template <device::cupm::DeviceType T, typename D>
515 inline PetscErrorCode Vec_CUPMBase<T, D>::VecPlaceArray_IMPL(Vec v, const PetscScalar *a) noexcept
516 {
517   return D::VecPlaceArray_IMPL_(v, a);
518 }
519 
520 template <device::cupm::DeviceType T, typename D>
521 inline PetscErrorCode Vec_CUPMBase<T, D>::VecCreate_IMPL_Private(Vec v, PetscBool *alloc_missing, PetscInt nghost, PetscScalar *host_array) noexcept
522 {
523   return D::VecCreate_IMPL_Private_(v, alloc_missing, nghost, host_array);
524 }
525 
526 template <device::cupm::DeviceType T, typename D>
527 inline constexpr PetscLogEvent Vec_CUPMBase<T, D>::VEC_CUPMCopyToGPU() noexcept
528 {
529   return T == device::cupm::DeviceType::CUDA ? VEC_CUDACopyToGPU : VEC_HIPCopyToGPU;
530 }
531 
532 template <device::cupm::DeviceType T, typename D>
533 inline constexpr PetscLogEvent Vec_CUPMBase<T, D>::VEC_CUPMCopyFromGPU() noexcept
534 {
535   return T == device::cupm::DeviceType::CUDA ? VEC_CUDACopyFromGPU : VEC_HIPCopyFromGPU;
536 }
537 
538 template <device::cupm::DeviceType T, typename D>
539 inline constexpr VecType Vec_CUPMBase<T, D>::VECSEQCUPM() noexcept
540 {
541   return T == device::cupm::DeviceType::CUDA ? VECSEQCUDA : VECSEQHIP;
542 }
543 
544 template <device::cupm::DeviceType T, typename D>
545 inline constexpr VecType Vec_CUPMBase<T, D>::VECMPICUPM() noexcept
546 {
547   return T == device::cupm::DeviceType::CUDA ? VECMPICUDA : VECMPIHIP;
548 }
549 
550 template <device::cupm::DeviceType T, typename D>
551 inline constexpr VecType Vec_CUPMBase<T, D>::VECCUPM() noexcept
552 {
553   return T == device::cupm::DeviceType::CUDA ? VECCUDA : VECHIP;
554 }
555 
556 template <device::cupm::DeviceType T, typename D>
557 template <typename U>
558 inline constexpr VecType Vec_CUPMBase<T, D>::VECIMPLCUPM() noexcept
559 {
560   return U::VECIMPLCUPM_();
561 }
562 
563 template <device::cupm::DeviceType T, typename D>
564 template <typename U>
565 inline constexpr VecType Vec_CUPMBase<T, D>::VECIMPL() noexcept
566 {
567   return U::VECIMPL_();
568 }
569 
570 // private version that takes a PetscDeviceContext, called by the public variant
571 template <device::cupm::DeviceType T, typename D>
572 template <PetscMemType mtype, PetscMemoryAccessMode access, bool force>
573 inline PetscErrorCode Vec_CUPMBase<T, D>::GetArray(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
574 {
575   constexpr auto hostmem     = PetscMemTypeHost(mtype);
576   const auto     oldmask     = v->offloadmask;
577   auto          &mask        = v->offloadmask;
578   auto           should_sync = false;
579 
580   PetscFunctionBegin;
581   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
582   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
583   if (PetscMemoryAccessRead(access)) {
584     // READ or READ_WRITE
585     if (((oldmask == PETSC_OFFLOAD_GPU) && hostmem) || ((oldmask == PETSC_OFFLOAD_CPU) && !hostmem)) {
586       // if we move the data we should set the flag to synchronize later on
587       should_sync = true;
588     }
589     PetscCall((hostmem ? CopyToHost_ : CopyToDevice_)(dctx, v, force));
590   } else {
591     // WRITE only
592     PetscCall((hostmem ? HostAllocateCheck_ : DeviceAllocateCheck_)(dctx, v));
593   }
594   *a = hostmem ? VecIMPLCast(v)->array : VecCUPMCast(v)->array_d;
595   // if unallocated previously we should zero things out if we intend to read
596   if (PetscMemoryAccessRead(access) && (oldmask == PETSC_OFFLOAD_UNALLOCATED)) {
597     const auto n = v->map->n;
598 
599     if (hostmem) {
600       PetscCall(PetscArrayzero(*a, n));
601     } else {
602       cupmStream_t stream;
603 
604       PetscCall(GetHandlesFrom_(dctx, &stream));
605       PetscCall(PetscCUPMMemsetAsync(*a, 0, n, stream, force));
606       should_sync = true;
607     }
608   }
609   // update the offloadmask if we intend to write, since we assume immediately modified
610   if (PetscMemoryAccessWrite(access)) {
611     PetscCall(VecSetErrorIfLocked(v, 1));
612     // REVIEW ME: this should probably also call PetscObjectStateIncrease() since we assume it
613     // is immediately modified
614     mask = hostmem ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU;
615   }
616   // if we are a globally blocking stream and we have MOVED data then we should synchronize,
617   // since even doing async calls on the NULL stream is not synchronous
618   if (!force && should_sync) PetscCall(PetscDeviceContextSynchronize(dctx));
619   PetscFunctionReturn(PETSC_SUCCESS);
620 }
621 
622 // v->ops->getarray[read|write] or VecCUPMGetArray[Read|Write]()
623 template <device::cupm::DeviceType T, typename D>
624 template <PetscMemType mtype, PetscMemoryAccessMode access, bool force>
625 inline PetscErrorCode Vec_CUPMBase<T, D>::GetArray(Vec v, PetscScalar **a) noexcept
626 {
627   PetscDeviceContext dctx;
628 
629   PetscFunctionBegin;
630   PetscCall(GetHandles_(&dctx));
631   PetscCall(D::template GetArray<mtype, access, force>(v, a, dctx));
632   PetscFunctionReturn(PETSC_SUCCESS);
633 }
634 
635 // private version that takes a PetscDeviceContext, called by the public variant
636 template <device::cupm::DeviceType T, typename D>
637 template <PetscMemType mtype, PetscMemoryAccessMode access>
638 inline PetscErrorCode Vec_CUPMBase<T, D>::RestoreArray(Vec v, PetscScalar **a, PetscDeviceContext) noexcept
639 {
640   PetscFunctionBegin;
641   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
642   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
643   if (PetscMemoryAccessWrite(access)) {
644     // WRITE or READ_WRITE
645     PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
646     v->offloadmask = PetscMemTypeHost(mtype) ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU;
647   }
648   if (a) {
649     PetscCall(CheckPointerMatchesMemType_(*a, mtype));
650     *a = nullptr;
651   }
652   PetscFunctionReturn(PETSC_SUCCESS);
653 }
654 
655 // v->ops->restorearray[read|write] or VecCUPMRestoreArray[Read|Write]()
656 template <device::cupm::DeviceType T, typename D>
657 template <PetscMemType mtype, PetscMemoryAccessMode access>
658 inline PetscErrorCode Vec_CUPMBase<T, D>::RestoreArray(Vec v, PetscScalar **a) noexcept
659 {
660   PetscDeviceContext dctx;
661 
662   PetscFunctionBegin;
663   PetscCall(GetHandles_(&dctx));
664   PetscCall(D::template RestoreArray<mtype, access>(v, a, dctx));
665   PetscFunctionReturn(PETSC_SUCCESS);
666 }
667 
668 template <device::cupm::DeviceType T, typename D>
669 template <PetscMemoryAccessMode access>
670 inline PetscErrorCode Vec_CUPMBase<T, D>::GetArrayAndMemtype(Vec v, PetscScalar **a, PetscMemType *mtype, PetscDeviceContext dctx) noexcept
671 {
672   PetscFunctionBegin;
673   PetscCall(D::template GetArray<PETSC_MEMTYPE_DEVICE, access>(v, a, dctx));
674   if (mtype) *mtype = (PetscDefined(HAVE_NVSHMEM) && VecCUPMCast(v)->nvshmem) ? PETSC_MEMTYPE_NVSHMEM : PETSC_MEMTYPE_CUPM();
675   PetscFunctionReturn(PETSC_SUCCESS);
676 }
677 
678 // v->ops->getarrayandmemtype
679 template <device::cupm::DeviceType T, typename D>
680 template <PetscMemoryAccessMode access>
681 inline PetscErrorCode Vec_CUPMBase<T, D>::GetArrayAndMemtype(Vec v, PetscScalar **a, PetscMemType *mtype) noexcept
682 {
683   PetscDeviceContext dctx;
684 
685   PetscFunctionBegin;
686   PetscCall(GetHandles_(&dctx));
687   PetscCall(D::template GetArrayAndMemtype<access>(v, a, mtype, dctx));
688   PetscFunctionReturn(PETSC_SUCCESS);
689 }
690 
691 template <device::cupm::DeviceType T, typename D>
692 template <PetscMemoryAccessMode access>
693 inline PetscErrorCode Vec_CUPMBase<T, D>::RestoreArrayAndMemtype(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
694 {
695   PetscFunctionBegin;
696   PetscCall(D::template RestoreArray<PETSC_MEMTYPE_DEVICE, access>(v, a, dctx));
697   PetscFunctionReturn(PETSC_SUCCESS);
698 }
699 
700 // v->ops->restorearrayandmemtype
701 template <device::cupm::DeviceType T, typename D>
702 template <PetscMemoryAccessMode access>
703 inline PetscErrorCode Vec_CUPMBase<T, D>::RestoreArrayAndMemtype(Vec v, PetscScalar **a) noexcept
704 {
705   PetscDeviceContext dctx;
706 
707   PetscFunctionBegin;
708   PetscCall(GetHandles_(&dctx));
709   PetscCall(D::template RestoreArrayAndMemtype<access>(v, a, dctx));
710   PetscFunctionReturn(PETSC_SUCCESS);
711 }
712 
713 // v->ops->placearray or VecCUPMPlaceArray()
714 template <device::cupm::DeviceType T, typename D>
715 template <PetscMemType mtype>
716 inline PetscErrorCode Vec_CUPMBase<T, D>::PlaceArray(Vec v, const PetscScalar *a) noexcept
717 {
718   PetscDeviceContext dctx;
719 
720   PetscFunctionBegin;
721   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
722   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
723   PetscCall(CheckPointerMatchesMemType_(a, mtype));
724   PetscCall(GetHandles_(&dctx));
725   if (PetscMemTypeHost(mtype)) {
726     PetscCall(CopyToHost_(dctx, v));
727     PetscCall(VecPlaceArray_IMPL(v, a));
728     v->offloadmask = PETSC_OFFLOAD_CPU;
729   } else {
730     PetscCall(VecIMPLAllocateCheck_(v));
731     {
732       auto &backup_array = VecIMPLCast(v)->unplacedarray;
733 
734       PetscCheck(!backup_array, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "VecPlaceArray() was already called on this vector, without a call to VecResetArray()");
735       PetscCall(CopyToDevice_(dctx, v));
736       PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
737       backup_array = util::exchange(VecCUPMCast(v)->array_d, const_cast<PetscScalar *>(a));
738       // only update the offload mask if we actually assign a pointer
739       if (a) v->offloadmask = PETSC_OFFLOAD_GPU;
740     }
741   }
742   PetscFunctionReturn(PETSC_SUCCESS);
743 }
744 
745 // v->ops->replacearray or VecCUPMReplaceArray()
746 template <device::cupm::DeviceType T, typename D>
747 template <PetscMemType mtype>
748 inline PetscErrorCode Vec_CUPMBase<T, D>::ReplaceArray(Vec v, const PetscScalar *a) noexcept
749 {
750   const auto         aptr = const_cast<PetscScalar *>(a);
751   PetscDeviceContext dctx;
752 
753   PetscFunctionBegin;
754   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
755   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
756   PetscCall(CheckPointerMatchesMemType_(a, mtype));
757   PetscCall(GetHandles_(&dctx));
758   if (PetscMemTypeHost(mtype)) {
759     PetscCall(VecIMPLAllocateCheck_(v));
760     {
761       const auto vimpl      = VecIMPLCast(v);
762       auto      &host_array = vimpl->array_allocated;
763 
764       // make sure the users array has the latest values.
765       // REVIEW ME: why? we're about to free it
766       if (host_array != vimpl->array) PetscCall(CopyToHost_(dctx, v));
767       if (host_array) {
768         const auto useit = UseCUPMHostAlloc(v->pinned_memory);
769 
770         PetscCall(PetscFree(host_array));
771       }
772       host_array       = aptr;
773       vimpl->array     = host_array;
774       v->pinned_memory = PETSC_FALSE; // REVIEW ME: we can determine this
775       v->offloadmask   = PETSC_OFFLOAD_CPU;
776     }
777   } else {
778     PetscCall(VecCUPMAllocateCheck_(v));
779     {
780       const auto vcu = VecCUPMCast(v);
781 
782       PetscCall(ResetAllocatedDevicePtr_(dctx, v, aptr));
783       // don't update the offloadmask if placed pointer is NULL
784       vcu->array_d = vcu->array_allocated_d /* = aptr */;
785       if (aptr) v->offloadmask = PETSC_OFFLOAD_GPU;
786     }
787   }
788   PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
789   PetscFunctionReturn(PETSC_SUCCESS);
790 }
791 
792 // v->ops->resetarray or VecCUPMResetArray()
793 template <device::cupm::DeviceType T, typename D>
794 template <PetscMemType mtype>
795 inline PetscErrorCode Vec_CUPMBase<T, D>::ResetArray(Vec v) noexcept
796 {
797   PetscDeviceContext dctx;
798 
799   PetscFunctionBegin;
800   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
801   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
802   PetscCall(GetHandles_(&dctx));
803   // REVIEW ME:
804   // this is wildly inefficient but must be done if we assume that the placed array must have
805   // correct values
806   if (PetscMemTypeHost(mtype)) {
807     PetscCall(CopyToHost_(dctx, v));
808     PetscCall(VecResetArray_IMPL(v));
809     v->offloadmask = PETSC_OFFLOAD_CPU;
810   } else {
811     PetscCall(VecIMPLAllocateCheck_(v));
812     PetscCall(VecCUPMAllocateCheck_(v));
813     {
814       const auto vcu        = VecCUPMCast(v);
815       const auto vimpl      = VecIMPLCast(v);
816       auto      &host_array = vimpl->unplacedarray;
817 
818       PetscCall(CheckPointerMatchesMemType_(host_array, PETSC_MEMTYPE_DEVICE));
819       PetscCall(CopyToDevice_(dctx, v));
820       PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
821       // Need to reset the offloadmask. If we had a stashed pointer we are on the GPU,
822       // otherwise check if the host has a valid pointer. If neither, then we are not
823       // allocated.
824       vcu->array_d = host_array;
825       if (host_array) {
826         host_array     = nullptr;
827         v->offloadmask = PETSC_OFFLOAD_GPU;
828       } else if (vimpl->array) {
829         v->offloadmask = PETSC_OFFLOAD_CPU;
830       } else {
831         v->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
832       }
833     }
834   }
835   PetscFunctionReturn(PETSC_SUCCESS);
836 }
837 
838 // v->ops->create
839 template <device::cupm::DeviceType T, typename D>
840 inline PetscErrorCode Vec_CUPMBase<T, D>::Create(Vec v) noexcept
841 {
842   PetscBool          alloc_missing;
843   PetscDeviceContext dctx;
844 
845   PetscFunctionBegin;
846   PetscCall(VecCreate_IMPL_Private(v, &alloc_missing));
847   PetscCall(GetHandles_(&dctx));
848   PetscCall(Initialize_CUPMBase(v, alloc_missing, nullptr, nullptr, dctx));
849   PetscFunctionReturn(PETSC_SUCCESS);
850 }
851 
852 // v->ops->destroy
853 template <device::cupm::DeviceType T, typename D>
854 inline PetscErrorCode Vec_CUPMBase<T, D>::Destroy(Vec v) noexcept
855 {
856   PetscFunctionBegin;
857   PetscCall(DestroyDevice_(v));
858   PetscCall(DestroyHost_(v));
859   PetscFunctionReturn(PETSC_SUCCESS);
860 }
861 
862 // ================================================================================== //
863 //                      Common core between Seq and MPI                               //
864 
865 // VecCreate_CUPM()
866 template <device::cupm::DeviceType T, typename D>
867 inline PetscErrorCode Vec_CUPMBase<T, D>::Create_CUPM(Vec v) noexcept
868 {
869   PetscMPIInt size;
870 
871   PetscFunctionBegin;
872   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
873   PetscCall(VecSetType(v, size > 1 ? VECMPICUPM() : VECSEQCUPM()));
874   PetscFunctionReturn(PETSC_SUCCESS);
875 }
876 
877 // VecCreateCUPM()
878 template <device::cupm::DeviceType T, typename D>
879 inline PetscErrorCode Vec_CUPMBase<T, D>::Create_CUPMBase(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, Vec *v, PetscBool call_set_type, PetscLayout reference) noexcept
880 {
881   PetscFunctionBegin;
882   PetscCall(VecCreate(comm, v));
883   if (reference) PetscCall(PetscLayoutReference(reference, &(*v)->map));
884   PetscCall(VecSetSizes(*v, n, N));
885   if (bs) PetscCall(VecSetBlockSize(*v, bs));
886   if (call_set_type) PetscCall(VecSetType(*v, VECIMPLCUPM()));
887   PetscFunctionReturn(PETSC_SUCCESS);
888 }
889 
890 // VecCreateIMPL_CUPM(), called through v->ops->create
891 template <device::cupm::DeviceType T, typename D>
892 inline PetscErrorCode Vec_CUPMBase<T, D>::Initialize_CUPMBase(Vec v, PetscBool allocate_missing, PetscScalar *host_array, PetscScalar *device_array, PetscDeviceContext dctx) noexcept
893 {
894   PetscFunctionBegin;
895   // REVIEW ME: perhaps not needed
896   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
897   PetscCall(PetscObjectChangeTypeName(PetscObjectCast(v), VECIMPLCUPM()));
898   PetscCall(D::BindToCPU(v, PETSC_FALSE));
899   if (device_array) {
900     PetscCall(CheckPointerMatchesMemType_(device_array, PETSC_MEMTYPE_CUPM()));
901     PetscCall(VecCUPMAllocateCheck_(v));
902     VecCUPMCast(v)->array_d = device_array;
903   }
904   if (host_array) {
905     PetscCall(CheckPointerMatchesMemType_(host_array, PETSC_MEMTYPE_HOST));
906     VecIMPLCast(v)->array = host_array;
907   }
908   if (allocate_missing) {
909     PetscCall(DeviceAllocateCheck_(dctx, v));
910     PetscCall(HostAllocateCheck_(dctx, v));
911     // REVIEW ME: junchao, is this needed with new calloc() branch? VecSet() will call
912     // set() for reference
913     // calls device-version
914     PetscCall(VecSet(v, 0));
915     // zero the host while device is underway
916     PetscCall(PetscArrayzero(VecIMPLCast(v)->array, v->map->n));
917     v->offloadmask = PETSC_OFFLOAD_BOTH;
918   } else {
919     if (host_array) {
920       v->offloadmask = device_array ? PETSC_OFFLOAD_BOTH : PETSC_OFFLOAD_CPU;
921     } else {
922       v->offloadmask = device_array ? PETSC_OFFLOAD_GPU : PETSC_OFFLOAD_UNALLOCATED;
923     }
924   }
925   PetscFunctionReturn(PETSC_SUCCESS);
926 }
927 
928 // v->ops->duplicate
929 template <device::cupm::DeviceType T, typename D>
930 template <typename SetupFunctionT>
931 inline PetscErrorCode Vec_CUPMBase<T, D>::Duplicate_CUPMBase(Vec v, Vec *y, PetscDeviceContext dctx, SetupFunctionT &&DerivedCreateIMPLCUPM_Async) noexcept
932 {
933   // if the derived setup is the default no_op then we should call VecSetType()
934   constexpr auto call_set_type = static_cast<PetscBool>(std::is_same<SetupFunctionT, no_op>::value);
935   const auto     vobj          = PetscObjectCast(v);
936   const auto     map           = v->map;
937   PetscInt       bs;
938 
939   PetscFunctionBegin;
940   PetscCall(VecGetBlockSize(v, &bs));
941   PetscCall(Create_CUPMBase(PetscObjectComm(vobj), bs, map->n, map->N, y, call_set_type, map));
942   // Derived class can set up the remainder of the data structures here
943   PetscCall(DerivedCreateIMPLCUPM_Async(*y));
944   // If the other vector is bound to CPU then the memcpy of the ops struct will give the
945   // duplicated vector the host "getarray" function which does not lazily allocate the array
946   // (as it is assumed to always exist). So we force allocation here, before we overwrite the
947   // ops
948   if (v->boundtocpu) PetscCall(HostAllocateCheck_(dctx, *y));
949   // in case the user has done some VecSetOps() tomfoolery
950   (*y)->ops[0] = v->ops[0];
951   {
952     const auto yobj = PetscObjectCast(*y);
953 
954     PetscCall(PetscObjectListDuplicate(vobj->olist, &yobj->olist));
955     PetscCall(PetscFunctionListDuplicate(vobj->qlist, &yobj->qlist));
956   }
957   (*y)->stash.donotstash   = v->stash.donotstash;
958   (*y)->stash.ignorenegidx = v->stash.ignorenegidx;
959   (*y)->map->bs            = std::abs(v->map->bs);
960   (*y)->bstash.bs          = v->bstash.bs;
961   PetscFunctionReturn(PETSC_SUCCESS);
962 }
963 
964   #define VecSetOp_CUPM(op_name, op_host, ...) \
965     do { \
966       if (usehost) { \
967         v->ops->op_name = op_host; \
968       } else { \
969         v->ops->op_name = __VA_ARGS__; \
970       } \
971     } while (0)
972 
973 // v->ops->bindtocpu
974 template <device::cupm::DeviceType T, typename D>
975 inline PetscErrorCode Vec_CUPMBase<T, D>::BindToCPU_CUPMBase(Vec v, PetscBool usehost, PetscDeviceContext dctx) noexcept
976 {
977   PetscFunctionBegin;
978   v->boundtocpu = usehost;
979   if (usehost) PetscCall(CopyToHost_(dctx, v));
980   PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &v->defaultrandtype));
981 
982   // set the base functions that are guaranteed to be the same for both
983   v->ops->duplicate = D::Duplicate;
984   v->ops->create    = D::Create;
985   v->ops->destroy   = D::Destroy;
986   v->ops->bindtocpu = D::BindToCPU;
987   // Note that setting these to NULL on host breaks convergence in certain areas. I don't know
988   // why, and I don't know how, but it is IMPERATIVE these are set as such!
989   v->ops->replacearray = D::template ReplaceArray<PETSC_MEMTYPE_HOST>;
990   v->ops->restorearray = D::template RestoreArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>;
991 
992   // set device-only common functions
993   VecSetOp_CUPM(dotnorm2, nullptr, D::DotNorm2);
994   VecSetOp_CUPM(getarray, nullptr, D::template GetArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>);
995   VecSetOp_CUPM(getarraywrite, nullptr, D::template GetArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>);
996   VecSetOp_CUPM(restorearraywrite, nullptr, D::template RestoreArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>);
997 
998   VecSetOp_CUPM(getarrayread, nullptr, [](Vec v, const PetscScalar **a) { return D::template GetArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>(v, const_cast<PetscScalar **>(a)); });
999   VecSetOp_CUPM(restorearrayread, nullptr, [](Vec v, const PetscScalar **a) { return D::template RestoreArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>(v, const_cast<PetscScalar **>(a)); });
1000 
1001   VecSetOp_CUPM(getarrayandmemtype, nullptr, D::template GetArrayAndMemtype<PETSC_MEMORY_ACCESS_READ_WRITE>);
1002   VecSetOp_CUPM(restorearrayandmemtype, nullptr, D::template RestoreArrayAndMemtype<PETSC_MEMORY_ACCESS_READ_WRITE>);
1003 
1004   VecSetOp_CUPM(getarraywriteandmemtype, nullptr, D::template GetArrayAndMemtype<PETSC_MEMORY_ACCESS_WRITE>);
1005   VecSetOp_CUPM(restorearraywriteandmemtype, nullptr, [](Vec v, PetscScalar **a, PetscMemType *) { return D::template RestoreArrayAndMemtype<PETSC_MEMORY_ACCESS_WRITE>(v, a); });
1006 
1007   VecSetOp_CUPM(getarrayreadandmemtype, nullptr, [](Vec v, const PetscScalar **a, PetscMemType *m) { return D::template GetArrayAndMemtype<PETSC_MEMORY_ACCESS_READ>(v, const_cast<PetscScalar **>(a), m); });
1008   VecSetOp_CUPM(restorearrayreadandmemtype, nullptr, [](Vec v, const PetscScalar **a) { return D::template RestoreArrayAndMemtype<PETSC_MEMORY_ACCESS_READ>(v, const_cast<PetscScalar **>(a)); });
1009 
1010   // set the functions that are always sequential
1011   using VecSeq_T = VecSeq_CUPM<T>;
1012   VecSetOp_CUPM(scale, VecScale_Seq, VecSeq_T::Scale);
1013   VecSetOp_CUPM(copy, VecCopy_Seq, VecSeq_T::Copy);
1014   VecSetOp_CUPM(set, VecSet_Seq, VecSeq_T::Set);
1015   VecSetOp_CUPM(swap, VecSwap_Seq, VecSeq_T::Swap);
1016   VecSetOp_CUPM(axpy, VecAXPY_Seq, VecSeq_T::AXPY);
1017   VecSetOp_CUPM(axpby, VecAXPBY_Seq, VecSeq_T::AXPBY);
1018   VecSetOp_CUPM(maxpy, VecMAXPY_Seq, VecSeq_T::MAXPY);
1019   VecSetOp_CUPM(aypx, VecAYPX_Seq, VecSeq_T::AYPX);
1020   VecSetOp_CUPM(waxpy, VecWAXPY_Seq, VecSeq_T::WAXPY);
1021   VecSetOp_CUPM(axpbypcz, VecAXPBYPCZ_Seq, VecSeq_T::AXPBYPCZ);
1022   VecSetOp_CUPM(pointwisemult, VecPointwiseMult_Seq, VecSeq_T::PointwiseMult);
1023   VecSetOp_CUPM(pointwisedivide, VecPointwiseDivide_Seq, VecSeq_T::PointwiseDivide);
1024   VecSetOp_CUPM(setrandom, VecSetRandom_Seq, VecSeq_T::SetRandom);
1025   VecSetOp_CUPM(dot_local, VecDot_Seq, VecSeq_T::Dot);
1026   VecSetOp_CUPM(tdot_local, VecTDot_Seq, VecSeq_T::TDot);
1027   VecSetOp_CUPM(norm_local, VecNorm_Seq, VecSeq_T::Norm);
1028   VecSetOp_CUPM(mdot_local, VecMDot_Seq, VecSeq_T::MDot);
1029   VecSetOp_CUPM(reciprocal, VecReciprocal_Default, VecSeq_T::Reciprocal);
1030   VecSetOp_CUPM(shift, nullptr, VecSeq_T::Shift);
1031   VecSetOp_CUPM(getlocalvector, nullptr, VecSeq_T::template GetLocalVector<PETSC_MEMORY_ACCESS_READ_WRITE>);
1032   VecSetOp_CUPM(restorelocalvector, nullptr, VecSeq_T::template RestoreLocalVector<PETSC_MEMORY_ACCESS_READ_WRITE>);
1033   VecSetOp_CUPM(getlocalvectorread, nullptr, VecSeq_T::template GetLocalVector<PETSC_MEMORY_ACCESS_READ>);
1034   VecSetOp_CUPM(restorelocalvectorread, nullptr, VecSeq_T::template RestoreLocalVector<PETSC_MEMORY_ACCESS_READ>);
1035   VecSetOp_CUPM(sum, nullptr, VecSeq_T::Sum);
1036   PetscFunctionReturn(PETSC_SUCCESS);
1037 }
1038 
1039 // Called from VecGetSubVector()
1040 template <device::cupm::DeviceType T, typename D>
1041 inline PetscErrorCode Vec_CUPMBase<T, D>::GetArrays_CUPMBase(Vec v, const PetscScalar **host_array, const PetscScalar **device_array, PetscOffloadMask *mask, PetscDeviceContext dctx) noexcept
1042 {
1043   PetscFunctionBegin;
1044   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1045   if (host_array) {
1046     PetscCall(HostAllocateCheck_(dctx, v));
1047     *host_array = VecIMPLCast(v)->array;
1048   }
1049   if (device_array) {
1050     PetscCall(DeviceAllocateCheck_(dctx, v));
1051     *device_array = VecCUPMCast(v)->array_d;
1052   }
1053   if (mask) *mask = v->offloadmask;
1054   PetscFunctionReturn(PETSC_SUCCESS);
1055 }
1056 
1057 template <device::cupm::DeviceType T, typename D>
1058 inline PetscErrorCode Vec_CUPMBase<T, D>::ResetPreallocationCOO_CUPMBase(Vec v, PetscDeviceContext dctx) noexcept
1059 {
1060   PetscFunctionBegin;
1061   if (const auto vcu = VecCUPMCast(v)) {
1062     cupmStream_t stream;
1063     // clang-format off
1064     const auto   cntptrs = util::make_array(
1065       std::ref(vcu->jmap1_d),
1066       std::ref(vcu->perm1_d),
1067       std::ref(vcu->imap2_d),
1068       std::ref(vcu->jmap2_d),
1069       std::ref(vcu->perm2_d),
1070       std::ref(vcu->Cperm_d)
1071     );
1072     // clang-format on
1073 
1074     PetscCall(GetHandlesFrom_(dctx, &stream));
1075     for (auto &&ptr : cntptrs) PetscCallCUPM(cupmFreeAsync(ptr.get(), stream));
1076     for (auto &&ptr : util::make_array(std::ref(vcu->sendbuf_d), std::ref(vcu->recvbuf_d))) PetscCallCUPM(cupmFreeAsync(ptr.get(), stream));
1077   }
1078   PetscFunctionReturn(PETSC_SUCCESS);
1079 }
1080 
1081 template <device::cupm::DeviceType T, typename D>
1082 template <std::size_t NCount, std::size_t NScal>
1083 inline PetscErrorCode Vec_CUPMBase<T, D>::SetPreallocationCOO_CUPMBase(Vec v, PetscCount, const PetscInt[], PetscDeviceContext dctx, const std::array<CooPair<PetscCount>, NCount> &extra_cntptrs, const std::array<CooPair<PetscScalar>, NScal> &bufptrs) noexcept
1084 {
1085   PetscFunctionBegin;
1086   PetscCall(ResetPreallocationCOO_CUPMBase(v, dctx));
1087   // need to instantiate the private pointer if not already
1088   PetscCall(VecCUPMAllocateCheck_(v));
1089   {
1090     const auto vimpl = VecIMPLCast(v);
1091     const auto vcu   = VecCUPMCast(v);
1092     // clang-format off
1093     const auto cntptrs = util::concat_array(
1094       util::make_array(
1095         make_coo_pair(vcu->jmap1_d, vimpl->jmap1, v->map->n + 1),
1096         make_coo_pair(vcu->perm1_d, vimpl->perm1, vimpl->tot1)
1097       ),
1098       extra_cntptrs
1099     );
1100     // clang-format on
1101     cupmStream_t stream;
1102 
1103     PetscCall(GetHandlesFrom_(dctx, &stream));
1104     // allocate
1105     for (auto &elem : cntptrs) PetscCall(PetscCUPMMallocAsync(&elem.device, elem.size, stream));
1106     for (auto &elem : bufptrs) PetscCall(PetscCUPMMallocAsync(&elem.device, elem.size, stream));
1107     // copy
1108     for (const auto &elem : cntptrs) PetscCall(PetscCUPMMemcpyAsync(elem.device, elem.host, elem.size, cupmMemcpyHostToDevice, stream, true));
1109     for (const auto &elem : bufptrs) PetscCall(PetscCUPMMemcpyAsync(elem.device, elem.host, elem.size, cupmMemcpyHostToDevice, stream, true));
1110   }
1111   PetscFunctionReturn(PETSC_SUCCESS);
1112 }
1113 
1114 template <device::cupm::DeviceType T, typename D>
1115 inline PetscErrorCode Vec_CUPMBase<T, D>::Convert_IMPL_IMPLCUPM(Vec v) noexcept
1116 {
1117   const auto         n        = v->map->n;
1118   const auto         vimpl    = VecIMPLCast(v);
1119   auto              &impl_arr = vimpl->array;
1120   PetscBool          set      = PETSC_FALSE;
1121   PetscDeviceContext dctx;
1122 
1123   PetscFunctionBegin;
1124   // If users do not explicitly require pinned memory, we prefer keeping the vector's regular
1125   // host array
1126   PetscCall(VecCUPMCheckMinimumPinnedMemory_Internal(v, &set));
1127   if (set && impl_arr && ((n * sizeof(*impl_arr)) > v->minimum_bytes_pinned_memory)) {
1128     auto        &impl_alloc = vimpl->array_allocated;
1129     PetscScalar *new_arr;
1130 
1131     // users require pinned memory
1132     {
1133       // Allocate pinned memory and copy over the old array
1134       const auto useit = UseCUPMHostAlloc(PETSC_TRUE);
1135 
1136       PetscCall(PetscMalloc1(n, &new_arr));
1137       PetscCall(PetscArraycpy(new_arr, impl_arr, n));
1138     }
1139     PetscCall(PetscFree(impl_alloc));
1140     impl_arr         = new_arr;
1141     impl_alloc       = new_arr;
1142     v->offloadmask   = PETSC_OFFLOAD_CPU;
1143     v->pinned_memory = PETSC_TRUE;
1144   }
1145   PetscCall(GetHandles_(&dctx));
1146   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, impl_arr, nullptr, dctx));
1147   PetscFunctionReturn(PETSC_SUCCESS);
1148 }
1149 
1150   #define PETSC_VEC_CUPM_BASE_CLASS_HEADER(name, Tp, ...) \
1151     PETSC_CUPMOBJECT_HEADER(Tp); \
1152     using name = ::Petsc::vec::cupm::impl::Vec_CUPMBase<Tp, __VA_ARGS__>; \
1153     friend name; \
1154     /* introspection */ \
1155     using name::VecCUPMCast; \
1156     using name::VecIMPLCast; \
1157     using name::VECIMPLCUPM; \
1158     using name::VECIMPL; \
1159     using name::VECSEQCUPM; \
1160     using name::VECMPICUPM; \
1161     using name::VECCUPM; \
1162     using name::VecView_Debug; \
1163     /* utility */ \
1164     using typename name::Vec_CUPM; \
1165     using name::VecCUPMAllocateCheck_; \
1166     using name::VecIMPLAllocateCheck_; \
1167     using name::HostAllocateCheck_; \
1168     using name::DeviceAllocateCheck_; \
1169     using name::CopyToDevice_; \
1170     using name::CopyToHost_; \
1171     using name::Create; \
1172     using name::Destroy; \
1173     using name::GetArray; \
1174     using name::RestoreArray; \
1175     using name::GetArrayAndMemtype; \
1176     using name::RestoreArrayAndMemtype; \
1177     using name::PlaceArray; \
1178     using name::ReplaceArray; \
1179     using name::ResetArray; \
1180     /* base functions */ \
1181     using name::Create_CUPMBase; \
1182     using name::Initialize_CUPMBase; \
1183     using name::Duplicate_CUPMBase; \
1184     using name::BindToCPU_CUPMBase; \
1185     using name::Create_CUPM; \
1186     using name::DeviceArrayRead; \
1187     using name::DeviceArrayWrite; \
1188     using name::DeviceArrayReadWrite; \
1189     using name::HostArrayRead; \
1190     using name::HostArrayWrite; \
1191     using name::HostArrayReadWrite; \
1192     using name::ResetPreallocationCOO_CUPMBase; \
1193     using name::SetPreallocationCOO_CUPMBase; \
1194     using name::Convert_IMPL_IMPLCUPM;
1195 
1196 } // namespace impl
1197 
1198 } // namespace cupm
1199 
1200 } // namespace vec
1201 
1202 } // namespace Petsc
1203 
1204 #endif // __cplusplus && PetscDefined(HAVE_DEVICE)
1205 
1206 #endif // PETSCVECCUPMIMPL_H
1207