1a4963045SJacob Faibussowitsch #pragma once
24742e46bSJacob Faibussowitsch
34742e46bSJacob Faibussowitsch #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/
44742e46bSJacob Faibussowitsch #include <../src/mat/impls/dense/mpi/mpidense.h>
54742e46bSJacob Faibussowitsch
64742e46bSJacob Faibussowitsch #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp>
74742e46bSJacob Faibussowitsch #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp>
84742e46bSJacob Faibussowitsch
94742e46bSJacob Faibussowitsch namespace Petsc
104742e46bSJacob Faibussowitsch {
114742e46bSJacob Faibussowitsch
124742e46bSJacob Faibussowitsch namespace mat
134742e46bSJacob Faibussowitsch {
144742e46bSJacob Faibussowitsch
154742e46bSJacob Faibussowitsch namespace cupm
164742e46bSJacob Faibussowitsch {
174742e46bSJacob Faibussowitsch
184742e46bSJacob Faibussowitsch namespace impl
194742e46bSJacob Faibussowitsch {
204742e46bSJacob Faibussowitsch
214742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
22*85f25e71SJed Brown class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> {
234742e46bSJacob Faibussowitsch public:
244742e46bSJacob Faibussowitsch MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>);
254742e46bSJacob Faibussowitsch
264742e46bSJacob Faibussowitsch private:
274742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept;
284742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr MatType MATIMPLCUPM_() noexcept;
294742e46bSJacob Faibussowitsch
304742e46bSJacob Faibussowitsch static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept;
314742e46bSJacob Faibussowitsch
324742e46bSJacob Faibussowitsch template <bool to_host>
334742e46bSJacob Faibussowitsch static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept;
344742e46bSJacob Faibussowitsch
354742e46bSJacob Faibussowitsch public:
364742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept;
374742e46bSJacob Faibussowitsch
384742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept;
394742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept;
404742e46bSJacob Faibussowitsch
414742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept;
424742e46bSJacob Faibussowitsch PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept;
434742e46bSJacob Faibussowitsch
444742e46bSJacob Faibussowitsch static PetscErrorCode Create(Mat) noexcept;
454742e46bSJacob Faibussowitsch
464742e46bSJacob Faibussowitsch static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept;
474742e46bSJacob Faibussowitsch static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept;
484742e46bSJacob Faibussowitsch static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept;
494742e46bSJacob Faibussowitsch
504742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode>
514742e46bSJacob Faibussowitsch static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
524742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode>
534742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
544742e46bSJacob Faibussowitsch
554742e46bSJacob Faibussowitsch private:
564742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode mode>
GetArrayC_(Mat m,PetscScalar ** p)574742e46bSJacob Faibussowitsch static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept
584742e46bSJacob Faibussowitsch {
594742e46bSJacob Faibussowitsch return GetArray<mtype, mode>(m, p);
604742e46bSJacob Faibussowitsch }
614742e46bSJacob Faibussowitsch
624742e46bSJacob Faibussowitsch template <PetscMemType mtype, PetscMemoryAccessMode mode>
RestoreArrayC_(Mat m,PetscScalar ** p)634742e46bSJacob Faibussowitsch static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept
644742e46bSJacob Faibussowitsch {
654742e46bSJacob Faibussowitsch return RestoreArray<mtype, mode>(m, p);
664742e46bSJacob Faibussowitsch }
674742e46bSJacob Faibussowitsch
684742e46bSJacob Faibussowitsch public:
694742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode>
704742e46bSJacob Faibussowitsch static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept;
714742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode>
724742e46bSJacob Faibussowitsch static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept;
734742e46bSJacob Faibussowitsch
744742e46bSJacob Faibussowitsch static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept;
754742e46bSJacob Faibussowitsch static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept;
764742e46bSJacob Faibussowitsch static PetscErrorCode ResetArray(Mat) noexcept;
774742e46bSJacob Faibussowitsch };
784742e46bSJacob Faibussowitsch
794742e46bSJacob Faibussowitsch } // namespace impl
804742e46bSJacob Faibussowitsch
814742e46bSJacob Faibussowitsch namespace
824742e46bSJacob Faibussowitsch {
834742e46bSJacob Faibussowitsch
844742e46bSJacob Faibussowitsch // Declare this here so that the functions below can make use of it
854742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
MatCreateMPIDenseCUPM(MPI_Comm comm,PetscInt m,PetscInt n,PetscInt M,PetscInt N,PetscScalar * data,Mat * A,PetscDeviceContext dctx=nullptr,bool preallocate=true)864742e46bSJacob Faibussowitsch inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept
874742e46bSJacob Faibussowitsch {
884742e46bSJacob Faibussowitsch PetscFunctionBegin;
894742e46bSJacob Faibussowitsch PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate));
904742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
914742e46bSJacob Faibussowitsch }
924742e46bSJacob Faibussowitsch
934742e46bSJacob Faibussowitsch } // anonymous namespace
944742e46bSJacob Faibussowitsch
954742e46bSJacob Faibussowitsch namespace impl
964742e46bSJacob Faibussowitsch {
974742e46bSJacob Faibussowitsch
984742e46bSJacob Faibussowitsch // ==========================================================================================
994742e46bSJacob Faibussowitsch // MatDense_MPI_CUPM -- Private API
1004742e46bSJacob Faibussowitsch // ==========================================================================================
1014742e46bSJacob Faibussowitsch
1024742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
MatIMPLCast_(Mat m)1034742e46bSJacob Faibussowitsch inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept
1044742e46bSJacob Faibussowitsch {
1054742e46bSJacob Faibussowitsch return static_cast<Mat_MPIDense *>(m->data);
1064742e46bSJacob Faibussowitsch }
1074742e46bSJacob Faibussowitsch
1084742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
MATIMPLCUPM_()1094742e46bSJacob Faibussowitsch inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept
1104742e46bSJacob Faibussowitsch {
1114742e46bSJacob Faibussowitsch return MATMPIDENSECUPM();
1124742e46bSJacob Faibussowitsch }
1134742e46bSJacob Faibussowitsch
1144742e46bSJacob Faibussowitsch // ==========================================================================================
1154742e46bSJacob Faibussowitsch
1164742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
SetPreallocation_(Mat A,PetscDeviceContext dctx,PetscScalar * device_array)1174742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
1184742e46bSJacob Faibussowitsch {
1194742e46bSJacob Faibussowitsch PetscFunctionBegin;
1204742e46bSJacob Faibussowitsch if (auto &mimplA = MatIMPLCast(A)->A) {
1214742e46bSJacob Faibussowitsch PetscCall(MatSetType(mimplA, MATSEQDENSECUPM()));
1224742e46bSJacob Faibussowitsch PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array));
1234742e46bSJacob Faibussowitsch } else {
1244742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx));
1254742e46bSJacob Faibussowitsch }
1264742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
1274742e46bSJacob Faibussowitsch }
1284742e46bSJacob Faibussowitsch
1294742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
1304742e46bSJacob Faibussowitsch template <bool to_host>
Convert_Dispatch_(Mat M,MatType,MatReuse reuse,Mat * newmat)1314742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept
1324742e46bSJacob Faibussowitsch {
1334742e46bSJacob Faibussowitsch PetscFunctionBegin;
1344742e46bSJacob Faibussowitsch if (reuse == MAT_INITIAL_MATRIX) {
1354742e46bSJacob Faibussowitsch PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat));
1364742e46bSJacob Faibussowitsch } else if (reuse == MAT_REUSE_MATRIX) {
1374742e46bSJacob Faibussowitsch PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN));
1384742e46bSJacob Faibussowitsch }
1394742e46bSJacob Faibussowitsch {
1404742e46bSJacob Faibussowitsch const auto B = *newmat;
1414742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(B);
1424742e46bSJacob Faibussowitsch
1434742e46bSJacob Faibussowitsch if (to_host) {
1444742e46bSJacob Faibussowitsch PetscCall(BindToCPU(B, PETSC_TRUE));
1454742e46bSJacob Faibussowitsch } else {
1464742e46bSJacob Faibussowitsch PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
1474742e46bSJacob Faibussowitsch }
1484742e46bSJacob Faibussowitsch
1494742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype));
1504742e46bSJacob Faibussowitsch PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM()));
1514742e46bSJacob Faibussowitsch
1524742e46bSJacob Faibussowitsch // ============================================================
1534742e46bSJacob Faibussowitsch // Composed Ops
1544742e46bSJacob Faibussowitsch // ============================================================
1554742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense);
1564742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
1574742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
1584742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
1594742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
1604742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
1614742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
1624742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
1634742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
1644742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
1654742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
1664742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray);
1674742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray);
1684742e46bSJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray);
1693d9668e3SJacob Faibussowitsch MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMSetPreallocation_C(), nullptr, SetPreallocation);
1704742e46bSJacob Faibussowitsch
1714742e46bSJacob Faibussowitsch if (to_host) {
1724742e46bSJacob Faibussowitsch if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A));
1734742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_CPU;
1744742e46bSJacob Faibussowitsch } else {
1754742e46bSJacob Faibussowitsch if (auto &m_A = MatIMPLCast(B)->A) {
1764742e46bSJacob Faibussowitsch PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A));
1774742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_BOTH;
1784742e46bSJacob Faibussowitsch } else {
1794742e46bSJacob Faibussowitsch B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
1804742e46bSJacob Faibussowitsch }
1814742e46bSJacob Faibussowitsch PetscCall(BindToCPU(B, PETSC_FALSE));
1824742e46bSJacob Faibussowitsch }
1834742e46bSJacob Faibussowitsch
1844742e46bSJacob Faibussowitsch // ============================================================
1854742e46bSJacob Faibussowitsch // Function Pointer Ops
1864742e46bSJacob Faibussowitsch // ============================================================
18714277c92SJacob Faibussowitsch MatSetOp_CUPM(to_host, B, getdiagonal, MatGetDiagonal_MPIDense, GetDiagonal);
1884742e46bSJacob Faibussowitsch MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU);
1894742e46bSJacob Faibussowitsch }
1904742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
1914742e46bSJacob Faibussowitsch }
1924742e46bSJacob Faibussowitsch
1934742e46bSJacob Faibussowitsch // ==========================================================================================
1944742e46bSJacob Faibussowitsch // MatDense_MPI_CUPM -- Public API
1954742e46bSJacob Faibussowitsch // ==========================================================================================
1964742e46bSJacob Faibussowitsch
1974742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
MatConvert_mpidensecupm_mpidense_C()1984742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept
1994742e46bSJacob Faibussowitsch {
2004742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C";
2014742e46bSJacob Faibussowitsch }
2024742e46bSJacob Faibussowitsch
2034742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
MatProductSetFromOptions_mpiaij_mpidensecupm_C()2044742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept
2054742e46bSJacob Faibussowitsch {
2064742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C";
2074742e46bSJacob Faibussowitsch }
2084742e46bSJacob Faibussowitsch
2094742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
MatProductSetFromOptions_mpidensecupm_mpiaij_C()2104742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept
2114742e46bSJacob Faibussowitsch {
2124742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C";
2134742e46bSJacob Faibussowitsch }
2144742e46bSJacob Faibussowitsch
2154742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C()2164742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept
2174742e46bSJacob Faibussowitsch {
2184742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C";
2194742e46bSJacob Faibussowitsch }
2204742e46bSJacob Faibussowitsch
2214742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C()2224742e46bSJacob Faibussowitsch inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept
2234742e46bSJacob Faibussowitsch {
2244742e46bSJacob Faibussowitsch return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C";
2254742e46bSJacob Faibussowitsch }
2264742e46bSJacob Faibussowitsch
2274742e46bSJacob Faibussowitsch // ==========================================================================================
2284742e46bSJacob Faibussowitsch
2294742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
Create(Mat A)2304742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept
2314742e46bSJacob Faibussowitsch {
2324742e46bSJacob Faibussowitsch PetscFunctionBegin;
2334742e46bSJacob Faibussowitsch PetscCall(MatCreate_MPIDense(A));
2344742e46bSJacob Faibussowitsch PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A));
2354742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
2364742e46bSJacob Faibussowitsch }
2374742e46bSJacob Faibussowitsch
2384742e46bSJacob Faibussowitsch // ==========================================================================================
2394742e46bSJacob Faibussowitsch
2404742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
BindToCPU(Mat A,PetscBool usehost)2414742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept
2424742e46bSJacob Faibussowitsch {
2434742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A);
2444742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(A);
2454742e46bSJacob Faibussowitsch
2464742e46bSJacob Faibussowitsch PetscFunctionBegin;
2474742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
2484742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
2494742e46bSJacob Faibussowitsch if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost));
2504742e46bSJacob Faibussowitsch A->boundtocpu = usehost;
2514742e46bSJacob Faibussowitsch PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype));
2524742e46bSJacob Faibussowitsch if (!usehost) {
2534742e46bSJacob Faibussowitsch PetscBool iscupm;
2544742e46bSJacob Faibussowitsch
2554742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm));
2564742e46bSJacob Faibussowitsch if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec));
2574742e46bSJacob Faibussowitsch PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm));
2584742e46bSJacob Faibussowitsch if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat));
2594742e46bSJacob Faibussowitsch }
2604742e46bSJacob Faibussowitsch
2614742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
2624742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
2634742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>);
2644742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>);
2654742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
2664742e46bSJacob Faibussowitsch MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
2674742e46bSJacob Faibussowitsch
2684742e46bSJacob Faibussowitsch MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift);
2694742e46bSJacob Faibussowitsch
2704742e46bSJacob Faibussowitsch if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost));
2714742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
2724742e46bSJacob Faibussowitsch }
2734742e46bSJacob Faibussowitsch
2744742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
Convert_MPIDenseCUPM_MPIDense(Mat M,MatType mtype,MatReuse reuse,Mat * newmat)2754742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
2764742e46bSJacob Faibussowitsch {
2774742e46bSJacob Faibussowitsch PetscFunctionBegin;
2784742e46bSJacob Faibussowitsch PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat));
2794742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
2804742e46bSJacob Faibussowitsch }
2814742e46bSJacob Faibussowitsch
2824742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
Convert_MPIDense_MPIDenseCUPM(Mat M,MatType mtype,MatReuse reuse,Mat * newmat)2834742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
2844742e46bSJacob Faibussowitsch {
2854742e46bSJacob Faibussowitsch PetscFunctionBegin;
2864742e46bSJacob Faibussowitsch PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat));
2874742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
2884742e46bSJacob Faibussowitsch }
2894742e46bSJacob Faibussowitsch
2904742e46bSJacob Faibussowitsch // ==========================================================================================
2914742e46bSJacob Faibussowitsch
2924742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
2934742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode access>
GetArray(Mat A,PetscScalar ** array,PetscDeviceContext dctx)29414277c92SJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext dctx) noexcept
2954742e46bSJacob Faibussowitsch {
29614277c92SJacob Faibussowitsch auto &mimplA = MatIMPLCast(A)->A;
29714277c92SJacob Faibussowitsch
2984742e46bSJacob Faibussowitsch PetscFunctionBegin;
29914277c92SJacob Faibussowitsch if (!mimplA) PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, nullptr, &mimplA, dctx));
30014277c92SJacob Faibussowitsch PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimplA, array));
3014742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
3024742e46bSJacob Faibussowitsch }
3034742e46bSJacob Faibussowitsch
3044742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
3054742e46bSJacob Faibussowitsch template <PetscMemType, PetscMemoryAccessMode access>
RestoreArray(Mat A,PetscScalar ** array,PetscDeviceContext)3064742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
3074742e46bSJacob Faibussowitsch {
3084742e46bSJacob Faibussowitsch PetscFunctionBegin;
3094742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array));
3104742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
3114742e46bSJacob Faibussowitsch }
3124742e46bSJacob Faibussowitsch
3134742e46bSJacob Faibussowitsch // ==========================================================================================
3144742e46bSJacob Faibussowitsch
3154742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
3164742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access>
GetColumnVec(Mat A,PetscInt col,Vec * v)3174742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept
3184742e46bSJacob Faibussowitsch {
3194742e46bSJacob Faibussowitsch using namespace vec::cupm;
3204742e46bSJacob Faibussowitsch
3214742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A);
3224742e46bSJacob Faibussowitsch const auto mimpl_A = mimpl->A;
3234742e46bSJacob Faibussowitsch const auto pobj = PetscObjectCast(A);
3244742e46bSJacob Faibussowitsch PetscInt lda;
3254742e46bSJacob Faibussowitsch
3264742e46bSJacob Faibussowitsch PetscFunctionBegin;
3274742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
3284742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
3294742e46bSJacob Faibussowitsch mimpl->vecinuse = col + 1;
3304742e46bSJacob Faibussowitsch
331d16ceb75SStefano Zampini if (!mimpl->cvec) PetscCall(MatDenseCreateColumnVec_Private(A, &mimpl->cvec));
3324742e46bSJacob Faibussowitsch
3334742e46bSJacob Faibussowitsch PetscCall(MatDenseGetLDA(mimpl_A, &lda));
3344742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
335d16ceb75SStefano Zampini PetscCall(VecCUPMPlaceArrayAsync<T>(mimpl->cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda)));
3364742e46bSJacob Faibussowitsch
337d16ceb75SStefano Zampini if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(mimpl->cvec));
338d16ceb75SStefano Zampini *v = mimpl->cvec;
3394742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
3404742e46bSJacob Faibussowitsch }
3414742e46bSJacob Faibussowitsch
3424742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
3434742e46bSJacob Faibussowitsch template <PetscMemoryAccessMode access>
RestoreColumnVec(Mat A,PetscInt,Vec * v)3444742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept
3454742e46bSJacob Faibussowitsch {
3464742e46bSJacob Faibussowitsch using namespace vec::cupm;
3474742e46bSJacob Faibussowitsch
3484742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A);
3494742e46bSJacob Faibussowitsch const auto cvec = mimpl->cvec;
3504742e46bSJacob Faibussowitsch
3514742e46bSJacob Faibussowitsch PetscFunctionBegin;
3524742e46bSJacob Faibussowitsch PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first");
3534742e46bSJacob Faibussowitsch PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector");
3544742e46bSJacob Faibussowitsch mimpl->vecinuse = 0;
3554742e46bSJacob Faibussowitsch
3564742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
3574742e46bSJacob Faibussowitsch if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec));
3584742e46bSJacob Faibussowitsch PetscCall(VecCUPMResetArrayAsync<T>(cvec));
3594742e46bSJacob Faibussowitsch
3604742e46bSJacob Faibussowitsch if (v) *v = nullptr;
3614742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
3624742e46bSJacob Faibussowitsch }
3634742e46bSJacob Faibussowitsch
3644742e46bSJacob Faibussowitsch // ==========================================================================================
3654742e46bSJacob Faibussowitsch
3664742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
PlaceArray(Mat A,const PetscScalar * array)3674742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept
3684742e46bSJacob Faibussowitsch {
3694742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A);
3704742e46bSJacob Faibussowitsch
3714742e46bSJacob Faibussowitsch PetscFunctionBegin;
3724742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
3734742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
3744742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array));
3754742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
3764742e46bSJacob Faibussowitsch }
3774742e46bSJacob Faibussowitsch
3784742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
ReplaceArray(Mat A,const PetscScalar * array)3794742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept
3804742e46bSJacob Faibussowitsch {
3814742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A);
3824742e46bSJacob Faibussowitsch
3834742e46bSJacob Faibussowitsch PetscFunctionBegin;
3844742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
3854742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
3864742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array));
3874742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
3884742e46bSJacob Faibussowitsch }
3894742e46bSJacob Faibussowitsch
3904742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
ResetArray(Mat A)3914742e46bSJacob Faibussowitsch inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept
3924742e46bSJacob Faibussowitsch {
3934742e46bSJacob Faibussowitsch const auto mimpl = MatIMPLCast(A);
3944742e46bSJacob Faibussowitsch
3954742e46bSJacob Faibussowitsch PetscFunctionBegin;
3964742e46bSJacob Faibussowitsch PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
3974742e46bSJacob Faibussowitsch PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
3984742e46bSJacob Faibussowitsch PetscCall(MatDenseCUPMResetArray<T>(mimpl->A));
3994742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
4004742e46bSJacob Faibussowitsch }
4014742e46bSJacob Faibussowitsch
4024742e46bSJacob Faibussowitsch } // namespace impl
4034742e46bSJacob Faibussowitsch
4044742e46bSJacob Faibussowitsch namespace
4054742e46bSJacob Faibussowitsch {
4064742e46bSJacob Faibussowitsch
4074742e46bSJacob Faibussowitsch template <device::cupm::DeviceType T>
MatCreateDenseCUPM(MPI_Comm comm,PetscInt n,PetscInt m,PetscInt N,PetscInt M,PetscScalar * data,Mat * A,PetscDeviceContext dctx=nullptr)4084742e46bSJacob Faibussowitsch inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept
4094742e46bSJacob Faibussowitsch {
4104742e46bSJacob Faibussowitsch PetscMPIInt size;
4114742e46bSJacob Faibussowitsch
4124742e46bSJacob Faibussowitsch PetscFunctionBegin;
4134f572ea9SToby Isaac PetscAssertPointer(A, 7);
4144742e46bSJacob Faibussowitsch PetscCallMPI(MPI_Comm_size(comm, &size));
4154742e46bSJacob Faibussowitsch if (size > 1) {
4164742e46bSJacob Faibussowitsch PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx));
4174742e46bSJacob Faibussowitsch } else {
4184742e46bSJacob Faibussowitsch if (n == PETSC_DECIDE) n = N;
4194742e46bSJacob Faibussowitsch if (m == PETSC_DECIDE) m = M;
4204742e46bSJacob Faibussowitsch // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down
4214742e46bSJacob Faibussowitsch // the line
4224742e46bSJacob Faibussowitsch PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx));
4234742e46bSJacob Faibussowitsch }
4244742e46bSJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS);
4254742e46bSJacob Faibussowitsch }
4264742e46bSJacob Faibussowitsch
4274742e46bSJacob Faibussowitsch } // anonymous namespace
4284742e46bSJacob Faibussowitsch
4294742e46bSJacob Faibussowitsch } // namespace cupm
4304742e46bSJacob Faibussowitsch
4314742e46bSJacob Faibussowitsch } // namespace mat
4324742e46bSJacob Faibussowitsch
4334742e46bSJacob Faibussowitsch } // namespace Petsc
434