1 #ifndef PETSCMATMPIDENSECUPM_HPP 2 #define PETSCMATMPIDENSECUPM_HPP 3 4 #include <petsc/private/matdensecupmimpl.h> /*I <petscmat.h> I*/ 5 #include <../src/mat/impls/dense/mpi/mpidense.h> 6 7 #ifdef __cplusplus 8 #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp> 9 #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp> 10 11 namespace Petsc 12 { 13 14 namespace mat 15 { 16 17 namespace cupm 18 { 19 20 namespace impl 21 { 22 23 template <device::cupm::DeviceType T> 24 class MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> { 25 public: 26 MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>); 27 28 private: 29 PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept; 30 PETSC_NODISCARD static constexpr MatType MATIMPLCUPM_() noexcept; 31 32 static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept; 33 34 template <bool to_host> 35 static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept; 36 37 public: 38 PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept; 39 40 PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept; 41 PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept; 42 43 PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept; 44 PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept; 45 46 static PetscErrorCode Create(Mat) noexcept; 47 48 static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept; 49 static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept; 50 static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept; 51 52 template <PetscMemType, PetscMemoryAccessMode> 53 static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept; 54 template <PetscMemType, PetscMemoryAccessMode> 55 static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept; 56 57 private: 58 template <PetscMemType mtype, PetscMemoryAccessMode mode> 59 static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept 60 { 61 return GetArray<mtype, mode>(m, p); 62 } 63 64 template <PetscMemType mtype, PetscMemoryAccessMode mode> 65 static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept 66 { 67 return RestoreArray<mtype, mode>(m, p); 68 } 69 70 public: 71 template <PetscMemoryAccessMode> 72 static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept; 73 template <PetscMemoryAccessMode> 74 static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept; 75 76 static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept; 77 static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept; 78 static PetscErrorCode ResetArray(Mat) noexcept; 79 80 static PetscErrorCode Shift(Mat, PetscScalar) noexcept; 81 }; 82 83 } // namespace impl 84 85 namespace 86 { 87 88 // Declare this here so that the functions below can make use of it 89 template <device::cupm::DeviceType T> 90 inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept 91 { 92 PetscFunctionBegin; 93 PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate)); 94 PetscFunctionReturn(PETSC_SUCCESS); 95 } 96 97 } // anonymous namespace 98 99 namespace impl 100 { 101 102 // ========================================================================================== 103 // MatDense_MPI_CUPM -- Private API 104 // ========================================================================================== 105 106 template <device::cupm::DeviceType T> 107 inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept 108 { 109 return static_cast<Mat_MPIDense *>(m->data); 110 } 111 112 template <device::cupm::DeviceType T> 113 inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept 114 { 115 return MATMPIDENSECUPM(); 116 } 117 118 // ========================================================================================== 119 120 template <device::cupm::DeviceType T> 121 inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept 122 { 123 PetscFunctionBegin; 124 if (auto &mimplA = MatIMPLCast(A)->A) { 125 PetscCall(MatSetType(mimplA, MATSEQDENSECUPM())); 126 PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array)); 127 } else { 128 PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx)); 129 } 130 PetscFunctionReturn(PETSC_SUCCESS); 131 } 132 133 template <device::cupm::DeviceType T> 134 template <bool to_host> 135 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept 136 { 137 PetscFunctionBegin; 138 if (reuse == MAT_INITIAL_MATRIX) { 139 PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat)); 140 } else if (reuse == MAT_REUSE_MATRIX) { 141 PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN)); 142 } 143 { 144 const auto B = *newmat; 145 const auto pobj = PetscObjectCast(B); 146 147 if (to_host) { 148 PetscCall(BindToCPU(B, PETSC_TRUE)); 149 } else { 150 PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM())); 151 } 152 153 PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype)); 154 PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM())); 155 156 // ============================================================ 157 // Composed Ops 158 // ============================================================ 159 MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense); 160 MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense); 161 MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense); 162 MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ); 163 MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ); 164 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 165 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 166 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 167 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>); 168 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>); 169 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>); 170 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray); 171 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray); 172 MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray); 173 174 if (to_host) { 175 if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A)); 176 B->offloadmask = PETSC_OFFLOAD_CPU; 177 } else { 178 if (auto &m_A = MatIMPLCast(B)->A) { 179 PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A)); 180 B->offloadmask = PETSC_OFFLOAD_BOTH; 181 } else { 182 B->offloadmask = PETSC_OFFLOAD_UNALLOCATED; 183 } 184 PetscCall(BindToCPU(B, PETSC_FALSE)); 185 } 186 187 // ============================================================ 188 // Function Pointer Ops 189 // ============================================================ 190 MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU); 191 } 192 PetscFunctionReturn(PETSC_SUCCESS); 193 } 194 195 // ========================================================================================== 196 // MatDense_MPI_CUPM -- Public API 197 // ========================================================================================== 198 199 template <device::cupm::DeviceType T> 200 inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept 201 { 202 return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C"; 203 } 204 205 template <device::cupm::DeviceType T> 206 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept 207 { 208 return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C"; 209 } 210 211 template <device::cupm::DeviceType T> 212 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept 213 { 214 return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C"; 215 } 216 217 template <device::cupm::DeviceType T> 218 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept 219 { 220 return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C"; 221 } 222 223 template <device::cupm::DeviceType T> 224 inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept 225 { 226 return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C"; 227 } 228 229 // ========================================================================================== 230 231 template <device::cupm::DeviceType T> 232 inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept 233 { 234 PetscFunctionBegin; 235 PetscCall(MatCreate_MPIDense(A)); 236 PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A)); 237 PetscFunctionReturn(PETSC_SUCCESS); 238 } 239 240 // ========================================================================================== 241 242 template <device::cupm::DeviceType T> 243 inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept 244 { 245 const auto mimpl = MatIMPLCast(A); 246 const auto pobj = PetscObjectCast(A); 247 248 PetscFunctionBegin; 249 PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 250 PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 251 if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost)); 252 A->boundtocpu = usehost; 253 PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype)); 254 if (!usehost) { 255 PetscBool iscupm; 256 257 PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm)); 258 if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec)); 259 PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm)); 260 if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat)); 261 } 262 263 MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 264 MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>); 265 MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>); 266 MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>); 267 MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 268 MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>); 269 270 MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift); 271 272 if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost)); 273 PetscFunctionReturn(PETSC_SUCCESS); 274 } 275 276 template <device::cupm::DeviceType T> 277 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept 278 { 279 PetscFunctionBegin; 280 PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat)); 281 PetscFunctionReturn(PETSC_SUCCESS); 282 } 283 284 template <device::cupm::DeviceType T> 285 inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept 286 { 287 PetscFunctionBegin; 288 PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat)); 289 PetscFunctionReturn(PETSC_SUCCESS); 290 } 291 292 // ========================================================================================== 293 294 template <device::cupm::DeviceType T> 295 template <PetscMemType, PetscMemoryAccessMode access> 296 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept 297 { 298 PetscFunctionBegin; 299 PetscCall(MatDenseCUPMGetArray_Private<T, access>(MatIMPLCast(A)->A, array)); 300 PetscFunctionReturn(PETSC_SUCCESS); 301 } 302 303 template <device::cupm::DeviceType T> 304 template <PetscMemType, PetscMemoryAccessMode access> 305 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept 306 { 307 PetscFunctionBegin; 308 PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array)); 309 PetscFunctionReturn(PETSC_SUCCESS); 310 } 311 312 // ========================================================================================== 313 314 template <device::cupm::DeviceType T> 315 template <PetscMemoryAccessMode access> 316 inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept 317 { 318 using namespace vec::cupm; 319 320 const auto mimpl = MatIMPLCast(A); 321 const auto mimpl_A = mimpl->A; 322 const auto pobj = PetscObjectCast(A); 323 auto &cvec = mimpl->cvec; 324 PetscInt lda; 325 326 PetscFunctionBegin; 327 PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 328 PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 329 mimpl->vecinuse = col + 1; 330 331 if (!cvec) PetscCall(VecCreateMPICUPMWithArray<T>(PetscObjectComm(pobj), A->rmap->bs, A->rmap->n, A->rmap->N, nullptr, &cvec)); 332 333 PetscCall(MatDenseGetLDA(mimpl_A, &lda)); 334 PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse))); 335 PetscCall(VecCUPMPlaceArrayAsync<T>(cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda))); 336 337 if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(cvec)); 338 *v = cvec; 339 PetscFunctionReturn(PETSC_SUCCESS); 340 } 341 342 template <device::cupm::DeviceType T> 343 template <PetscMemoryAccessMode access> 344 inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept 345 { 346 using namespace vec::cupm; 347 348 const auto mimpl = MatIMPLCast(A); 349 const auto cvec = mimpl->cvec; 350 351 PetscFunctionBegin; 352 PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first"); 353 PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector"); 354 mimpl->vecinuse = 0; 355 356 PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse))); 357 if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec)); 358 PetscCall(VecCUPMResetArrayAsync<T>(cvec)); 359 360 if (v) *v = nullptr; 361 PetscFunctionReturn(PETSC_SUCCESS); 362 } 363 364 // ========================================================================================== 365 366 template <device::cupm::DeviceType T> 367 inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept 368 { 369 const auto mimpl = MatIMPLCast(A); 370 371 PetscFunctionBegin; 372 PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 373 PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 374 PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array)); 375 PetscFunctionReturn(PETSC_SUCCESS); 376 } 377 378 template <device::cupm::DeviceType T> 379 inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept 380 { 381 const auto mimpl = MatIMPLCast(A); 382 383 PetscFunctionBegin; 384 PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 385 PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 386 PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array)); 387 PetscFunctionReturn(PETSC_SUCCESS); 388 } 389 390 template <device::cupm::DeviceType T> 391 inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept 392 { 393 const auto mimpl = MatIMPLCast(A); 394 395 PetscFunctionBegin; 396 PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first"); 397 PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first"); 398 PetscCall(MatDenseCUPMResetArray<T>(mimpl->A)); 399 PetscFunctionReturn(PETSC_SUCCESS); 400 } 401 402 // ========================================================================================== 403 404 template <device::cupm::DeviceType T> 405 inline PetscErrorCode MatDense_MPI_CUPM<T>::Shift(Mat A, PetscScalar alpha) noexcept 406 { 407 PetscDeviceContext dctx; 408 409 PetscFunctionBegin; 410 PetscCall(GetHandles_(&dctx)); 411 PetscCall(PetscInfo(A, "Performing Shift on backend\n")); 412 PetscCall(DiagonalUnaryTransform(A, A->rmap->rstart, A->rmap->rend, A->cmap->N, dctx, device::cupm::functors::make_plus_equals(alpha))); 413 PetscFunctionReturn(PETSC_SUCCESS); 414 } 415 416 } // namespace impl 417 418 namespace 419 { 420 421 template <device::cupm::DeviceType T> 422 inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept 423 { 424 PetscMPIInt size; 425 426 PetscFunctionBegin; 427 PetscValidPointer(A, 7); 428 PetscCallMPI(MPI_Comm_size(comm, &size)); 429 if (size > 1) { 430 PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx)); 431 } else { 432 if (n == PETSC_DECIDE) n = N; 433 if (m == PETSC_DECIDE) m = M; 434 // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down 435 // the line 436 PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx)); 437 } 438 PetscFunctionReturn(PETSC_SUCCESS); 439 } 440 441 } // anonymous namespace 442 443 } // namespace cupm 444 445 } // namespace mat 446 447 } // namespace Petsc 448 449 #endif // __cplusplus 450 451 #endif // PETSCMATMPIDENSECUPM_HPP 452