xref: /petsc/include/petsc/private/veccupmimpl.h (revision 390d3996fed1e5c2d89175ebb68cf53cdee176f4) !
1 #pragma once
2 
3 #include <petsc/private/vecimpl.h>
4 #include <../src/vec/vec/impls/dvecimpl.h> // for Vec_Seq
5 
6 #if PetscDefined(HAVE_NVSHMEM)
7 PETSC_INTERN PetscErrorCode PetscNvshmemInitializeCheck(void);
8 PETSC_INTERN PetscErrorCode PetscNvshmemMalloc(size_t, void **);
9 PETSC_INTERN PetscErrorCode PetscNvshmemCalloc(size_t, void **);
10 PETSC_INTERN PetscErrorCode PetscNvshmemFree_Private(void *);
11   #define PetscNvshmemFree(ptr) ((PetscErrorCode)((ptr) && (PetscNvshmemFree_Private(ptr) || ((ptr) = PETSC_NULLPTR, PETSC_SUCCESS))))
12 PETSC_INTERN PetscErrorCode PetscNvshmemSum(PetscInt, PetscScalar *, const PetscScalar *);
13 PETSC_INTERN PetscErrorCode PetscNvshmemMax(PetscInt, PetscReal *, const PetscReal *);
14 PETSC_INTERN PetscErrorCode VecAllocateNVSHMEM_SeqCUDA(Vec);
15 #else
16   #define PetscNvshmemFree(ptr) PETSC_SUCCESS
17 #endif
18 
19 #if defined(__cplusplus) && PetscDefined(HAVE_DEVICE)
20   #include <petsc/private/deviceimpl.h>
21   #include <petsc/private/cupmobject.hpp>
22   #include <petsc/private/cupmblasinterface.hpp>
23 
24   #include <petsc/private/cpp/functional.hpp>
25 
26   #include <limits> // std::numeric_limits
27 
28 namespace Petsc
29 {
30 
31 namespace vec
32 {
33 
34 namespace cupm
35 {
36 
37 namespace impl
38 {
39 
40 namespace
41 {
42 
43 struct no_op {
44   template <typename... T>
operatorno_op45   constexpr PetscErrorCode operator()(T &&...) const noexcept
46   {
47     return PETSC_SUCCESS;
48   }
49 };
50 
51 template <typename T>
52 struct CooPair {
53   using value_type = T;
54   using size_type  = PetscCount;
55 
56   value_type *&device;
57   value_type *&host;
58   size_type    size;
59 };
60 
61 template <typename U>
make_coo_pair(U * & device,U * & host,PetscCount size)62 static constexpr CooPair<U> make_coo_pair(U *&device, U *&host, PetscCount size) noexcept
63 {
64   return {device, host, size};
65 }
66 
67 } // anonymous namespace
68 
69 // forward declarations
70 template <device::cupm::DeviceType>
71 class VecSeq_CUPM;
72 template <device::cupm::DeviceType>
73 class VecMPI_CUPM;
74 
75 // ==========================================================================================
76 // Vec_CUPMBase
77 //
78 // Base class for the VecSeq and VecMPI CUPM implementations. On top of the usual DeviceType
79 // template parameter it also uses CRTP to be able to use values/calls specific to either
80 // VecSeq or VecMPI. This is in effect "inside-out" polymorphism.
81 // ==========================================================================================
82 template <device::cupm::DeviceType T, typename Derived>
83 class Vec_CUPMBase : protected device::cupm::impl::CUPMObject<T> {
84 public:
85   PETSC_CUPMOBJECT_HEADER(T);
86 
87   // ==========================================================================================
88   // Vec_CUPMBase::VectorArray
89   //
90   // RAII versions of the get/restore array routines. Determines constness of the pointer type,
91   // holds the pointer itself provides the implicit conversion operator
92   // ==========================================================================================
93   template <PetscMemType, PetscMemoryAccessMode>
94   class VectorArray;
95 
96 protected:
97   static PetscErrorCode VecView_Debug(Vec v, const char *message = "") noexcept
98   {
99     const auto   pobj  = PetscObjectCast(v);
100     const auto   vimpl = VecIMPLCast(v);
101     const auto   vcu   = VecCUPMCast(v);
102     PetscMemType mtype;
103     MPI_Comm     comm;
104 
105     PetscFunctionBegin;
106     PetscAssertPointer(vimpl, 1);
107     PetscAssertPointer(vcu, 1);
108     PetscCall(PetscObjectGetComm(pobj, &comm));
109     PetscCall(PetscPrintf(comm, "---------- %s ----------\n", message));
110     PetscCall(PetscObjectPrintClassNamePrefixType(pobj, PETSC_VIEWER_STDOUT_(comm)));
111     PetscCall(PetscPrintf(comm, "Address:             %p\n", v));
112     PetscCall(PetscPrintf(comm, "Size:                %" PetscInt_FMT "\n", v->map->n));
113     PetscCall(PetscPrintf(comm, "Offload mask:        %s\n", PetscOffloadMaskToString(v->offloadmask)));
114     PetscCall(PetscPrintf(comm, "Host ptr:            %p\n", vimpl->array));
115     PetscCall(PetscPrintf(comm, "Device ptr:          %p\n", vcu->array_d));
116     PetscCall(PetscPrintf(comm, "Device alloced ptr:  %p\n", vcu->array_allocated_d));
117     PetscCall(PetscCUPMGetMemType(vcu->array_d, &mtype));
118     PetscCall(PetscPrintf(comm, "dptr is device mem?  %s\n", PetscBools[static_cast<PetscBool>(PetscMemTypeDevice(mtype))]));
119     PetscFunctionReturn(PETSC_SUCCESS);
120   }
121 
122   // Delete the allocated device array if required and replace it with the given array
123   static PetscErrorCode ResetAllocatedDevicePtr_(PetscDeviceContext, Vec, PetscScalar * = nullptr) noexcept;
124   // Check either the host or device impl pointer is allocated and allocate it if
125   // isn't. CastFunctionType casts the Vec to the required type and returns the pointer
126   template <typename CastFunctionType>
127   static PetscErrorCode VecAllocateCheck_(Vec, void *&, CastFunctionType &&) noexcept;
128   // Check the CUPM part (v->spptr) is allocated, otherwise allocate it
129   static PetscErrorCode VecCUPMAllocateCheck_(Vec) noexcept;
130   // Check the Host part (v->data) is allocated, otherwise allocate it
131   static PetscErrorCode VecIMPLAllocateCheck_(Vec) noexcept;
132   // Check the Host array is allocated, otherwise allocate it
133   static PetscErrorCode HostAllocateCheck_(PetscDeviceContext, Vec) noexcept;
134   // Check the CUPM array is allocated, otherwise allocate it
135   static PetscErrorCode DeviceAllocateCheck_(PetscDeviceContext, Vec) noexcept;
136   // Copy HTOD, allocating device if necessary
137   static PetscErrorCode CopyToDevice_(PetscDeviceContext, Vec, bool = false) noexcept;
138   // Copy DTOH, allocating host if necessary
139   static PetscErrorCode CopyToHost_(PetscDeviceContext, Vec, bool = false) noexcept;
140   static PetscErrorCode DestroyDevice_(Vec) noexcept;
141   static PetscErrorCode DestroyHost_(Vec) noexcept;
142 
143 public:
144   struct Vec_CUPM {
145     PetscScalar *array_d;           // gpu data
146     PetscScalar *array_allocated_d; // does PETSc own the array ptr?
147     PetscBool    nvshmem;           // is array allocated in nvshmem? It is used to allocate
148                                     // Mvctx->lvec in nvshmem
149 
150     // COO stuff
151     PetscCount *jmap1_d; // [m+1]: i-th entry of the vector has jmap1[i+1]-jmap1[i] repeats
152                          // in COO arrays
153     PetscCount *perm1_d; // [tot1]: permutation array for local entries
154     PetscCount *imap2_d; // [nnz2]: i-th unique entry in recvbuf is imap2[i]-th entry in
155                          // the vector
156     PetscCount *jmap2_d; // [nnz2+1]
157     PetscCount *perm2_d; // [recvlen]
158     PetscCount *Cperm_d; // [sendlen]: permutation array to fill sendbuf[]. 'C' for
159                          // communication
160 
161     // Buffers for remote values in VecSetValuesCOO()
162     PetscScalar *sendbuf_d;
163     PetscScalar *recvbuf_d;
164   };
165 
166   // Cast the Vec to its Vec_CUPM struct, i.e. return the result of (Vec_CUPM *)v->spptr
167   PETSC_NODISCARD static Vec_CUPM *VecCUPMCast(Vec) noexcept;
168   // Cast the Vec to its host struct, i.e. return the result of (Vec_Seq *)v->data
169   template <typename U = Derived>
170   PETSC_NODISCARD static constexpr auto VecIMPLCast(Vec v) noexcept -> decltype(U::VecIMPLCast_(v));
171   // Get the PetscLogEvents for HTOD and DTOH
172   PETSC_NODISCARD static constexpr PetscLogEvent VEC_CUPMCopyToGPU() noexcept;
173   PETSC_NODISCARD static constexpr PetscLogEvent VEC_CUPMCopyFromGPU() noexcept;
174   // Get the VecTypes
175   PETSC_NODISCARD static constexpr VecType VECSEQCUPM() noexcept;
176   PETSC_NODISCARD static constexpr VecType VECMPICUPM() noexcept;
177   PETSC_NODISCARD static constexpr VecType VECCUPM() noexcept;
178 
179   // Get the device VecType of the calling vector
180   template <typename U = Derived>
181   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM() noexcept;
182   // Get the host VecType of the calling vector
183   template <typename U = Derived>
184   PETSC_NODISCARD static constexpr VecType VECIMPL() noexcept;
185 
186   // Call the host destroy function, i.e. VecDestroy_Seq()
187   static PetscErrorCode VecDestroy_IMPL(Vec) noexcept;
188   // Call the host reset function, i.e. VecResetArray_Seq()
189   static PetscErrorCode VecResetArray_IMPL(Vec) noexcept;
190   // ... you get the idea
191   static PetscErrorCode VecPlaceArray_IMPL(Vec, const PetscScalar *) noexcept;
192   // Call the host creation function, i.e. VecCreate_Seq(), and also initialize the CUPM part
193   // along with it if needed
194   static PetscErrorCode VecCreate_IMPL_Private(Vec, PetscBool *, PetscInt = 0, PetscScalar * = nullptr) noexcept;
195 
196   // Shorthand for creating VectorArray's. Need functions to create them, otherwise using them
197   // as an unnamed temporary leads to most vexing parse
198   PETSC_NODISCARD static auto DeviceArrayRead(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>{dctx, v});
199   PETSC_NODISCARD static auto DeviceArrayWrite(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>{dctx, v});
200   PETSC_NODISCARD static auto DeviceArrayReadWrite(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, v});
201   PETSC_NODISCARD static auto HostArrayRead(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>{dctx, v});
202   PETSC_NODISCARD static auto HostArrayWrite(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>{dctx, v});
203   PETSC_NODISCARD static auto HostArrayReadWrite(PetscDeviceContext dctx, Vec v) noexcept PETSC_DECLTYPE_AUTO_RETURNS(VectorArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>{dctx, v});
204 
205   // ops-table functions
206   static PetscErrorCode Create(Vec) noexcept;
207   static PetscErrorCode Destroy(Vec) noexcept;
208   template <PetscMemType, PetscMemoryAccessMode, bool = false>
209   static PetscErrorCode GetArray(Vec, PetscScalar **, PetscDeviceContext) noexcept;
210   template <PetscMemType, PetscMemoryAccessMode, bool = false>
211   static PetscErrorCode GetArray(Vec, PetscScalar **) noexcept;
212   template <PetscMemType, PetscMemoryAccessMode>
213   static PetscErrorCode RestoreArray(Vec, PetscScalar **, PetscDeviceContext) noexcept;
214   template <PetscMemType, PetscMemoryAccessMode>
215   static PetscErrorCode RestoreArray(Vec, PetscScalar **) noexcept;
216   template <PetscMemoryAccessMode>
217   static PetscErrorCode GetArrayAndMemtype(Vec, PetscScalar **, PetscMemType *, PetscDeviceContext) noexcept;
218   template <PetscMemoryAccessMode>
219   static PetscErrorCode GetArrayAndMemtype(Vec, PetscScalar **, PetscMemType *) noexcept;
220   template <PetscMemoryAccessMode>
221   static PetscErrorCode RestoreArrayAndMemtype(Vec, PetscScalar **, PetscDeviceContext) noexcept;
222   template <PetscMemoryAccessMode>
223   static PetscErrorCode RestoreArrayAndMemtype(Vec, PetscScalar **) noexcept;
224   template <PetscMemType>
225   static PetscErrorCode ReplaceArray(Vec, const PetscScalar *) noexcept;
226   template <PetscMemType>
227   static PetscErrorCode ResetArray(Vec) noexcept;
228   template <PetscMemType>
229   static PetscErrorCode PlaceArray(Vec, const PetscScalar *) noexcept;
230 
231   // common ops shared between Seq and MPI
232   static PetscErrorCode Create_CUPM(Vec) noexcept;
233   static PetscErrorCode Create_CUPMBase(MPI_Comm, PetscInt, PetscInt, PetscInt, Vec *, PetscBool, PetscLayout /*reference*/ = nullptr) noexcept;
234   static PetscErrorCode Initialize_CUPMBase(Vec, PetscBool, PetscScalar *, PetscScalar *, PetscDeviceContext) noexcept;
235   template <typename SetupFunctionT = no_op>
236   static PetscErrorCode Duplicate_CUPMBase(Vec, Vec *, PetscDeviceContext, SetupFunctionT && = SetupFunctionT{}) noexcept;
237   static PetscErrorCode BindToCPU_CUPMBase(Vec, PetscBool, PetscDeviceContext) noexcept;
238   static PetscErrorCode GetArrays_CUPMBase(Vec, const PetscScalar **, const PetscScalar **, PetscOffloadMask *, PetscDeviceContext) noexcept;
239   static PetscErrorCode ResetPreallocationCOO_CUPMBase(Vec, PetscDeviceContext) noexcept;
240   template <std::size_t NCount = 0, std::size_t NScal = 0>
241   static PetscErrorCode SetPreallocationCOO_CUPMBase(Vec, PetscCount, const PetscInt[], PetscDeviceContext, const std::array<CooPair<PetscCount>, NCount> & = {}, const std::array<CooPair<PetscScalar>, NScal> & = {}) noexcept;
242 
243   static PetscErrorCode Convert_IMPL_IMPLCUPM(Vec) noexcept;
244 };
245 
246 // ==========================================================================================
247 // Vec_CUPMBase::VectorArray
248 //
249 // RAII versions of the get/restore array routines. Determines constness of the pointer type,
250 // holds the pointer itself and provides the implicit conversion operator.
251 //
252 // On construction this calls the moral equivalent of Vec[CUPM]GetArray[Read|Write]()
253 // (depending on PetscMemoryAccessMode) and on destruction automatically restores the array
254 // for you
255 // ==========================================================================================
256 template <device::cupm::DeviceType T, typename D>
257 template <PetscMemType MT, PetscMemoryAccessMode MA>
258 class Vec_CUPMBase<T, D>::VectorArray : public device::cupm::impl::RestoreableArray<T, MT, MA> {
259   using base_type = device::cupm::impl::RestoreableArray<T, MT, MA>;
260 
261 public:
262   VectorArray(PetscDeviceContext, Vec) noexcept;
263   ~VectorArray() noexcept;
264 
265 private:
266   Vec v_ = nullptr;
267 };
268 
269 // ==========================================================================================
270 // Vec_CUPMBase::VectorArray - Public API
271 // ==========================================================================================
272 
273 template <device::cupm::DeviceType T, typename D>
274 template <PetscMemType MT, PetscMemoryAccessMode MA>
VectorArray(PetscDeviceContext dctx,Vec v)275 inline Vec_CUPMBase<T, D>::VectorArray<MT, MA>::VectorArray(PetscDeviceContext dctx, Vec v) noexcept : base_type{dctx}, v_{v}
276 {
277   PetscFunctionBegin;
278   PetscCallAbort(PETSC_COMM_SELF, Vec_CUPMBase<T, D>::template GetArray<MT, MA, true>(v, &this->ptr_, dctx));
279   PetscFunctionReturnVoid();
280 }
281 
282 template <device::cupm::DeviceType T, typename D>
283 template <PetscMemType MT, PetscMemoryAccessMode MA>
~VectorArray()284 inline Vec_CUPMBase<T, D>::VectorArray<MT, MA>::~VectorArray() noexcept
285 {
286   PetscFunctionBegin;
287   PetscCallAbort(PETSC_COMM_SELF, Vec_CUPMBase<T, D>::template RestoreArray<MT, MA>(v_, &this->ptr_, this->dctx_));
288   PetscFunctionReturnVoid();
289 }
290 
291 // ==========================================================================================
292 // Vec_CUPMBase - Protected API
293 // ==========================================================================================
294 
295 template <device::cupm::DeviceType T, typename D>
ResetAllocatedDevicePtr_(PetscDeviceContext dctx,Vec v,PetscScalar * new_value)296 inline PetscErrorCode Vec_CUPMBase<T, D>::ResetAllocatedDevicePtr_(PetscDeviceContext dctx, Vec v, PetscScalar *new_value) noexcept
297 {
298   auto &device_array = VecCUPMCast(v)->array_allocated_d;
299 
300   PetscFunctionBegin;
301   if (device_array) {
302     if (PetscDefined(HAVE_NVSHMEM) && VecCUPMCast(v)->nvshmem) {
303       PetscCall(PetscNvshmemFree(device_array));
304     } else {
305       cupmStream_t stream;
306 
307       PetscCall(GetHandlesFrom_(dctx, &stream));
308       PetscCallCUPM(cupmFreeAsync(device_array, stream));
309     }
310   }
311   device_array = new_value;
312   PetscFunctionReturn(PETSC_SUCCESS);
313 }
314 
315 namespace
316 {
317 
318 inline PetscErrorCode VecCUPMCheckMinimumPinnedMemory_Internal(Vec v, PetscBool *set = nullptr) noexcept
319 {
320   auto      mem = static_cast<PetscInt>(v->minimum_bytes_pinned_memory);
321   PetscBool flg;
322 
323   PetscFunctionBegin;
324   PetscObjectOptionsBegin(PetscObjectCast(v));
325   PetscCall(PetscOptionsRangeInt("-vec_pinned_memory_min", "Minimum size (in bytes) for an allocation to use pinned memory on host", "VecSetPinnedMemoryMin", mem, &mem, &flg, 0, std::numeric_limits<decltype(mem)>::max()));
326   if (flg) v->minimum_bytes_pinned_memory = mem;
327   PetscOptionsEnd();
328   if (set) *set = flg;
329   PetscFunctionReturn(PETSC_SUCCESS);
330 }
331 
332 } // anonymous namespace
333 
334 template <device::cupm::DeviceType T, typename D>
335 template <typename CastFunctionType>
VecAllocateCheck_(Vec v,void * & dest,CastFunctionType && cast)336 inline PetscErrorCode Vec_CUPMBase<T, D>::VecAllocateCheck_(Vec v, void *&dest, CastFunctionType &&cast) noexcept
337 {
338   PetscFunctionBegin;
339   if (PetscLikely(dest)) PetscFunctionReturn(PETSC_SUCCESS);
340   // do the check here so we don't have to do it in every function
341   PetscCall(checkCupmBlasIntCast(v->map->n));
342   {
343     auto impl = cast(v);
344 
345     PetscCall(PetscNew(&impl));
346     dest = impl;
347   }
348   PetscFunctionReturn(PETSC_SUCCESS);
349 }
350 
351 template <device::cupm::DeviceType T, typename D>
VecIMPLAllocateCheck_(Vec v)352 inline PetscErrorCode Vec_CUPMBase<T, D>::VecIMPLAllocateCheck_(Vec v) noexcept
353 {
354   PetscFunctionBegin;
355   PetscCall(VecAllocateCheck_(v, v->data, VecIMPLCast<D>));
356   PetscFunctionReturn(PETSC_SUCCESS);
357 }
358 
359 // allocate the Vec_CUPM struct. this is normally done through DeviceAllocateCheck_(), but in
360 // certain circumstances (such as when the user places the device array) we do not want to do
361 // the full DeviceAllocateCheck_() as it also allocates the array
362 template <device::cupm::DeviceType T, typename D>
VecCUPMAllocateCheck_(Vec v)363 inline PetscErrorCode Vec_CUPMBase<T, D>::VecCUPMAllocateCheck_(Vec v) noexcept
364 {
365   PetscFunctionBegin;
366   PetscCall(VecAllocateCheck_(v, v->spptr, VecCUPMCast));
367   PetscFunctionReturn(PETSC_SUCCESS);
368 }
369 
370 template <device::cupm::DeviceType T, typename D>
HostAllocateCheck_(PetscDeviceContext,Vec v)371 inline PetscErrorCode Vec_CUPMBase<T, D>::HostAllocateCheck_(PetscDeviceContext, Vec v) noexcept
372 {
373   PetscFunctionBegin;
374   PetscCall(VecIMPLAllocateCheck_(v));
375   if (auto &alloc = VecIMPLCast(v)->array_allocated) PetscFunctionReturn(PETSC_SUCCESS);
376   else {
377     PetscCall(VecCUPMCheckMinimumPinnedMemory_Internal(v));
378     {
379       const auto n     = v->map->n;
380       const auto useit = UseCUPMHostAlloc((n * sizeof(*alloc)) > v->minimum_bytes_pinned_memory);
381 
382       v->pinned_memory = static_cast<decltype(v->pinned_memory)>(useit.value());
383       PetscCall(PetscMalloc1(n, &alloc));
384     }
385     if (!VecIMPLCast(v)->array) VecIMPLCast(v)->array = alloc;
386     if (v->offloadmask == PETSC_OFFLOAD_UNALLOCATED) v->offloadmask = PETSC_OFFLOAD_CPU;
387   }
388   PetscFunctionReturn(PETSC_SUCCESS);
389 }
390 
391 template <device::cupm::DeviceType T, typename D>
DeviceAllocateCheck_(PetscDeviceContext dctx,Vec v)392 inline PetscErrorCode Vec_CUPMBase<T, D>::DeviceAllocateCheck_(PetscDeviceContext dctx, Vec v) noexcept
393 {
394   PetscFunctionBegin;
395   PetscCall(VecCUPMAllocateCheck_(v));
396   if (auto &alloc = VecCUPMCast(v)->array_d) PetscFunctionReturn(PETSC_SUCCESS);
397   else {
398     const auto   n                 = v->map->n;
399     auto        &array_allocated_d = VecCUPMCast(v)->array_allocated_d;
400     cupmStream_t stream;
401 
402     PetscCall(GetHandlesFrom_(dctx, &stream));
403     PetscCall(PetscCUPMMallocAsync(&array_allocated_d, n, stream));
404     alloc = array_allocated_d;
405     if (v->offloadmask == PETSC_OFFLOAD_UNALLOCATED) {
406       const auto vimp = VecIMPLCast(v);
407       v->offloadmask  = (vimp && vimp->array) ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU;
408     }
409   }
410   PetscFunctionReturn(PETSC_SUCCESS);
411 }
412 
413 template <device::cupm::DeviceType T, typename D>
CopyToDevice_(PetscDeviceContext dctx,Vec v,bool forceasync)414 inline PetscErrorCode Vec_CUPMBase<T, D>::CopyToDevice_(PetscDeviceContext dctx, Vec v, bool forceasync) noexcept
415 {
416   PetscFunctionBegin;
417   PetscCall(DeviceAllocateCheck_(dctx, v));
418   if (v->offloadmask == PETSC_OFFLOAD_CPU) {
419     cupmStream_t stream;
420 
421     v->offloadmask = PETSC_OFFLOAD_BOTH;
422     PetscCall(GetHandlesFrom_(dctx, &stream));
423     PetscCall(PetscLogEventBegin(VEC_CUPMCopyToGPU(), v, 0, 0, 0));
424     PetscCall(PetscCUPMMemcpyAsync(VecCUPMCast(v)->array_d, VecIMPLCast(v)->array, v->map->n, cupmMemcpyHostToDevice, stream, forceasync));
425     PetscCall(PetscLogEventEnd(VEC_CUPMCopyToGPU(), v, 0, 0, 0));
426   }
427   PetscFunctionReturn(PETSC_SUCCESS);
428 }
429 
430 template <device::cupm::DeviceType T, typename D>
CopyToHost_(PetscDeviceContext dctx,Vec v,bool forceasync)431 inline PetscErrorCode Vec_CUPMBase<T, D>::CopyToHost_(PetscDeviceContext dctx, Vec v, bool forceasync) noexcept
432 {
433   PetscFunctionBegin;
434   PetscCall(HostAllocateCheck_(dctx, v));
435   if (v->offloadmask == PETSC_OFFLOAD_GPU) {
436     cupmStream_t stream;
437 
438     v->offloadmask = PETSC_OFFLOAD_BOTH;
439     PetscCall(GetHandlesFrom_(dctx, &stream));
440     PetscCall(PetscLogEventBegin(VEC_CUPMCopyFromGPU(), v, 0, 0, 0));
441     PetscCall(PetscCUPMMemcpyAsync(VecIMPLCast(v)->array, VecCUPMCast(v)->array_d, v->map->n, cupmMemcpyDeviceToHost, stream, forceasync));
442     PetscCall(PetscLogEventEnd(VEC_CUPMCopyFromGPU(), v, 0, 0, 0));
443   }
444   PetscFunctionReturn(PETSC_SUCCESS);
445 }
446 
447 template <device::cupm::DeviceType T, typename D>
DestroyDevice_(Vec v)448 inline PetscErrorCode Vec_CUPMBase<T, D>::DestroyDevice_(Vec v) noexcept
449 {
450   PetscFunctionBegin;
451   if (const auto vcu = VecCUPMCast(v)) {
452     PetscDeviceContext dctx;
453 
454     PetscCall(GetHandles_(&dctx));
455     PetscCall(ResetAllocatedDevicePtr_(dctx, v));
456     PetscCall(ResetPreallocationCOO_CUPMBase(v, dctx));
457     PetscCall(PetscFree(v->spptr));
458   }
459   PetscFunctionReturn(PETSC_SUCCESS);
460 }
461 
462 template <device::cupm::DeviceType T, typename D>
DestroyHost_(Vec v)463 inline PetscErrorCode Vec_CUPMBase<T, D>::DestroyHost_(Vec v) noexcept
464 {
465   PetscFunctionBegin;
466   PetscCall(PetscObjectSAWsViewOff(PetscObjectCast(v)));
467   if (const auto vimpl = VecIMPLCast(v)) {
468     if (auto &array_allocated = vimpl->array_allocated) {
469       const auto useit = UseCUPMHostAlloc(v->pinned_memory);
470 
471       // do this ourselves since we may want to use the cupm functions
472       PetscCall(PetscFree(array_allocated));
473     }
474   }
475   v->pinned_memory = PETSC_FALSE;
476   PetscCall(VecDestroy_IMPL(v));
477   PetscFunctionReturn(PETSC_SUCCESS);
478 }
479 
480 // ==========================================================================================
481 // Vec_CUPMBase - Public API
482 // ==========================================================================================
483 
484 template <device::cupm::DeviceType T, typename D>
VecCUPMCast(Vec v)485 inline typename Vec_CUPMBase<T, D>::Vec_CUPM *Vec_CUPMBase<T, D>::VecCUPMCast(Vec v) noexcept
486 {
487   return static_cast<Vec_CUPM *>(v->spptr);
488 }
489 
490 // This is a trick to get around the fact that in CRTP the derived class is not yet fully
491 // defined because Base<Derived> must necessarily be instantiated before Derived is
492 // complete. By using a dummy template parameter we make the type "dependent" and so will
493 // only be determined when the derived class is instantiated (and therefore fully defined)
494 template <device::cupm::DeviceType T, typename D>
495 template <typename U>
496 inline constexpr auto Vec_CUPMBase<T, D>::VecIMPLCast(Vec v) noexcept -> decltype(U::VecIMPLCast_(v))
497 {
498   return U::VecIMPLCast_(v);
499 }
500 
501 template <device::cupm::DeviceType T, typename D>
VecDestroy_IMPL(Vec v)502 inline PetscErrorCode Vec_CUPMBase<T, D>::VecDestroy_IMPL(Vec v) noexcept
503 {
504   return D::VecDestroy_IMPL_(v);
505 }
506 
507 template <device::cupm::DeviceType T, typename D>
VecResetArray_IMPL(Vec v)508 inline PetscErrorCode Vec_CUPMBase<T, D>::VecResetArray_IMPL(Vec v) noexcept
509 {
510   return D::VecResetArray_IMPL_(v);
511 }
512 
513 template <device::cupm::DeviceType T, typename D>
VecPlaceArray_IMPL(Vec v,const PetscScalar * a)514 inline PetscErrorCode Vec_CUPMBase<T, D>::VecPlaceArray_IMPL(Vec v, const PetscScalar *a) noexcept
515 {
516   return D::VecPlaceArray_IMPL_(v, a);
517 }
518 
519 template <device::cupm::DeviceType T, typename D>
VecCreate_IMPL_Private(Vec v,PetscBool * alloc_missing,PetscInt nghost,PetscScalar * host_array)520 inline PetscErrorCode Vec_CUPMBase<T, D>::VecCreate_IMPL_Private(Vec v, PetscBool *alloc_missing, PetscInt nghost, PetscScalar *host_array) noexcept
521 {
522   return D::VecCreate_IMPL_Private_(v, alloc_missing, nghost, host_array);
523 }
524 
525 template <device::cupm::DeviceType T, typename D>
VEC_CUPMCopyToGPU()526 inline constexpr PetscLogEvent Vec_CUPMBase<T, D>::VEC_CUPMCopyToGPU() noexcept
527 {
528   return T == device::cupm::DeviceType::CUDA ? VEC_CUDACopyToGPU : VEC_HIPCopyToGPU;
529 }
530 
531 template <device::cupm::DeviceType T, typename D>
VEC_CUPMCopyFromGPU()532 inline constexpr PetscLogEvent Vec_CUPMBase<T, D>::VEC_CUPMCopyFromGPU() noexcept
533 {
534   return T == device::cupm::DeviceType::CUDA ? VEC_CUDACopyFromGPU : VEC_HIPCopyFromGPU;
535 }
536 
537 template <device::cupm::DeviceType T, typename D>
VECSEQCUPM()538 inline constexpr VecType Vec_CUPMBase<T, D>::VECSEQCUPM() noexcept
539 {
540   return T == device::cupm::DeviceType::CUDA ? VECSEQCUDA : VECSEQHIP;
541 }
542 
543 template <device::cupm::DeviceType T, typename D>
VECMPICUPM()544 inline constexpr VecType Vec_CUPMBase<T, D>::VECMPICUPM() noexcept
545 {
546   return T == device::cupm::DeviceType::CUDA ? VECMPICUDA : VECMPIHIP;
547 }
548 
549 template <device::cupm::DeviceType T, typename D>
VECCUPM()550 inline constexpr VecType Vec_CUPMBase<T, D>::VECCUPM() noexcept
551 {
552   return T == device::cupm::DeviceType::CUDA ? VECCUDA : VECHIP;
553 }
554 
555 template <device::cupm::DeviceType T, typename D>
556 template <typename U>
VECIMPLCUPM()557 inline constexpr VecType Vec_CUPMBase<T, D>::VECIMPLCUPM() noexcept
558 {
559   return U::VECIMPLCUPM_();
560 }
561 
562 template <device::cupm::DeviceType T, typename D>
563 template <typename U>
VECIMPL()564 inline constexpr VecType Vec_CUPMBase<T, D>::VECIMPL() noexcept
565 {
566   return U::VECIMPL_();
567 }
568 
569 // private version that takes a PetscDeviceContext, called by the public variant
570 template <device::cupm::DeviceType T, typename D>
571 template <PetscMemType mtype, PetscMemoryAccessMode access, bool force>
GetArray(Vec v,PetscScalar ** a,PetscDeviceContext dctx)572 inline PetscErrorCode Vec_CUPMBase<T, D>::GetArray(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
573 {
574   constexpr auto hostmem     = PetscMemTypeHost(mtype);
575   const auto     oldmask     = v->offloadmask;
576   auto          &mask        = v->offloadmask;
577   auto           should_sync = false;
578 
579   PetscFunctionBegin;
580   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
581   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
582   if (PetscMemoryAccessRead(access)) {
583     // READ or READ_WRITE
584     if (((oldmask == PETSC_OFFLOAD_GPU) && hostmem) || ((oldmask == PETSC_OFFLOAD_CPU) && !hostmem)) {
585       // if we move the data we should set the flag to synchronize later on
586       should_sync = true;
587     }
588     PetscCall((hostmem ? CopyToHost_ : CopyToDevice_)(dctx, v, force));
589   } else {
590     // WRITE only
591     PetscCall((hostmem ? HostAllocateCheck_ : DeviceAllocateCheck_)(dctx, v));
592   }
593   *a = hostmem ? VecIMPLCast(v)->array : VecCUPMCast(v)->array_d;
594   // if unallocated previously we should zero things out if we intend to read
595   if (PetscMemoryAccessRead(access) && (oldmask == PETSC_OFFLOAD_UNALLOCATED)) {
596     const auto n = v->map->n;
597 
598     if (hostmem) {
599       PetscCall(PetscArrayzero(*a, n));
600     } else {
601       cupmStream_t stream;
602 
603       PetscCall(GetHandlesFrom_(dctx, &stream));
604       PetscCall(PetscCUPMMemsetAsync(*a, 0, n, stream, force));
605       should_sync = true;
606     }
607   }
608   // update the offloadmask if we intend to write, since we assume immediately modified
609   if (PetscMemoryAccessWrite(access)) {
610     PetscCall(VecSetErrorIfLocked(v, 1));
611     // REVIEW ME: this should probably also call PetscObjectStateIncrease() since we assume it
612     // is immediately modified
613     mask = hostmem ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU;
614   }
615   // if we are a globally blocking stream and we have MOVED data then we should synchronize,
616   // since even doing async calls on the NULL stream is not synchronous
617   if (!force && should_sync) PetscCall(PetscDeviceContextSynchronize(dctx));
618   PetscFunctionReturn(PETSC_SUCCESS);
619 }
620 
621 // v->ops->getarray[read|write] or VecCUPMGetArray[Read|Write]()
622 template <device::cupm::DeviceType T, typename D>
623 template <PetscMemType mtype, PetscMemoryAccessMode access, bool force>
GetArray(Vec v,PetscScalar ** a)624 inline PetscErrorCode Vec_CUPMBase<T, D>::GetArray(Vec v, PetscScalar **a) noexcept
625 {
626   PetscDeviceContext dctx;
627 
628   PetscFunctionBegin;
629   PetscCall(GetHandles_(&dctx));
630   PetscCall(D::template GetArray<mtype, access, force>(v, a, dctx));
631   PetscFunctionReturn(PETSC_SUCCESS);
632 }
633 
634 // private version that takes a PetscDeviceContext, called by the public variant
635 template <device::cupm::DeviceType T, typename D>
636 template <PetscMemType mtype, PetscMemoryAccessMode access>
RestoreArray(Vec v,PetscScalar ** a,PetscDeviceContext)637 inline PetscErrorCode Vec_CUPMBase<T, D>::RestoreArray(Vec v, PetscScalar **a, PetscDeviceContext) noexcept
638 {
639   PetscFunctionBegin;
640   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
641   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
642   if (PetscMemoryAccessWrite(access)) {
643     // WRITE or READ_WRITE
644     PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
645     v->offloadmask = PetscMemTypeHost(mtype) ? PETSC_OFFLOAD_CPU : PETSC_OFFLOAD_GPU;
646   }
647   if (a) {
648     PetscCall(CheckPointerMatchesMemType_(*a, mtype));
649     *a = nullptr;
650   }
651   PetscFunctionReturn(PETSC_SUCCESS);
652 }
653 
654 // v->ops->restorearray[read|write] or VecCUPMRestoreArray[Read|Write]()
655 template <device::cupm::DeviceType T, typename D>
656 template <PetscMemType mtype, PetscMemoryAccessMode access>
RestoreArray(Vec v,PetscScalar ** a)657 inline PetscErrorCode Vec_CUPMBase<T, D>::RestoreArray(Vec v, PetscScalar **a) noexcept
658 {
659   PetscDeviceContext dctx;
660 
661   PetscFunctionBegin;
662   PetscCall(GetHandles_(&dctx));
663   PetscCall(D::template RestoreArray<mtype, access>(v, a, dctx));
664   PetscFunctionReturn(PETSC_SUCCESS);
665 }
666 
667 template <device::cupm::DeviceType T, typename D>
668 template <PetscMemoryAccessMode access>
GetArrayAndMemtype(Vec v,PetscScalar ** a,PetscMemType * mtype,PetscDeviceContext dctx)669 inline PetscErrorCode Vec_CUPMBase<T, D>::GetArrayAndMemtype(Vec v, PetscScalar **a, PetscMemType *mtype, PetscDeviceContext dctx) noexcept
670 {
671   PetscFunctionBegin;
672   if (a) PetscCall(D::template GetArray<PETSC_MEMTYPE_DEVICE, access>(v, a, dctx));
673   if (mtype) *mtype = (PetscDefined(HAVE_NVSHMEM) && VecCUPMCast(v)->nvshmem) ? PETSC_MEMTYPE_NVSHMEM : PETSC_MEMTYPE_CUPM();
674   PetscFunctionReturn(PETSC_SUCCESS);
675 }
676 
677 // v->ops->getarrayandmemtype
678 template <device::cupm::DeviceType T, typename D>
679 template <PetscMemoryAccessMode access>
GetArrayAndMemtype(Vec v,PetscScalar ** a,PetscMemType * mtype)680 inline PetscErrorCode Vec_CUPMBase<T, D>::GetArrayAndMemtype(Vec v, PetscScalar **a, PetscMemType *mtype) noexcept
681 {
682   PetscDeviceContext dctx;
683 
684   PetscFunctionBegin;
685   PetscCall(GetHandles_(&dctx));
686   PetscCall(D::template GetArrayAndMemtype<access>(v, a, mtype, dctx));
687   PetscFunctionReturn(PETSC_SUCCESS);
688 }
689 
690 template <device::cupm::DeviceType T, typename D>
691 template <PetscMemoryAccessMode access>
RestoreArrayAndMemtype(Vec v,PetscScalar ** a,PetscDeviceContext dctx)692 inline PetscErrorCode Vec_CUPMBase<T, D>::RestoreArrayAndMemtype(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
693 {
694   PetscFunctionBegin;
695   PetscCall(D::template RestoreArray<PETSC_MEMTYPE_DEVICE, access>(v, a, dctx));
696   PetscFunctionReturn(PETSC_SUCCESS);
697 }
698 
699 // v->ops->restorearrayandmemtype
700 template <device::cupm::DeviceType T, typename D>
701 template <PetscMemoryAccessMode access>
RestoreArrayAndMemtype(Vec v,PetscScalar ** a)702 inline PetscErrorCode Vec_CUPMBase<T, D>::RestoreArrayAndMemtype(Vec v, PetscScalar **a) noexcept
703 {
704   PetscDeviceContext dctx;
705 
706   PetscFunctionBegin;
707   PetscCall(GetHandles_(&dctx));
708   PetscCall(D::template RestoreArrayAndMemtype<access>(v, a, dctx));
709   PetscFunctionReturn(PETSC_SUCCESS);
710 }
711 
712 // v->ops->placearray or VecCUPMPlaceArray()
713 template <device::cupm::DeviceType T, typename D>
714 template <PetscMemType mtype>
PlaceArray(Vec v,const PetscScalar * a)715 inline PetscErrorCode Vec_CUPMBase<T, D>::PlaceArray(Vec v, const PetscScalar *a) noexcept
716 {
717   PetscDeviceContext dctx;
718 
719   PetscFunctionBegin;
720   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
721   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
722   PetscCall(CheckPointerMatchesMemType_(a, mtype));
723   PetscCall(GetHandles_(&dctx));
724   if (PetscMemTypeHost(mtype)) {
725     PetscCall(CopyToHost_(dctx, v));
726     PetscCall(VecPlaceArray_IMPL(v, a));
727     v->offloadmask = PETSC_OFFLOAD_CPU;
728   } else {
729     PetscCall(VecIMPLAllocateCheck_(v));
730     {
731       auto &backup_array = VecIMPLCast(v)->unplacedarray;
732 
733       PetscCheck(!backup_array, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "VecPlaceArray() was already called on this vector, without a call to VecResetArray()");
734       PetscCall(CopyToDevice_(dctx, v));
735       PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
736       backup_array = util::exchange(VecCUPMCast(v)->array_d, const_cast<PetscScalar *>(a));
737       // only update the offload mask if we actually assign a pointer
738       if (a) v->offloadmask = PETSC_OFFLOAD_GPU;
739     }
740   }
741   PetscFunctionReturn(PETSC_SUCCESS);
742 }
743 
744 // v->ops->replacearray or VecCUPMReplaceArray()
745 template <device::cupm::DeviceType T, typename D>
746 template <PetscMemType mtype>
ReplaceArray(Vec v,const PetscScalar * a)747 inline PetscErrorCode Vec_CUPMBase<T, D>::ReplaceArray(Vec v, const PetscScalar *a) noexcept
748 {
749   const auto         aptr = const_cast<PetscScalar *>(a);
750   PetscDeviceContext dctx;
751 
752   PetscFunctionBegin;
753   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
754   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
755   PetscCall(CheckPointerMatchesMemType_(a, mtype));
756   PetscCall(GetHandles_(&dctx));
757   if (PetscMemTypeHost(mtype)) {
758     PetscCall(VecIMPLAllocateCheck_(v));
759     {
760       const auto vimpl      = VecIMPLCast(v);
761       auto      &host_array = vimpl->array_allocated;
762 
763       // make sure the users array has the latest values.
764       // REVIEW ME: why? we're about to free it
765       if (host_array != vimpl->array) PetscCall(CopyToHost_(dctx, v));
766       if (host_array) {
767         const auto useit = UseCUPMHostAlloc(v->pinned_memory);
768 
769         PetscCall(PetscFree(host_array));
770       }
771       host_array       = aptr;
772       vimpl->array     = host_array;
773       v->pinned_memory = PETSC_FALSE; // REVIEW ME: we can determine this
774       v->offloadmask   = PETSC_OFFLOAD_CPU;
775     }
776   } else {
777     PetscCall(VecCUPMAllocateCheck_(v));
778     {
779       const auto vcu = VecCUPMCast(v);
780 
781       PetscCall(ResetAllocatedDevicePtr_(dctx, v, aptr));
782       // don't update the offloadmask if placed pointer is NULL
783       vcu->array_d = vcu->array_allocated_d /* = aptr */;
784       if (aptr) v->offloadmask = PETSC_OFFLOAD_GPU;
785     }
786   }
787   PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
788   PetscFunctionReturn(PETSC_SUCCESS);
789 }
790 
791 // v->ops->resetarray or VecCUPMResetArray()
792 template <device::cupm::DeviceType T, typename D>
793 template <PetscMemType mtype>
ResetArray(Vec v)794 inline PetscErrorCode Vec_CUPMBase<T, D>::ResetArray(Vec v) noexcept
795 {
796   PetscDeviceContext dctx;
797 
798   PetscFunctionBegin;
799   static_assert((mtype == PETSC_MEMTYPE_HOST) || (mtype == PETSC_MEMTYPE_DEVICE), "");
800   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
801   PetscCall(GetHandles_(&dctx));
802   // REVIEW ME:
803   // this is wildly inefficient but must be done if we assume that the placed array must have
804   // correct values
805   if (PetscMemTypeHost(mtype)) {
806     PetscCall(CopyToHost_(dctx, v));
807     PetscCall(VecResetArray_IMPL(v));
808     v->offloadmask = PETSC_OFFLOAD_CPU;
809   } else {
810     PetscCall(VecIMPLAllocateCheck_(v));
811     PetscCall(VecCUPMAllocateCheck_(v));
812     {
813       const auto vcu        = VecCUPMCast(v);
814       const auto vimpl      = VecIMPLCast(v);
815       auto      &host_array = vimpl->unplacedarray;
816 
817       PetscCall(CheckPointerMatchesMemType_(host_array, PETSC_MEMTYPE_DEVICE));
818       if (v->offloadmask == PETSC_OFFLOAD_CPU) {
819         PetscCall(CopyToDevice_(dctx, v));
820         PetscCall(PetscDeviceContextSynchronize(dctx)); // Above H2D might be async, so we must sync dctx, otherwise if later user writes v's host array, it could ruin the H2D
821       }
822       PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
823       // Need to reset the offloadmask. If we had a stashed pointer we are on the GPU,
824       // otherwise check if the host has a valid pointer. If neither, then we are not
825       // allocated.
826       vcu->array_d = host_array;
827       if (host_array) {
828         host_array     = nullptr;
829         v->offloadmask = PETSC_OFFLOAD_GPU;
830       } else if (vimpl->array) {
831         v->offloadmask = PETSC_OFFLOAD_CPU;
832       } else {
833         v->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
834       }
835     }
836   }
837   PetscFunctionReturn(PETSC_SUCCESS);
838 }
839 
840 // v->ops->create
841 template <device::cupm::DeviceType T, typename D>
Create(Vec v)842 inline PetscErrorCode Vec_CUPMBase<T, D>::Create(Vec v) noexcept
843 {
844   PetscBool          alloc_missing;
845   PetscDeviceContext dctx;
846 
847   PetscFunctionBegin;
848   PetscCall(VecCreate_IMPL_Private(v, &alloc_missing));
849   PetscCall(GetHandles_(&dctx));
850   PetscCall(Initialize_CUPMBase(v, alloc_missing, nullptr, nullptr, dctx));
851   PetscFunctionReturn(PETSC_SUCCESS);
852 }
853 
854 // v->ops->destroy
855 template <device::cupm::DeviceType T, typename D>
Destroy(Vec v)856 inline PetscErrorCode Vec_CUPMBase<T, D>::Destroy(Vec v) noexcept
857 {
858   PetscFunctionBegin;
859   PetscCall(DestroyDevice_(v));
860   PetscCall(DestroyHost_(v));
861   PetscFunctionReturn(PETSC_SUCCESS);
862 }
863 
864 // ================================================================================== //
865 //                      Common core between Seq and MPI                               //
866 
867 // VecCreate_CUPM()
868 template <device::cupm::DeviceType T, typename D>
Create_CUPM(Vec v)869 inline PetscErrorCode Vec_CUPMBase<T, D>::Create_CUPM(Vec v) noexcept
870 {
871   PetscMPIInt size;
872 
873   PetscFunctionBegin;
874   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
875   PetscCall(VecSetType(v, size > 1 ? VECMPICUPM() : VECSEQCUPM()));
876   PetscFunctionReturn(PETSC_SUCCESS);
877 }
878 
879 // VecCreateCUPM()
880 template <device::cupm::DeviceType T, typename D>
Create_CUPMBase(MPI_Comm comm,PetscInt bs,PetscInt n,PetscInt N,Vec * v,PetscBool call_set_type,PetscLayout reference)881 inline PetscErrorCode Vec_CUPMBase<T, D>::Create_CUPMBase(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, Vec *v, PetscBool call_set_type, PetscLayout reference) noexcept
882 {
883   PetscFunctionBegin;
884   PetscCall(VecCreate(comm, v));
885   if (reference) PetscCall(PetscLayoutReference(reference, &(*v)->map));
886   PetscCall(VecSetSizes(*v, n, N));
887   if (bs) PetscCall(VecSetBlockSize(*v, bs));
888   if (call_set_type) PetscCall(VecSetType(*v, VECIMPLCUPM()));
889   PetscFunctionReturn(PETSC_SUCCESS);
890 }
891 
892 // VecCreateIMPL_CUPM(), called through v->ops->create
893 template <device::cupm::DeviceType T, typename D>
Initialize_CUPMBase(Vec v,PetscBool allocate_missing,PetscScalar * host_array,PetscScalar * device_array,PetscDeviceContext dctx)894 inline PetscErrorCode Vec_CUPMBase<T, D>::Initialize_CUPMBase(Vec v, PetscBool allocate_missing, PetscScalar *host_array, PetscScalar *device_array, PetscDeviceContext dctx) noexcept
895 {
896   PetscFunctionBegin;
897   // REVIEW ME: perhaps not needed
898   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
899   PetscCall(PetscObjectChangeTypeName(PetscObjectCast(v), VECIMPLCUPM()));
900   PetscCall(D::BindToCPU(v, PETSC_FALSE));
901   if (device_array) {
902     PetscCall(CheckPointerMatchesMemType_(device_array, PETSC_MEMTYPE_CUPM()));
903     PetscCall(VecCUPMAllocateCheck_(v));
904     VecCUPMCast(v)->array_d = device_array;
905   }
906   if (host_array) {
907     PetscCall(CheckPointerMatchesMemType_(host_array, PETSC_MEMTYPE_HOST));
908     VecIMPLCast(v)->array = host_array;
909   }
910   if (allocate_missing) {
911     PetscCall(DeviceAllocateCheck_(dctx, v));
912     PetscCall(HostAllocateCheck_(dctx, v));
913     // REVIEW ME: junchao, is this needed with new calloc() branch? VecSet() will call
914     // set() for reference
915     // calls device-version
916     PetscCall(VecSet(v, 0));
917     // zero the host while device is underway
918     PetscCall(PetscArrayzero(VecIMPLCast(v)->array, v->map->n));
919     v->offloadmask = PETSC_OFFLOAD_BOTH;
920   } else {
921     if (host_array) {
922       v->offloadmask = device_array ? PETSC_OFFLOAD_BOTH : PETSC_OFFLOAD_CPU;
923     } else {
924       v->offloadmask = device_array ? PETSC_OFFLOAD_GPU : PETSC_OFFLOAD_UNALLOCATED;
925     }
926   }
927   PetscFunctionReturn(PETSC_SUCCESS);
928 }
929 
930 // v->ops->duplicate
931 template <device::cupm::DeviceType T, typename D>
932 template <typename SetupFunctionT>
Duplicate_CUPMBase(Vec v,Vec * y,PetscDeviceContext dctx,SetupFunctionT && DerivedCreateIMPLCUPM_Async)933 inline PetscErrorCode Vec_CUPMBase<T, D>::Duplicate_CUPMBase(Vec v, Vec *y, PetscDeviceContext dctx, SetupFunctionT &&DerivedCreateIMPLCUPM_Async) noexcept
934 {
935   // if the derived setup is the default no_op then we should call VecSetType()
936   constexpr auto call_set_type = static_cast<PetscBool>(std::is_same<SetupFunctionT, no_op>::value);
937   const auto     vobj          = PetscObjectCast(v);
938   const auto     map           = v->map;
939   PetscInt       bs;
940 
941   PetscFunctionBegin;
942   PetscCall(VecGetBlockSize(v, &bs));
943   PetscCall(Create_CUPMBase(PetscObjectComm(vobj), bs, map->n, map->N, y, call_set_type, map));
944   // Derived class can set up the remainder of the data structures here
945   PetscCall(DerivedCreateIMPLCUPM_Async(*y));
946   // If the other vector is bound to CPU then the memcpy of the ops struct will give the
947   // duplicated vector the host "getarray" function which does not lazily allocate the array
948   // (as it is assumed to always exist). So we force allocation here, before we overwrite the
949   // ops
950   if (v->boundtocpu) PetscCall(HostAllocateCheck_(dctx, *y));
951   // in case the user has done some VecSetOps() tomfoolery
952   (*y)->ops[0] = v->ops[0];
953   {
954     const auto yobj = PetscObjectCast(*y);
955 
956     PetscCall(PetscObjectListDuplicate(vobj->olist, &yobj->olist));
957     PetscCall(PetscFunctionListDuplicate(vobj->qlist, &yobj->qlist));
958   }
959   (*y)->stash.donotstash   = v->stash.donotstash;
960   (*y)->stash.ignorenegidx = v->stash.ignorenegidx;
961   (*y)->map->bs            = std::abs(v->map->bs);
962   (*y)->bstash.bs          = v->bstash.bs;
963   PetscFunctionReturn(PETSC_SUCCESS);
964 }
965 
966   #define VecSetOp_CUPM(op_name, op_host, ...) \
967     do { \
968       if (usehost) { \
969         v->ops->op_name = op_host; \
970       } else { \
971         v->ops->op_name = __VA_ARGS__; \
972       } \
973     } while (0)
974 
975 // v->ops->bindtocpu
976 template <device::cupm::DeviceType T, typename D>
BindToCPU_CUPMBase(Vec v,PetscBool usehost,PetscDeviceContext dctx)977 inline PetscErrorCode Vec_CUPMBase<T, D>::BindToCPU_CUPMBase(Vec v, PetscBool usehost, PetscDeviceContext dctx) noexcept
978 {
979   PetscFunctionBegin;
980   v->boundtocpu = usehost;
981   if (usehost) PetscCall(CopyToHost_(dctx, v));
982   PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &v->defaultrandtype));
983 
984   // set the base functions that are guaranteed to be the same for both
985   v->ops->duplicate = D::Duplicate;
986   v->ops->create    = D::Create;
987   v->ops->destroy   = D::Destroy;
988   v->ops->bindtocpu = D::BindToCPU;
989   // Note that setting these to NULL on host breaks convergence in certain areas. I don't know
990   // why, and I don't know how, but it is IMPERATIVE these are set as such!
991   v->ops->replacearray = D::template ReplaceArray<PETSC_MEMTYPE_HOST>;
992   v->ops->restorearray = D::template RestoreArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>;
993 
994   // set device-only common functions
995   VecSetOp_CUPM(getarray, nullptr, D::template GetArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ_WRITE>);
996   VecSetOp_CUPM(getarraywrite, nullptr, D::template GetArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>);
997   VecSetOp_CUPM(restorearraywrite, nullptr, D::template RestoreArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_WRITE>);
998 
999   VecSetOp_CUPM(getarrayread, nullptr, [](Vec v, const PetscScalar **a) { return D::template GetArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>(v, const_cast<PetscScalar **>(a)); });
1000   VecSetOp_CUPM(restorearrayread, nullptr, [](Vec v, const PetscScalar **a) { return D::template RestoreArray<PETSC_MEMTYPE_HOST, PETSC_MEMORY_ACCESS_READ>(v, const_cast<PetscScalar **>(a)); });
1001 
1002   VecSetOp_CUPM(getarrayandmemtype, nullptr, D::template GetArrayAndMemtype<PETSC_MEMORY_ACCESS_READ_WRITE>);
1003   VecSetOp_CUPM(restorearrayandmemtype, nullptr, D::template RestoreArrayAndMemtype<PETSC_MEMORY_ACCESS_READ_WRITE>);
1004 
1005   VecSetOp_CUPM(getarraywriteandmemtype, nullptr, D::template GetArrayAndMemtype<PETSC_MEMORY_ACCESS_WRITE>);
1006   VecSetOp_CUPM(restorearraywriteandmemtype, nullptr, [](Vec v, PetscScalar **a, PetscMemType *) { return D::template RestoreArrayAndMemtype<PETSC_MEMORY_ACCESS_WRITE>(v, a); });
1007 
1008   VecSetOp_CUPM(getarrayreadandmemtype, nullptr, [](Vec v, const PetscScalar **a, PetscMemType *m) { return D::template GetArrayAndMemtype<PETSC_MEMORY_ACCESS_READ>(v, const_cast<PetscScalar **>(a), m); });
1009   VecSetOp_CUPM(restorearrayreadandmemtype, nullptr, [](Vec v, const PetscScalar **a) { return D::template RestoreArrayAndMemtype<PETSC_MEMORY_ACCESS_READ>(v, const_cast<PetscScalar **>(a)); });
1010 
1011   // set the functions that are always sequential
1012   using VecSeq_T = VecSeq_CUPM<T>;
1013   VecSetOp_CUPM(scale, VecScale_Seq, VecSeq_T::Scale);
1014   VecSetOp_CUPM(copy, VecCopy_Seq, VecSeq_T::Copy);
1015   VecSetOp_CUPM(set, VecSet_Seq, VecSeq_T::Set);
1016   VecSetOp_CUPM(swap, VecSwap_Seq, VecSeq_T::Swap);
1017   VecSetOp_CUPM(axpy, VecAXPY_Seq, VecSeq_T::AXPY);
1018   VecSetOp_CUPM(axpby, VecAXPBY_Seq, VecSeq_T::AXPBY);
1019   VecSetOp_CUPM(maxpy, VecMAXPY_Seq, VecSeq_T::MAXPY);
1020   VecSetOp_CUPM(aypx, VecAYPX_Seq, VecSeq_T::AYPX);
1021   VecSetOp_CUPM(waxpy, VecWAXPY_Seq, VecSeq_T::WAXPY);
1022   VecSetOp_CUPM(axpbypcz, VecAXPBYPCZ_Seq, VecSeq_T::AXPBYPCZ);
1023   VecSetOp_CUPM(pointwisemult, VecPointwiseMult_Seq, VecSeq_T::PointwiseMult);
1024   VecSetOp_CUPM(pointwisedivide, VecPointwiseDivide_Seq, VecSeq_T::PointwiseDivide);
1025   VecSetOp_CUPM(pointwisemax, VecPointwiseMax_Seq, VecSeq_T::PointwiseMax);
1026   VecSetOp_CUPM(pointwisemaxabs, VecPointwiseMaxAbs_Seq, VecSeq_T::PointwiseMaxAbs);
1027   VecSetOp_CUPM(pointwisemin, VecPointwiseMin_Seq, VecSeq_T::PointwiseMin);
1028   VecSetOp_CUPM(setrandom, VecSetRandom_Seq, VecSeq_T::SetRandom);
1029   VecSetOp_CUPM(dot_local, VecDot_Seq, VecSeq_T::Dot);
1030   VecSetOp_CUPM(tdot_local, VecTDot_Seq, VecSeq_T::TDot);
1031   VecSetOp_CUPM(norm_local, VecNorm_Seq, VecSeq_T::Norm);
1032   VecSetOp_CUPM(mdot_local, VecMDot_Seq, VecSeq_T::MDot);
1033   VecSetOp_CUPM(reciprocal, VecReciprocal_Default, VecSeq_T::Reciprocal);
1034   VecSetOp_CUPM(conjugate, VecConjugate_Seq, VecSeq_T::Conjugate);
1035   VecSetOp_CUPM(abs, nullptr, VecSeq_T::Abs);
1036   VecSetOp_CUPM(sqrt, nullptr, VecSeq_T::SqrtAbs);
1037   VecSetOp_CUPM(exp, nullptr, VecSeq_T::Exp);
1038   VecSetOp_CUPM(log, nullptr, VecSeq_T::Log);
1039   VecSetOp_CUPM(shift, nullptr, VecSeq_T::Shift);
1040   VecSetOp_CUPM(dotnorm2, nullptr, D::DotNorm2);
1041   VecSetOp_CUPM(getlocalvector, nullptr, VecSeq_T::template GetLocalVector<PETSC_MEMORY_ACCESS_READ_WRITE>);
1042   VecSetOp_CUPM(restorelocalvector, nullptr, VecSeq_T::template RestoreLocalVector<PETSC_MEMORY_ACCESS_READ_WRITE>);
1043   VecSetOp_CUPM(getlocalvectorread, nullptr, VecSeq_T::template GetLocalVector<PETSC_MEMORY_ACCESS_READ>);
1044   VecSetOp_CUPM(restorelocalvectorread, nullptr, VecSeq_T::template RestoreLocalVector<PETSC_MEMORY_ACCESS_READ>);
1045   VecSetOp_CUPM(sum, nullptr, VecSeq_T::Sum);
1046   VecSetOp_CUPM(errorwnorm, nullptr, D::ErrorWnorm);
1047   VecSetOp_CUPM(duplicatevecs, VecDuplicateVecs_Default, VecDuplicateVecs_Default);
1048   PetscFunctionReturn(PETSC_SUCCESS);
1049 }
1050 
1051 // Called from VecGetSubVector()
1052 template <device::cupm::DeviceType T, typename D>
GetArrays_CUPMBase(Vec v,const PetscScalar ** host_array,const PetscScalar ** device_array,PetscOffloadMask * mask,PetscDeviceContext dctx)1053 inline PetscErrorCode Vec_CUPMBase<T, D>::GetArrays_CUPMBase(Vec v, const PetscScalar **host_array, const PetscScalar **device_array, PetscOffloadMask *mask, PetscDeviceContext dctx) noexcept
1054 {
1055   PetscFunctionBegin;
1056   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1057   if (host_array) {
1058     PetscCall(HostAllocateCheck_(dctx, v));
1059     *host_array = VecIMPLCast(v)->array;
1060   }
1061   if (device_array) {
1062     PetscCall(DeviceAllocateCheck_(dctx, v));
1063     *device_array = VecCUPMCast(v)->array_d;
1064   }
1065   if (mask) *mask = v->offloadmask;
1066   PetscFunctionReturn(PETSC_SUCCESS);
1067 }
1068 
1069 template <device::cupm::DeviceType T, typename D>
ResetPreallocationCOO_CUPMBase(Vec v,PetscDeviceContext dctx)1070 inline PetscErrorCode Vec_CUPMBase<T, D>::ResetPreallocationCOO_CUPMBase(Vec v, PetscDeviceContext dctx) noexcept
1071 {
1072   PetscFunctionBegin;
1073   if (const auto vcu = VecCUPMCast(v)) {
1074     cupmStream_t stream;
1075     // clang-format off
1076     const auto   cntptrs = util::make_array(
1077       std::ref(vcu->jmap1_d),
1078       std::ref(vcu->perm1_d),
1079       std::ref(vcu->imap2_d),
1080       std::ref(vcu->jmap2_d),
1081       std::ref(vcu->perm2_d),
1082       std::ref(vcu->Cperm_d)
1083     );
1084     // clang-format on
1085 
1086     PetscCall(GetHandlesFrom_(dctx, &stream));
1087     for (auto &&ptr : cntptrs) PetscCallCUPM(cupmFreeAsync(ptr.get(), stream));
1088     for (auto &&ptr : util::make_array(std::ref(vcu->sendbuf_d), std::ref(vcu->recvbuf_d))) PetscCallCUPM(cupmFreeAsync(ptr.get(), stream));
1089   }
1090   PetscFunctionReturn(PETSC_SUCCESS);
1091 }
1092 
1093 template <device::cupm::DeviceType T, typename D>
1094 template <std::size_t NCount, std::size_t NScal>
SetPreallocationCOO_CUPMBase(Vec v,PetscCount,const PetscInt[],PetscDeviceContext dctx,const std::array<CooPair<PetscCount>,NCount> & extra_cntptrs,const std::array<CooPair<PetscScalar>,NScal> & bufptrs)1095 inline PetscErrorCode Vec_CUPMBase<T, D>::SetPreallocationCOO_CUPMBase(Vec v, PetscCount, const PetscInt[], PetscDeviceContext dctx, const std::array<CooPair<PetscCount>, NCount> &extra_cntptrs, const std::array<CooPair<PetscScalar>, NScal> &bufptrs) noexcept
1096 {
1097   PetscFunctionBegin;
1098   PetscCall(ResetPreallocationCOO_CUPMBase(v, dctx));
1099   // need to instantiate the private pointer if not already
1100   PetscCall(VecCUPMAllocateCheck_(v));
1101   {
1102     const auto vimpl = VecIMPLCast(v);
1103     const auto vcu   = VecCUPMCast(v);
1104     // clang-format off
1105     const auto cntptrs = util::concat_array(
1106       util::make_array(
1107         make_coo_pair(vcu->jmap1_d, vimpl->jmap1, v->map->n + 1),
1108         make_coo_pair(vcu->perm1_d, vimpl->perm1, vimpl->tot1)
1109       ),
1110       extra_cntptrs
1111     );
1112     // clang-format on
1113     cupmStream_t stream;
1114 
1115     PetscCall(GetHandlesFrom_(dctx, &stream));
1116     // allocate
1117     for (auto &elem : cntptrs) PetscCall(PetscCUPMMallocAsync(&elem.device, elem.size, stream));
1118     for (auto &elem : bufptrs) PetscCall(PetscCUPMMallocAsync(&elem.device, elem.size, stream));
1119     // copy
1120     for (const auto &elem : cntptrs) PetscCall(PetscCUPMMemcpyAsync(elem.device, elem.host, elem.size, cupmMemcpyHostToDevice, stream, true));
1121     for (const auto &elem : bufptrs) PetscCall(PetscCUPMMemcpyAsync(elem.device, elem.host, elem.size, cupmMemcpyHostToDevice, stream, true));
1122   }
1123   PetscFunctionReturn(PETSC_SUCCESS);
1124 }
1125 
1126 template <device::cupm::DeviceType T, typename D>
Convert_IMPL_IMPLCUPM(Vec v)1127 inline PetscErrorCode Vec_CUPMBase<T, D>::Convert_IMPL_IMPLCUPM(Vec v) noexcept
1128 {
1129   const auto         n        = v->map->n;
1130   const auto         vimpl    = VecIMPLCast(v);
1131   auto              &impl_arr = vimpl->array;
1132   PetscBool          set      = PETSC_FALSE;
1133   PetscDeviceContext dctx;
1134 
1135   PetscFunctionBegin;
1136   // If users do not explicitly require pinned memory, we prefer keeping the vector's regular
1137   // host array
1138   PetscCall(VecCUPMCheckMinimumPinnedMemory_Internal(v, &set));
1139   if (set && impl_arr && ((n * sizeof(*impl_arr)) > v->minimum_bytes_pinned_memory)) {
1140     auto        &impl_alloc = vimpl->array_allocated;
1141     PetscScalar *new_arr;
1142 
1143     // users require pinned memory
1144     {
1145       // Allocate pinned memory and copy over the old array
1146       const auto useit = UseCUPMHostAlloc(PETSC_TRUE);
1147 
1148       PetscCall(PetscMalloc1(n, &new_arr));
1149       PetscCall(PetscArraycpy(new_arr, impl_arr, n));
1150     }
1151     PetscCall(PetscFree(impl_alloc));
1152     impl_arr         = new_arr;
1153     impl_alloc       = new_arr;
1154     v->offloadmask   = PETSC_OFFLOAD_CPU;
1155     v->pinned_memory = PETSC_TRUE;
1156   }
1157   PetscCall(GetHandles_(&dctx));
1158   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, impl_arr, nullptr, dctx));
1159   PetscFunctionReturn(PETSC_SUCCESS);
1160 }
1161 
1162   #define PETSC_VEC_CUPM_BASE_CLASS_HEADER(name, Tp, ...) \
1163     PETSC_CUPMOBJECT_HEADER(Tp); \
1164     using name = ::Petsc::vec::cupm::impl::Vec_CUPMBase<Tp, __VA_ARGS__>; \
1165     friend name; \
1166     /* introspection */ \
1167     using name::VecCUPMCast; \
1168     using name::VecIMPLCast; \
1169     using name::VECIMPLCUPM; \
1170     using name::VECIMPL; \
1171     using name::VECSEQCUPM; \
1172     using name::VECMPICUPM; \
1173     using name::VECCUPM; \
1174     using name::VecView_Debug; \
1175     /* utility */ \
1176     using typename name::Vec_CUPM; \
1177     using name::VecCUPMAllocateCheck_; \
1178     using name::VecIMPLAllocateCheck_; \
1179     using name::HostAllocateCheck_; \
1180     using name::DeviceAllocateCheck_; \
1181     using name::CopyToDevice_; \
1182     using name::CopyToHost_; \
1183     using name::Create; \
1184     using name::Destroy; \
1185     using name::GetArray; \
1186     using name::RestoreArray; \
1187     using name::GetArrayAndMemtype; \
1188     using name::RestoreArrayAndMemtype; \
1189     using name::PlaceArray; \
1190     using name::ReplaceArray; \
1191     using name::ResetArray; \
1192     /* base functions */ \
1193     using name::Create_CUPMBase; \
1194     using name::Initialize_CUPMBase; \
1195     using name::Duplicate_CUPMBase; \
1196     using name::BindToCPU_CUPMBase; \
1197     using name::Create_CUPM; \
1198     using name::DeviceArrayRead; \
1199     using name::DeviceArrayWrite; \
1200     using name::DeviceArrayReadWrite; \
1201     using name::HostArrayRead; \
1202     using name::HostArrayWrite; \
1203     using name::HostArrayReadWrite; \
1204     using name::ResetPreallocationCOO_CUPMBase; \
1205     using name::SetPreallocationCOO_CUPMBase; \
1206     using name::Convert_IMPL_IMPLCUPM;
1207 
1208 } // namespace impl
1209 
1210 } // namespace cupm
1211 
1212 } // namespace vec
1213 
1214 } // namespace Petsc
1215 
1216 #endif // __cplusplus && PetscDefined(HAVE_DEVICE)
1217