1 #ifndef PETSCMATMPIDENSECUPM_HPP 2 #define PETSCMATMPIDENSECUPM_HPP 3 4 #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/ 5 #include <../src/mat/impls/dense/mpi/mpidense.h> 6 7 #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp> 8 #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp> 9 10 namespace Petsc 11 { 12 13 namespace mat 14 { 15 16 namespace cupm 17 { 18 19 namespace impl 20 { 21 22 template <device::cupm::DeviceType T> 23 class MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> { 24 public: 25 MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>); 26 27 private: 28 PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept; 29 PETSC_NODISCARD static constexpr MatType MATIMPLCUPM_() noexcept; 30 31 static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept; 32 33 template <bool to_host> 34 static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept; 35 36 public: 37 PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept; 38 39 PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept; 40 PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept; 41 42 PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept; 43 PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept; 44 45 static PetscErrorCode Create(Mat) noexcept; 46 47 static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept; 48 static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept; 49 static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept; 50 51 template <PetscMemType, PetscMemoryAccessMode> 52 static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept; 53 template <PetscMemType, PetscMemoryAccessMode> 54 static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept; 55 56 private: 57 template <PetscMemType mtype, PetscMemoryAccessMode mode> 58 static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept 59 { 60 return GetArray<mtype, mode>(m, p); 61 } 62 63 template <PetscMemType mtype, PetscMemoryAccessMode mode> 64 static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept 65 { 66 return RestoreArray<mtype, mode>(m, p); 67 } 68 69 public: 70 template <PetscMemoryAccessMode> 71 static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept; 72 template <PetscMemoryAccessMode> 73 static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept; 74 75 static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept; 76 static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept; 77 static PetscErrorCode ResetArray(Mat) noexcept; 78 79 static PetscErrorCode Shift(Mat, PetscScalar) noexcept; 80 }; 81 82 } // namespace impl 83 84 namespace 85 { 86 87 // Declare this here so that the functions below can make use of it 88 template <device::cupm::DeviceType T> 89 inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept 90 { 91 PetscFunctionBegin; 92 PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate)); 93 PetscFunctionReturn(PETSC_SUCCESS); 94 } 95 96 } // anonymous namespace 97 98 namespace impl 99 { 100 101 // ========================================================================================== 102 // MatDense_MPI_CUPM -- Private API 103 // ========================================================================================== 104 105 template <device::cupm::DeviceType T> 106 inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept 107 { 108 return static_cast<Mat_MPIDense *>(m->data); 109 } 110 111 template <device::cupm::DeviceType T> 112 inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept 113 { 114 return MATMPIDENSECUPM(); 115 } 116 117 // ========================================================================================== 118 119 template <device::cupm::DeviceType T> 120 inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept 121 { 122 PetscFunctionBegin; 123 if (auto &mimplA = MatIMPLCast(A)->A) { 124 PetscCall(MatSetType(mimplA, MATSEQDENSECUPM())); 125 PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array)); 126 } else { 127 PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx)); 128 } 129 PetscFunctionReturn(PETSC_SUCCESS); 130 } 131 132 template <device::cupm::DeviceType T> 133 template <bool to_host> 134 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept 135 { 136 PetscFunctionBegin; 137 if (reuse == MAT_INITIAL_MATRIX) { 138 PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat)); 139 } else if (reuse == MAT_REUSE_MATRIX) { 140 PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN)); 141 } 142 { 143 const auto B = *newmat; 144 const auto pobj = PetscObjectCast(B); 145 146 if (to_host) { 147 PetscCall(BindToCPU(B, PETSC_TRUE)); 148 } else { 149 PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM())); 150 } 151 152 PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype)); 153 PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM())); 154 155 // ============================================================ 156 // Composed Ops 157 // ============================================================ 158 MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense); 159 MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense); 160 MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense); 161 MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ); 162 MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ); 163 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 164 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 165 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 166 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 167 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 168 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 169 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray); 170 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray); 171 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray); 172 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMSetPreallocation_C(), nullptr, SetPreallocation); 173 174 if (to_host) { 175 if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A)); 176 B->offloadmask = PETSC_OFFLOAD_CPU; 177 } else { 178 if (auto &m_A = MatIMPLCast(B)->A) { 179 PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A)); 180 B->offloadmask = PETSC_OFFLOAD_BOTH; 181 } else { 182 B->offloadmask = PETSC_OFFLOAD_UNALLOCATED; 183 } 184 PetscCall(BindToCPU(B, PETSC_FALSE)); 185 } 186 187 // ============================================================ 188 // Function Pointer Ops 189 // ============================================================ 190 MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU); 191 } 192 PetscFunctionReturn(PETSC_SUCCESS); 193 } 194 195 // ========================================================================================== 196 // MatDense_MPI_CUPM -- Public API 197 // ========================================================================================== 198 199 template <device::cupm::DeviceType T> 200 inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept 201 { 202 return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C"; 203 } 204 205 template <device::cupm::DeviceType T> 206 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept 207 { 208 return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C"; 209 } 210 211 template <device::cupm::DeviceType T> 212 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept 213 { 214 return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C"; 215 } 216 217 template <device::cupm::DeviceType T> 218 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept 219 { 220 return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C"; 221 } 222 223 template <device::cupm::DeviceType T> 224 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept 225 { 226 return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C"; 227 } 228 229 // ========================================================================================== 230 231 template <device::cupm::DeviceType T> 232 inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept 233 { 234 PetscFunctionBegin; 235 PetscCall(MatCreate_MPIDense(A)); 236 PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 237 PetscFunctionReturn(PETSC_SUCCESS); 238 } 239 240 // ========================================================================================== 241 242 template <device::cupm::DeviceType T> 243 inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept 244 { 245 const auto mimpl = MatIMPLCast(A); 246 const auto pobj = PetscObjectCast(A); 247 248 PetscFunctionBegin; 249 PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 250 PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 251 if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost)); 252 A->boundtocpu = usehost; 253 PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype)); 254 if (!usehost) { 255 PetscBool iscupm; 256 257 PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm)); 258 if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec)); 259 PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm)); 260 if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat)); 261 } 262 263 MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 264 MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 265 MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>); 266 MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>); 267 MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 268 MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 269 270 MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift); 271 272 if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost)); 273 PetscFunctionReturn(PETSC_SUCCESS); 274 } 275 276 template <device::cupm::DeviceType T> 277 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept 278 { 279 PetscFunctionBegin; 280 PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat)); 281 PetscFunctionReturn(PETSC_SUCCESS); 282 } 283 284 template <device::cupm::DeviceType T> 285 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept 286 { 287 PetscFunctionBegin; 288 PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat)); 289 PetscFunctionReturn(PETSC_SUCCESS); 290 } 291 292 // ========================================================================================== 293 294 template <device::cupm::DeviceType T> 295 template <PetscMemType, PetscMemoryAccessMode access> 296 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept 297 { 298 PetscFunctionBegin; 299 PetscCall(MatDenseCUPMGetArray_Private<T, access>(MatIMPLCast(A)->A, array)); 300 PetscFunctionReturn(PETSC_SUCCESS); 301 } 302 303 template <device::cupm::DeviceType T> 304 template <PetscMemType, PetscMemoryAccessMode access> 305 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept 306 { 307 PetscFunctionBegin; 308 PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array)); 309 PetscFunctionReturn(PETSC_SUCCESS); 310 } 311 312 // ========================================================================================== 313 314 template <device::cupm::DeviceType T> 315 template <PetscMemoryAccessMode access> 316 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept 317 { 318 using namespace vec::cupm; 319 320 const auto mimpl = MatIMPLCast(A); 321 const auto mimpl_A = mimpl->A; 322 const auto pobj = PetscObjectCast(A); 323 PetscInt lda; 324 325 PetscFunctionBegin; 326 PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 327 PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 328 mimpl->vecinuse = col + 1; 329 330 if (!mimpl->cvec) PetscCall(MatDenseCreateColumnVec_Private(A, &mimpl->cvec)); 331 332 PetscCall(MatDenseGetLDA(mimpl_A, &lda)); 333 PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse))); 334 PetscCall(VecCUPMPlaceArrayAsync<T>(mimpl->cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda))); 335 336 if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(mimpl->cvec)); 337 *v = mimpl->cvec; 338 PetscFunctionReturn(PETSC_SUCCESS); 339 } 340 341 template <device::cupm::DeviceType T> 342 template <PetscMemoryAccessMode access> 343 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept 344 { 345 using namespace vec::cupm; 346 347 const auto mimpl = MatIMPLCast(A); 348 const auto cvec = mimpl->cvec; 349 350 PetscFunctionBegin; 351 PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first"); 352 PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector"); 353 mimpl->vecinuse = 0; 354 355 PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse))); 356 if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec)); 357 PetscCall(VecCUPMResetArrayAsync<T>(cvec)); 358 359 if (v) *v = nullptr; 360 PetscFunctionReturn(PETSC_SUCCESS); 361 } 362 363 // ========================================================================================== 364 365 template <device::cupm::DeviceType T> 366 inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept 367 { 368 const auto mimpl = MatIMPLCast(A); 369 370 PetscFunctionBegin; 371 PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 372 PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 373 PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array)); 374 PetscFunctionReturn(PETSC_SUCCESS); 375 } 376 377 template <device::cupm::DeviceType T> 378 inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept 379 { 380 const auto mimpl = MatIMPLCast(A); 381 382 PetscFunctionBegin; 383 PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 384 PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 385 PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array)); 386 PetscFunctionReturn(PETSC_SUCCESS); 387 } 388 389 template <device::cupm::DeviceType T> 390 inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept 391 { 392 const auto mimpl = MatIMPLCast(A); 393 394 PetscFunctionBegin; 395 PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 396 PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 397 PetscCall(MatDenseCUPMResetArray<T>(mimpl->A)); 398 PetscFunctionReturn(PETSC_SUCCESS); 399 } 400 401 // ========================================================================================== 402 403 template <device::cupm::DeviceType T> 404 inline PetscErrorCode MatDense_MPI_CUPM<T>::Shift(Mat A, PetscScalar alpha) noexcept 405 { 406 PetscDeviceContext dctx; 407 408 PetscFunctionBegin; 409 PetscCall(GetHandles_(&dctx)); 410 PetscCall(PetscInfo(A, "Performing Shift on backend\n")); 411 PetscCall(DiagonalUnaryTransform(A, A->rmap->rstart, A->rmap->rend, A->cmap->N, dctx, device::cupm::functors::make_plus_equals(alpha))); 412 PetscFunctionReturn(PETSC_SUCCESS); 413 } 414 415 } // namespace impl 416 417 namespace 418 { 419 420 template <device::cupm::DeviceType T> 421 inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept 422 { 423 PetscMPIInt size; 424 425 PetscFunctionBegin; 426 PetscValidPointer(A, 7); 427 PetscCallMPI(MPI_Comm_size(comm, &size)); 428 if (size > 1) { 429 PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx)); 430 } else { 431 if (n == PETSC_DECIDE) n = N; 432 if (m == PETSC_DECIDE) m = M; 433 // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down 434 // the line 435 PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx)); 436 } 437 PetscFunctionReturn(PETSC_SUCCESS); 438 } 439 440 } // anonymous namespace 441 442 } // namespace cupm 443 444 } // namespace mat 445 446 } // namespace Petsc 447 448 #endif // PETSCMATMPIDENSECUPM_HPP 449