1 #pragma once
2
3 #include <petsc/private/veccupmimpl.h> /*I <petscvec.h> I*/
4 #include <../src/vec/vec/impls/seq/cupm/vecseqcupm.hpp>
5 #include <../src/vec/vec/impls/mpi/pvecimpl.h>
6
7 namespace Petsc
8 {
9
10 namespace vec
11 {
12
13 namespace cupm
14 {
15
16 namespace impl
17 {
18
19 template <device::cupm::DeviceType T>
20 class VecMPI_CUPM : public Vec_CUPMBase<T, VecMPI_CUPM<T>> {
21 public:
22 PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecMPI_CUPM<T>);
23 using VecSeq_T = VecSeq_CUPM<T>;
24
25 private:
26 PETSC_NODISCARD static Vec_MPI *VecIMPLCast_(Vec) noexcept;
27 PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
28 PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
29
30 static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
31 static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
32 static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
33 static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
34
35 static PetscErrorCode CreateMPICUPM_(Vec, PetscDeviceContext, PetscBool /*allocate_missing*/ = PETSC_TRUE, PetscInt /*nghost*/ = 0, PetscScalar * /*host_array*/ = nullptr, PetscScalar * /*device_array*/ = nullptr) noexcept;
36
37 public:
38 // callable directly via a bespoke function
39 static PetscErrorCode CreateMPICUPM(MPI_Comm, PetscInt, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
40 static PetscErrorCode CreateMPICUPMWithArrays(MPI_Comm, PetscInt, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
41
42 static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
43 static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
44 static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
45 static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
46 static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
47 static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
48 static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
49 static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
50 static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
51 static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
52 static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
53 static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
54 };
55
56 } // namespace impl
57
58 template <device::cupm::DeviceType T>
VecCreateMPICUPMAsync(MPI_Comm comm,PetscInt n,PetscInt N,Vec * v)59 inline PetscErrorCode VecCreateMPICUPMAsync(MPI_Comm comm, PetscInt n, PetscInt N, Vec *v) noexcept
60 {
61 PetscFunctionBegin;
62 PetscAssertPointer(v, 4);
63 PetscCall(impl::VecMPI_CUPM<T>::CreateMPICUPM(comm, 0, n, N, v, PETSC_TRUE));
64 PetscFunctionReturn(PETSC_SUCCESS);
65 }
66
67 template <device::cupm::DeviceType T>
VecCreateMPICUPMWithArrays(MPI_Comm comm,PetscInt bs,PetscInt n,PetscInt N,const PetscScalar cpuarray[],const PetscScalar gpuarray[],Vec * v)68 inline PetscErrorCode VecCreateMPICUPMWithArrays(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v)
69 {
70 PetscFunctionBegin;
71 if (n && cpuarray) PetscAssertPointer(cpuarray, 5);
72 PetscAssertPointer(v, 7);
73 PetscCall(impl::VecMPI_CUPM<T>::CreateMPICUPMWithArrays(comm, bs, n, N, cpuarray, gpuarray, v));
74 PetscFunctionReturn(PETSC_SUCCESS);
75 }
76
77 template <device::cupm::DeviceType T>
VecCreateMPICUPMWithArray(MPI_Comm comm,PetscInt bs,PetscInt n,PetscInt N,const PetscScalar gpuarray[],Vec * v)78 inline PetscErrorCode VecCreateMPICUPMWithArray(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar gpuarray[], Vec *v)
79 {
80 PetscFunctionBegin;
81 PetscCall(VecCreateMPICUPMWithArrays<T>(comm, bs, n, N, nullptr, gpuarray, v));
82 PetscFunctionReturn(PETSC_SUCCESS);
83 }
84
85 } // namespace cupm
86
87 } // namespace vec
88
89 } // namespace Petsc
90
91 #if PetscDefined(HAVE_CUDA)
92 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecMPI_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
93 #endif
94
95 #if PetscDefined(HAVE_HIP)
96 extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecMPI_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
97 #endif
98