xref: /petsc/src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp (revision 9dd11ecf0918283bb567d8b33a92f53ac4ea7840)
1 #pragma once
2 
3 #include <petsc/private/veccupmimpl.h> /*I <petscvec.h> I*/
4 #include <../src/vec/vec/impls/seq/cupm/vecseqcupm.hpp>
5 #include <../src/vec/vec/impls/mpi/pvecimpl.h>
6 
7 namespace Petsc
8 {
9 
10 namespace vec
11 {
12 
13 namespace cupm
14 {
15 
16 namespace impl
17 {
18 
19 template <device::cupm::DeviceType T>
20 class VecMPI_CUPM : public Vec_CUPMBase<T, VecMPI_CUPM<T>> {
21 public:
22   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecMPI_CUPM<T>);
23   using VecSeq_T = VecSeq_CUPM<T>;
24 
25 private:
26   PETSC_NODISCARD static Vec_MPI          *VecIMPLCast_(Vec) noexcept;
27   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
28   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
29 
30   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
31   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
32   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
33   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
34 
35   static PetscErrorCode CreateMPICUPM_(Vec, PetscDeviceContext, PetscBool /*allocate_missing*/ = PETSC_TRUE, PetscInt /*nghost*/ = 0, PetscScalar * /*host_array*/ = nullptr, PetscScalar * /*device_array*/ = nullptr) noexcept;
36 
37 public:
38   // callable directly via a bespoke function
39   static PetscErrorCode CreateMPICUPM(MPI_Comm, PetscInt, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
40   static PetscErrorCode CreateMPICUPMWithArrays(MPI_Comm, PetscInt, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
41 
42   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
43   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
44   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
45   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
46   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
47   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
48   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
49   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
50   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
51   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
52   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
53   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
54 };
55 
56 } // namespace impl
57 
58 template <device::cupm::DeviceType T>
VecCreateMPICUPMAsync(MPI_Comm comm,PetscInt n,PetscInt N,Vec * v)59 inline PetscErrorCode VecCreateMPICUPMAsync(MPI_Comm comm, PetscInt n, PetscInt N, Vec *v) noexcept
60 {
61   PetscFunctionBegin;
62   PetscAssertPointer(v, 4);
63   PetscCall(impl::VecMPI_CUPM<T>::CreateMPICUPM(comm, 0, n, N, v, PETSC_TRUE));
64   PetscFunctionReturn(PETSC_SUCCESS);
65 }
66 
67 template <device::cupm::DeviceType T>
VecCreateMPICUPMWithArrays(MPI_Comm comm,PetscInt bs,PetscInt n,PetscInt N,const PetscScalar cpuarray[],const PetscScalar gpuarray[],Vec * v)68 inline PetscErrorCode VecCreateMPICUPMWithArrays(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v)
69 {
70   PetscFunctionBegin;
71   if (n && cpuarray) PetscAssertPointer(cpuarray, 5);
72   PetscAssertPointer(v, 7);
73   PetscCall(impl::VecMPI_CUPM<T>::CreateMPICUPMWithArrays(comm, bs, n, N, cpuarray, gpuarray, v));
74   PetscFunctionReturn(PETSC_SUCCESS);
75 }
76 
77 template <device::cupm::DeviceType T>
VecCreateMPICUPMWithArray(MPI_Comm comm,PetscInt bs,PetscInt n,PetscInt N,const PetscScalar gpuarray[],Vec * v)78 inline PetscErrorCode VecCreateMPICUPMWithArray(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar gpuarray[], Vec *v)
79 {
80   PetscFunctionBegin;
81   PetscCall(VecCreateMPICUPMWithArrays<T>(comm, bs, n, N, nullptr, gpuarray, v));
82   PetscFunctionReturn(PETSC_SUCCESS);
83 }
84 
85 } // namespace cupm
86 
87 } // namespace vec
88 
89 } // namespace Petsc
90 
91 #if PetscDefined(HAVE_CUDA)
92 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecMPI_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
93 #endif
94 
95 #if PetscDefined(HAVE_HIP)
96 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecMPI_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
97 #endif
98