1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmblasinterface.hpp> 6 7 #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11) 8 # error PetscDeviceContext backends for CUDA and HIP requires C++11 9 #endif 10 11 #include <array> 12 13 namespace Petsc 14 { 15 16 namespace Device 17 { 18 19 namespace CUPM 20 { 21 22 namespace Impl 23 { 24 25 namespace detail 26 { 27 28 // for tag-based dispatch of handle retrieval 29 template <typename T> struct HandleTag { using type = T; }; 30 31 } // namespace detail 32 33 // Forward declare 34 template <DeviceType T> class PETSC_VISIBILITY_INTERNAL DeviceContext; 35 36 template <DeviceType T> 37 class DeviceContext : Impl::BlasInterface<T> 38 { 39 public: 40 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t,T); 41 42 private: 43 template <typename H> using HandleTag = typename detail::HandleTag<H>; 44 using stream_tag = HandleTag<cupmStream_t>; 45 using blas_tag = HandleTag<cupmBlasHandle_t>; 46 using solver_tag = HandleTag<cupmSolverHandle_t>; 47 48 public: 49 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 50 // header, but since we are using the power of templates it must be declared part of 51 // this class to have easy access the same typedefs. Technically one can make a 52 // templated struct outside the class but it's more code for the same result. 53 struct PetscDeviceContext_IMPLS 54 { 55 cupmStream_t stream; 56 cupmEvent_t event; 57 cupmEvent_t begin; // timer-only 58 cupmEvent_t end; // timer-only 59 #if PetscDefined(USE_DEBUG) 60 PetscBool timerInUse; 61 #endif 62 cupmBlasHandle_t blas; 63 cupmSolverHandle_t solver; 64 65 PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; } 66 PETSC_NODISCARD auto get(blas_tag) const -> decltype(this->blas) { return this->blas; } 67 PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; } 68 }; 69 70 private: 71 static bool initialized_; 72 static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> blashandles_; 73 static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> solverhandles_; 74 75 PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS* impls_cast_(PetscDeviceContext ptr)) 76 { 77 return static_cast<PetscDeviceContext_IMPLS*>(ptr->data); 78 } 79 80 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(cupmBlasHandle_t &handle)) 81 { 82 PetscFunctionBegin; 83 if (handle) PetscFunctionReturn(0); 84 for (auto i = 0; i < 3; ++i) { 85 auto cberr = cupmBlasCreate(&handle); 86 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 87 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) CHKERRCUPMBLAS(cberr); 88 if (i != 2) { 89 auto ierr = PetscSleep(3);CHKERRQ(ierr); 90 continue; 91 } 92 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_SUCCESS)) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_GPU_RESOURCE,"Unable to initialize %s",cupmBlasName()); 93 } 94 PetscFunctionReturn(0); 95 } 96 97 PETSC_CXX_COMPAT_DECL(PetscErrorCode set_handle_stream_(cupmBlasHandle_t &handle, cupmStream_t &stream)) 98 { 99 cupmStream_t cupmStream; 100 cupmBlasError_t cberr; 101 102 PetscFunctionBegin; 103 cberr = cupmBlasGetStream(handle,&cupmStream);CHKERRCUPMBLAS(cberr); 104 if (cupmStream != stream) {cberr = cupmBlasSetStream(handle,stream);CHKERRCUPMBLAS(cberr);} 105 PetscFunctionReturn(0); 106 } 107 108 PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_()) 109 { 110 PetscFunctionBegin; 111 for (auto&& handle : blashandles_) { 112 if (handle) { 113 auto cberr = cupmBlasDestroy(handle);CHKERRCUPMBLAS(cberr); 114 handle = nullptr; 115 } 116 } 117 for (auto&& handle : solverhandles_) { 118 if (handle) { 119 auto ierr = cupmBlasInterface_t::DestroyHandle(handle);CHKERRQ(ierr); 120 handle = nullptr; 121 } 122 } 123 initialized_ = false; 124 PetscFunctionReturn(0); 125 } 126 127 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_(PetscInt id, PetscDeviceContext_IMPLS *dci)) 128 { 129 PetscErrorCode ierr; 130 131 PetscFunctionBegin; 132 ierr = PetscDeviceCheckDeviceCount_Internal(id);CHKERRQ(ierr); 133 if (!initialized_) { 134 initialized_ = true; 135 ierr = PetscRegisterFinalize(finalize_);CHKERRQ(ierr); 136 } 137 // use the blashandle as a canary 138 if (!blashandles_[id]) { 139 ierr = initialize_handle_(blashandles_[id]);CHKERRQ(ierr); 140 ierr = cupmBlasInterface_t::InitializeHandle(solverhandles_[id]);CHKERRQ(ierr); 141 } 142 ierr = set_handle_stream_(blashandles_[id],dci->stream);CHKERRQ(ierr); 143 ierr = cupmBlasInterface_t::SetHandleStream(solverhandles_[id],dci->stream);CHKERRQ(ierr); 144 dci->blas = blashandles_[id]; 145 dci->solver = solverhandles_[id]; 146 PetscFunctionReturn(0); 147 } 148 149 public: 150 const struct _DeviceContextOps ops = { 151 destroy, 152 changeStreamType, 153 setUp, 154 query, 155 waitForContext, 156 synchronize, 157 getHandle<blas_tag>, 158 getHandle<solver_tag>, 159 getHandle<stream_tag>, 160 beginTimer, 161 endTimer, 162 }; 163 164 // All of these functions MUST be static in order to be callable from C, otherwise they 165 // get the implicit 'this' pointer tacked on 166 PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); 167 PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType)); 168 PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); 169 PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext,PetscBool*)); 170 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext)); 171 PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); 172 template <typename Handle_t> 173 PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext,void*)); 174 PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); 175 PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*)); 176 }; 177 178 template <DeviceType T> 179 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) 180 { 181 cupmError_t cerr; 182 PetscErrorCode ierr; 183 auto dci = impls_cast_(dctx); 184 185 PetscFunctionBegin; 186 if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);} 187 if (dci->event) {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr); } 188 if (dci->begin) {cerr = cupmEventDestroy(dci->begin);CHKERRCUPM(cerr); } 189 if (dci->end) {cerr = cupmEventDestroy(dci->end);CHKERRCUPM(cerr); } 190 ierr = PetscFree(dctx->data);CHKERRQ(ierr); 191 PetscFunctionReturn(0); 192 } 193 194 template <DeviceType T> 195 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) 196 { 197 auto dci = impls_cast_(dctx); 198 199 PetscFunctionBegin; 200 if (dci->stream) { 201 auto cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr); 202 dci->stream = nullptr; 203 } 204 // set these to null so they aren't usable until setup is called again 205 dci->blas = nullptr; 206 dci->solver = nullptr; 207 PetscFunctionReturn(0); 208 } 209 210 template <DeviceType T> 211 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) 212 { 213 PetscErrorCode ierr; 214 cupmError_t cerr; 215 auto dci = impls_cast_(dctx); 216 217 PetscFunctionBegin; 218 if (dci->stream) { 219 cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr); 220 dci->stream = nullptr; 221 } 222 switch (dctx->streamType) { 223 case PETSC_STREAM_GLOBAL_BLOCKING: 224 // don't create a stream for global blocking 225 break; 226 case PETSC_STREAM_DEFAULT_BLOCKING: 227 cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr); 228 break; 229 case PETSC_STREAM_GLOBAL_NONBLOCKING: 230 cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr); 231 break; 232 default: 233 SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[util::integral_value(dctx->streamType)]); 234 break; 235 } 236 if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);} 237 #if PetscDefined(USE_DEBUG) 238 dci->timerInUse = PETSC_FALSE; 239 #endif 240 ierr = initialize_(dctx->device->deviceId,dci);CHKERRQ(ierr); 241 PetscFunctionReturn(0); 242 } 243 244 template <DeviceType T> 245 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) 246 { 247 cupmError_t cerr; 248 249 PetscFunctionBegin; 250 cerr = cupmStreamQuery(impls_cast_(dctx)->stream); 251 if (cerr == cupmSuccess) *idle = PETSC_TRUE; 252 else { 253 // somethings gone wrong 254 if (PetscUnlikely(cerr != cupmErrorNotReady)) CHKERRCUPM(cerr); 255 *idle = PETSC_FALSE; 256 } 257 PetscFunctionReturn(0); 258 } 259 260 template <DeviceType T> 261 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) 262 { 263 cupmError_t cerr; 264 auto dcib = impls_cast_(dctxb); 265 266 PetscFunctionBegin; 267 cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr); 268 cerr = cupmStreamWaitEvent(impls_cast_(dctxa)->stream,dcib->event,0);CHKERRCUPM(cerr); 269 PetscFunctionReturn(0); 270 } 271 272 template <DeviceType T> 273 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) 274 { 275 cupmError_t cerr; 276 auto dci = impls_cast_(dctx); 277 278 PetscFunctionBegin; 279 // in case anything was queued on the event 280 cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr); 281 cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr); 282 PetscFunctionReturn(0); 283 } 284 285 template <DeviceType T> 286 template <typename handle_t> 287 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) 288 { 289 PetscFunctionBegin; 290 *static_cast<typename handle_t::type*>(handle) = impls_cast_(dctx)->get(handle_t()); 291 PetscFunctionReturn(0); 292 } 293 294 template <DeviceType T> 295 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) 296 { 297 auto dci = impls_cast_(dctx); 298 cupmError_t cerr; 299 300 PetscFunctionBegin; 301 #if PetscDefined(USE_DEBUG) 302 if (PetscUnlikely(dci->timerInUse)) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?"); 303 dci->timerInUse = PETSC_TRUE; 304 #endif 305 if (!dci->begin) { 306 cerr = cupmEventCreate(&dci->begin);CHKERRCUPM(cerr); 307 cerr = cupmEventCreate(&dci->end);CHKERRCUPM(cerr); 308 } 309 cerr = cupmEventRecord(dci->begin,dci->stream);CHKERRCUPM(cerr); 310 PetscFunctionReturn(0); 311 } 312 313 template <DeviceType T> 314 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) 315 { 316 cupmError_t cerr; 317 float gtime; 318 auto dci = impls_cast_(dctx); 319 320 PetscFunctionBegin; 321 #if PetscDefined(USE_DEBUG) 322 if (PetscUnlikely(!dci->timerInUse)) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?"); 323 dci->timerInUse = PETSC_FALSE; 324 #endif 325 cerr = cupmEventRecord(dci->end,dci->stream);CHKERRCUPM(cerr); 326 cerr = cupmEventSynchronize(dci->end);CHKERRCUPM(cerr); 327 cerr = cupmEventElapsedTime(>ime,dci->begin,dci->end);CHKERRCUPM(cerr); 328 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 329 PetscFunctionReturn(0); 330 } 331 332 // initialize the static member variables 333 template <DeviceType T> bool DeviceContext<T>::initialized_ = false; 334 335 template <DeviceType T> 336 std::array<typename DeviceContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 337 338 template <DeviceType T> 339 std::array<typename DeviceContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 340 341 } // namespace Impl 342 343 // shorten this one up a bit (and instantiate the templates) 344 using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>; 345 using CUPMContextHip = Impl::DeviceContext<DeviceType::HIP>; 346 347 // shorthand for what is an EXTREMELY long name 348 #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 349 350 } // namespace CUPM 351 352 } // namespace Device 353 354 } // namespace Petsc 355 356 #endif // PETSCDEVICECONTEXTCUDA_HPP 357