1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmblasinterface.hpp> 6 7 #include <array> 8 9 namespace Petsc 10 { 11 12 namespace Device 13 { 14 15 namespace CUPM 16 { 17 18 namespace Impl 19 { 20 21 namespace detail 22 { 23 24 // for tag-based dispatch of handle retrieval 25 template <typename T> struct HandleTag { using type = T; }; 26 27 } // namespace detail 28 29 // Forward declare 30 template <DeviceType T> class PETSC_VISIBILITY_INTERNAL DeviceContext; 31 32 template <DeviceType T> 33 class DeviceContext : Impl::BlasInterface<T> 34 { 35 public: 36 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t,T); 37 38 private: 39 template <typename H> using HandleTag = typename detail::HandleTag<H>; 40 using stream_tag = HandleTag<cupmStream_t>; 41 using blas_tag = HandleTag<cupmBlasHandle_t>; 42 using solver_tag = HandleTag<cupmSolverHandle_t>; 43 44 public: 45 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 46 // header, but since we are using the power of templates it must be declared part of 47 // this class to have easy access the same typedefs. Technically one can make a 48 // templated struct outside the class but it's more code for the same result. 49 struct PetscDeviceContext_IMPLS 50 { 51 cupmStream_t stream; 52 cupmEvent_t event; 53 cupmEvent_t begin; // timer-only 54 cupmEvent_t end; // timer-only 55 #if PetscDefined(USE_DEBUG) 56 PetscBool timerInUse; 57 #endif 58 cupmBlasHandle_t blas; 59 cupmSolverHandle_t solver; 60 61 PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; } 62 PETSC_NODISCARD auto get(blas_tag) const -> decltype(this->blas) { return this->blas; } 63 PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; } 64 }; 65 66 private: 67 static bool initialized_; 68 static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> blashandles_; 69 static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> solverhandles_; 70 71 PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS* impls_cast_(PetscDeviceContext ptr)) 72 { 73 return static_cast<PetscDeviceContext_IMPLS*>(ptr->data); 74 } 75 76 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(cupmBlasHandle_t &handle)) 77 { 78 PetscFunctionBegin; 79 if (handle) PetscFunctionReturn(0); 80 for (auto i = 0; i < 3; ++i) { 81 auto cberr = cupmBlasCreate(&handle); 82 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 83 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) CHKERRCUPMBLAS(cberr); 84 if (i != 2) { 85 CHKERRQ(PetscSleep(3)); 86 continue; 87 } 88 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS,PETSC_COMM_SELF,PETSC_ERR_GPU_RESOURCE,"Unable to initialize %s",cupmBlasName()); 89 } 90 PetscFunctionReturn(0); 91 } 92 93 PETSC_CXX_COMPAT_DECL(PetscErrorCode set_handle_stream_(cupmBlasHandle_t &handle, cupmStream_t &stream)) 94 { 95 cupmStream_t cupmStream; 96 97 PetscFunctionBegin; 98 CHKERRCUPMBLAS(cupmBlasGetStream(handle,&cupmStream)); 99 if (cupmStream != stream) CHKERRCUPMBLAS(cupmBlasSetStream(handle,stream)); 100 PetscFunctionReturn(0); 101 } 102 103 PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_()) 104 { 105 PetscFunctionBegin; 106 for (auto&& handle : blashandles_) { 107 if (handle) { 108 CHKERRCUPMBLAS(cupmBlasDestroy(handle)); 109 handle = nullptr; 110 } 111 } 112 for (auto&& handle : solverhandles_) { 113 if (handle) { 114 CHKERRQ(cupmBlasInterface_t::DestroyHandle(handle)); 115 handle = nullptr; 116 } 117 } 118 initialized_ = false; 119 PetscFunctionReturn(0); 120 } 121 122 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_(PetscInt id, PetscDeviceContext_IMPLS *dci)) 123 { 124 125 PetscFunctionBegin; 126 CHKERRQ(PetscDeviceCheckDeviceCount_Internal(id)); 127 if (!initialized_) { 128 initialized_ = true; 129 CHKERRQ(PetscRegisterFinalize(finalize_)); 130 } 131 // use the blashandle as a canary 132 if (!blashandles_[id]) { 133 CHKERRQ(initialize_handle_(blashandles_[id])); 134 CHKERRQ(cupmBlasInterface_t::InitializeHandle(solverhandles_[id])); 135 } 136 CHKERRQ(set_handle_stream_(blashandles_[id],dci->stream)); 137 CHKERRQ(cupmBlasInterface_t::SetHandleStream(solverhandles_[id],dci->stream)); 138 dci->blas = blashandles_[id]; 139 dci->solver = solverhandles_[id]; 140 PetscFunctionReturn(0); 141 } 142 143 public: 144 const struct _DeviceContextOps ops = { 145 destroy, 146 changeStreamType, 147 setUp, 148 query, 149 waitForContext, 150 synchronize, 151 getHandle<blas_tag>, 152 getHandle<solver_tag>, 153 getHandle<stream_tag>, 154 beginTimer, 155 endTimer, 156 }; 157 158 // All of these functions MUST be static in order to be callable from C, otherwise they 159 // get the implicit 'this' pointer tacked on 160 PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); 161 PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType)); 162 PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); 163 PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext,PetscBool*)); 164 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext)); 165 PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); 166 template <typename Handle_t> 167 PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext,void*)); 168 PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); 169 PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*)); 170 }; 171 172 template <DeviceType T> 173 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) 174 { 175 auto dci = impls_cast_(dctx); 176 177 PetscFunctionBegin; 178 if (dci->stream) CHKERRCUPM(cupmStreamDestroy(dci->stream)); 179 if (dci->event) CHKERRCUPM(cupmEventDestroy(dci->event)); 180 if (dci->begin) CHKERRCUPM(cupmEventDestroy(dci->begin)); 181 if (dci->end) CHKERRCUPM(cupmEventDestroy(dci->end)); 182 CHKERRQ(PetscFree(dctx->data)); 183 PetscFunctionReturn(0); 184 } 185 186 template <DeviceType T> 187 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) 188 { 189 auto dci = impls_cast_(dctx); 190 191 PetscFunctionBegin; 192 if (dci->stream) { 193 CHKERRCUPM(cupmStreamDestroy(dci->stream)); 194 dci->stream = nullptr; 195 } 196 // set these to null so they aren't usable until setup is called again 197 dci->blas = nullptr; 198 dci->solver = nullptr; 199 PetscFunctionReturn(0); 200 } 201 202 template <DeviceType T> 203 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) 204 { 205 auto dci = impls_cast_(dctx); 206 207 PetscFunctionBegin; 208 if (dci->stream) { 209 CHKERRCUPM(cupmStreamDestroy(dci->stream)); 210 dci->stream = nullptr; 211 } 212 switch (dctx->streamType) { 213 case PETSC_STREAM_GLOBAL_BLOCKING: 214 // don't create a stream for global blocking 215 break; 216 case PETSC_STREAM_DEFAULT_BLOCKING: 217 CHKERRCUPM(cupmStreamCreate(&dci->stream)); 218 break; 219 case PETSC_STREAM_GLOBAL_NONBLOCKING: 220 CHKERRCUPM(cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking)); 221 break; 222 default: 223 SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[util::integral_value(dctx->streamType)]); 224 break; 225 } 226 if (!dci->event) CHKERRCUPM(cupmEventCreate(&dci->event)); 227 #if PetscDefined(USE_DEBUG) 228 dci->timerInUse = PETSC_FALSE; 229 #endif 230 CHKERRQ(initialize_(dctx->device->deviceId,dci)); 231 PetscFunctionReturn(0); 232 } 233 234 template <DeviceType T> 235 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) 236 { 237 cupmError_t cerr; 238 239 PetscFunctionBegin; 240 cerr = cupmStreamQuery(impls_cast_(dctx)->stream); 241 if (cerr == cupmSuccess) *idle = PETSC_TRUE; 242 else { 243 // somethings gone wrong 244 if (PetscUnlikely(cerr != cupmErrorNotReady)) CHKERRCUPM(cerr); 245 *idle = PETSC_FALSE; 246 } 247 PetscFunctionReturn(0); 248 } 249 250 template <DeviceType T> 251 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) 252 { 253 auto dcib = impls_cast_(dctxb); 254 255 PetscFunctionBegin; 256 CHKERRCUPM(cupmEventRecord(dcib->event,dcib->stream)); 257 CHKERRCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream,dcib->event,0)); 258 PetscFunctionReturn(0); 259 } 260 261 template <DeviceType T> 262 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) 263 { 264 auto dci = impls_cast_(dctx); 265 266 PetscFunctionBegin; 267 // in case anything was queued on the event 268 CHKERRCUPM(cupmStreamWaitEvent(dci->stream,dci->event,0)); 269 CHKERRCUPM(cupmStreamSynchronize(dci->stream)); 270 PetscFunctionReturn(0); 271 } 272 273 template <DeviceType T> 274 template <typename handle_t> 275 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) 276 { 277 PetscFunctionBegin; 278 *static_cast<typename handle_t::type*>(handle) = impls_cast_(dctx)->get(handle_t()); 279 PetscFunctionReturn(0); 280 } 281 282 template <DeviceType T> 283 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) 284 { 285 auto dci = impls_cast_(dctx); 286 287 PetscFunctionBegin; 288 #if PetscDefined(USE_DEBUG) 289 PetscCheck(!dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?"); 290 dci->timerInUse = PETSC_TRUE; 291 #endif 292 if (!dci->begin) { 293 CHKERRCUPM(cupmEventCreate(&dci->begin)); 294 CHKERRCUPM(cupmEventCreate(&dci->end)); 295 } 296 CHKERRCUPM(cupmEventRecord(dci->begin,dci->stream)); 297 PetscFunctionReturn(0); 298 } 299 300 template <DeviceType T> 301 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) 302 { 303 float gtime; 304 auto dci = impls_cast_(dctx); 305 306 PetscFunctionBegin; 307 #if PetscDefined(USE_DEBUG) 308 PetscCheck(dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?"); 309 dci->timerInUse = PETSC_FALSE; 310 #endif 311 CHKERRCUPM(cupmEventRecord(dci->end,dci->stream)); 312 CHKERRCUPM(cupmEventSynchronize(dci->end)); 313 CHKERRCUPM(cupmEventElapsedTime(>ime,dci->begin,dci->end)); 314 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 315 PetscFunctionReturn(0); 316 } 317 318 // initialize the static member variables 319 template <DeviceType T> bool DeviceContext<T>::initialized_ = false; 320 321 template <DeviceType T> 322 std::array<typename DeviceContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 323 324 template <DeviceType T> 325 std::array<typename DeviceContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 326 327 } // namespace Impl 328 329 // shorten this one up a bit (and instantiate the templates) 330 using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>; 331 using CUPMContextHip = Impl::DeviceContext<DeviceType::HIP>; 332 333 // shorthand for what is an EXTREMELY long name 334 #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 335 336 } // namespace CUPM 337 338 } // namespace Device 339 340 } // namespace Petsc 341 342 #endif // PETSCDEVICECONTEXTCUDA_HPP 343