1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmblasinterface.hpp> 6 7 #include <array> 8 9 namespace Petsc 10 { 11 12 namespace Device 13 { 14 15 namespace CUPM 16 { 17 18 namespace Impl 19 { 20 21 namespace detail 22 { 23 24 // for tag-based dispatch of handle retrieval 25 template <typename T> struct HandleTag { using type = T; }; 26 27 } // namespace detail 28 29 // Forward declare 30 template <DeviceType T> class PETSC_VISIBILITY_INTERNAL DeviceContext; 31 32 template <DeviceType T> 33 class DeviceContext : Impl::BlasInterface<T> 34 { 35 public: 36 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t,T); 37 38 private: 39 template <typename H> using HandleTag = typename detail::HandleTag<H>; 40 using stream_tag = HandleTag<cupmStream_t>; 41 using blas_tag = HandleTag<cupmBlasHandle_t>; 42 using solver_tag = HandleTag<cupmSolverHandle_t>; 43 44 public: 45 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 46 // header, but since we are using the power of templates it must be declared part of 47 // this class to have easy access the same typedefs. Technically one can make a 48 // templated struct outside the class but it's more code for the same result. 49 struct PetscDeviceContext_IMPLS 50 { 51 cupmStream_t stream; 52 cupmEvent_t event; 53 cupmEvent_t begin; // timer-only 54 cupmEvent_t end; // timer-only 55 #if PetscDefined(USE_DEBUG) 56 PetscBool timerInUse; 57 #endif 58 cupmBlasHandle_t blas; 59 cupmSolverHandle_t solver; 60 61 PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; } 62 PETSC_NODISCARD auto get(blas_tag) const -> decltype(this->blas) { return this->blas; } 63 PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; } 64 }; 65 66 private: 67 static bool initialized_; 68 static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> blashandles_; 69 static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> solverhandles_; 70 71 PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS* impls_cast_(PetscDeviceContext ptr)) 72 { 73 return static_cast<PetscDeviceContext_IMPLS*>(ptr->data); 74 } 75 76 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(cupmBlasHandle_t &handle)) 77 { 78 PetscFunctionBegin; 79 if (handle) PetscFunctionReturn(0); 80 for (auto i = 0; i < 3; ++i) { 81 auto cberr = cupmBlasCreate(&handle); 82 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 83 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) CHKERRCUPMBLAS(cberr); 84 if (i != 2) { 85 auto ierr = PetscSleep(3);CHKERRQ(ierr); 86 continue; 87 } 88 PetscAssertFalse(cberr != CUPMBLAS_STATUS_SUCCESS,PETSC_COMM_SELF,PETSC_ERR_GPU_RESOURCE,"Unable to initialize %s",cupmBlasName()); 89 } 90 PetscFunctionReturn(0); 91 } 92 93 PETSC_CXX_COMPAT_DECL(PetscErrorCode set_handle_stream_(cupmBlasHandle_t &handle, cupmStream_t &stream)) 94 { 95 cupmStream_t cupmStream; 96 cupmBlasError_t cberr; 97 98 PetscFunctionBegin; 99 cberr = cupmBlasGetStream(handle,&cupmStream);CHKERRCUPMBLAS(cberr); 100 if (cupmStream != stream) {cberr = cupmBlasSetStream(handle,stream);CHKERRCUPMBLAS(cberr);} 101 PetscFunctionReturn(0); 102 } 103 104 PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_()) 105 { 106 PetscFunctionBegin; 107 for (auto&& handle : blashandles_) { 108 if (handle) { 109 auto cberr = cupmBlasDestroy(handle);CHKERRCUPMBLAS(cberr); 110 handle = nullptr; 111 } 112 } 113 for (auto&& handle : solverhandles_) { 114 if (handle) { 115 auto ierr = cupmBlasInterface_t::DestroyHandle(handle);CHKERRQ(ierr); 116 handle = nullptr; 117 } 118 } 119 initialized_ = false; 120 PetscFunctionReturn(0); 121 } 122 123 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_(PetscInt id, PetscDeviceContext_IMPLS *dci)) 124 { 125 PetscErrorCode ierr; 126 127 PetscFunctionBegin; 128 ierr = PetscDeviceCheckDeviceCount_Internal(id);CHKERRQ(ierr); 129 if (!initialized_) { 130 initialized_ = true; 131 ierr = PetscRegisterFinalize(finalize_);CHKERRQ(ierr); 132 } 133 // use the blashandle as a canary 134 if (!blashandles_[id]) { 135 ierr = initialize_handle_(blashandles_[id]);CHKERRQ(ierr); 136 ierr = cupmBlasInterface_t::InitializeHandle(solverhandles_[id]);CHKERRQ(ierr); 137 } 138 ierr = set_handle_stream_(blashandles_[id],dci->stream);CHKERRQ(ierr); 139 ierr = cupmBlasInterface_t::SetHandleStream(solverhandles_[id],dci->stream);CHKERRQ(ierr); 140 dci->blas = blashandles_[id]; 141 dci->solver = solverhandles_[id]; 142 PetscFunctionReturn(0); 143 } 144 145 public: 146 const struct _DeviceContextOps ops = { 147 destroy, 148 changeStreamType, 149 setUp, 150 query, 151 waitForContext, 152 synchronize, 153 getHandle<blas_tag>, 154 getHandle<solver_tag>, 155 getHandle<stream_tag>, 156 beginTimer, 157 endTimer, 158 }; 159 160 // All of these functions MUST be static in order to be callable from C, otherwise they 161 // get the implicit 'this' pointer tacked on 162 PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); 163 PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType)); 164 PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); 165 PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext,PetscBool*)); 166 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext)); 167 PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); 168 template <typename Handle_t> 169 PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext,void*)); 170 PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); 171 PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*)); 172 }; 173 174 template <DeviceType T> 175 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) 176 { 177 cupmError_t cerr; 178 PetscErrorCode ierr; 179 auto dci = impls_cast_(dctx); 180 181 PetscFunctionBegin; 182 if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} 183 if (dci->event) {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr); } 184 if (dci->begin) {cerr = cupmEventDestroy(dci->begin);CHKERRCUPM(cerr); } 185 if (dci->end) {cerr = cupmEventDestroy(dci->end);CHKERRCUPM(cerr); } 186 ierr = PetscFree(dctx->data);CHKERRQ(ierr); 187 PetscFunctionReturn(0); 188 } 189 190 template <DeviceType T> 191 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) 192 { 193 auto dci = impls_cast_(dctx); 194 195 PetscFunctionBegin; 196 if (dci->stream) { 197 auto cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr); 198 dci->stream = nullptr; 199 } 200 // set these to null so they aren't usable until setup is called again 201 dci->blas = nullptr; 202 dci->solver = nullptr; 203 PetscFunctionReturn(0); 204 } 205 206 template <DeviceType T> 207 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) 208 { 209 PetscErrorCode ierr; 210 cupmError_t cerr; 211 auto dci = impls_cast_(dctx); 212 213 PetscFunctionBegin; 214 if (dci->stream) { 215 cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr); 216 dci->stream = nullptr; 217 } 218 switch (dctx->streamType) { 219 case PETSC_STREAM_GLOBAL_BLOCKING: 220 // don't create a stream for global blocking 221 break; 222 case PETSC_STREAM_DEFAULT_BLOCKING: 223 cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr); 224 break; 225 case PETSC_STREAM_GLOBAL_NONBLOCKING: 226 cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr); 227 break; 228 default: 229 SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[util::integral_value(dctx->streamType)]); 230 break; 231 } 232 if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);} 233 #if PetscDefined(USE_DEBUG) 234 dci->timerInUse = PETSC_FALSE; 235 #endif 236 ierr = initialize_(dctx->device->deviceId,dci);CHKERRQ(ierr); 237 PetscFunctionReturn(0); 238 } 239 240 template <DeviceType T> 241 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) 242 { 243 cupmError_t cerr; 244 245 PetscFunctionBegin; 246 cerr = cupmStreamQuery(impls_cast_(dctx)->stream); 247 if (cerr == cupmSuccess) *idle = PETSC_TRUE; 248 else { 249 // somethings gone wrong 250 if (PetscUnlikely(cerr != cupmErrorNotReady)) CHKERRCUPM(cerr); 251 *idle = PETSC_FALSE; 252 } 253 PetscFunctionReturn(0); 254 } 255 256 template <DeviceType T> 257 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) 258 { 259 cupmError_t cerr; 260 auto dcib = impls_cast_(dctxb); 261 262 PetscFunctionBegin; 263 cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr); 264 cerr = cupmStreamWaitEvent(impls_cast_(dctxa)->stream,dcib->event,0);CHKERRCUPM(cerr); 265 PetscFunctionReturn(0); 266 } 267 268 template <DeviceType T> 269 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) 270 { 271 cupmError_t cerr; 272 auto dci = impls_cast_(dctx); 273 274 PetscFunctionBegin; 275 // in case anything was queued on the event 276 cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr); 277 cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr); 278 PetscFunctionReturn(0); 279 } 280 281 template <DeviceType T> 282 template <typename handle_t> 283 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) 284 { 285 PetscFunctionBegin; 286 *static_cast<typename handle_t::type*>(handle) = impls_cast_(dctx)->get(handle_t()); 287 PetscFunctionReturn(0); 288 } 289 290 template <DeviceType T> 291 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) 292 { 293 auto dci = impls_cast_(dctx); 294 cupmError_t cerr; 295 296 PetscFunctionBegin; 297 #if PetscDefined(USE_DEBUG) 298 PetscAssertFalse(dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?"); 299 dci->timerInUse = PETSC_TRUE; 300 #endif 301 if (!dci->begin) { 302 cerr = cupmEventCreate(&dci->begin);CHKERRCUPM(cerr); 303 cerr = cupmEventCreate(&dci->end);CHKERRCUPM(cerr); 304 } 305 cerr = cupmEventRecord(dci->begin,dci->stream);CHKERRCUPM(cerr); 306 PetscFunctionReturn(0); 307 } 308 309 template <DeviceType T> 310 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) 311 { 312 cupmError_t cerr; 313 float gtime; 314 auto dci = impls_cast_(dctx); 315 316 PetscFunctionBegin; 317 #if PetscDefined(USE_DEBUG) 318 PetscAssertFalse(!dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?"); 319 dci->timerInUse = PETSC_FALSE; 320 #endif 321 cerr = cupmEventRecord(dci->end,dci->stream);CHKERRCUPM(cerr); 322 cerr = cupmEventSynchronize(dci->end);CHKERRCUPM(cerr); 323 cerr = cupmEventElapsedTime(>ime,dci->begin,dci->end);CHKERRCUPM(cerr); 324 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 325 PetscFunctionReturn(0); 326 } 327 328 // initialize the static member variables 329 template <DeviceType T> bool DeviceContext<T>::initialized_ = false; 330 331 template <DeviceType T> 332 std::array<typename DeviceContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 333 334 template <DeviceType T> 335 std::array<typename DeviceContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 336 337 } // namespace Impl 338 339 // shorten this one up a bit (and instantiate the templates) 340 using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>; 341 using CUPMContextHip = Impl::DeviceContext<DeviceType::HIP>; 342 343 // shorthand for what is an EXTREMELY long name 344 #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 345 346 } // namespace CUPM 347 348 } // namespace Device 349 350 } // namespace Petsc 351 352 #endif // PETSCDEVICECONTEXTCUDA_HPP 353