1 #ifndef PETSCDEVICECONTEXTCUPM_HPP 2 #define PETSCDEVICECONTEXTCUPM_HPP 3 4 #include <petsc/private/deviceimpl.h> 5 #include <petsc/private/cupmblasinterface.hpp> 6 7 #include <array> 8 9 namespace Petsc 10 { 11 12 namespace Device 13 { 14 15 namespace CUPM 16 { 17 18 namespace Impl 19 { 20 21 namespace detail 22 { 23 24 // for tag-based dispatch of handle retrieval 25 template <typename T> struct HandleTag { using type = T; }; 26 27 } // namespace detail 28 29 // Forward declare 30 template <DeviceType T> class PETSC_VISIBILITY_INTERNAL DeviceContext; 31 32 template <DeviceType T> 33 class DeviceContext : Impl::BlasInterface<T> 34 { 35 public: 36 PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t,T); 37 38 private: 39 template <typename H> using HandleTag = typename detail::HandleTag<H>; 40 using stream_tag = HandleTag<cupmStream_t>; 41 using blas_tag = HandleTag<cupmBlasHandle_t>; 42 using solver_tag = HandleTag<cupmSolverHandle_t>; 43 44 public: 45 // This is the canonical PETSc "impls" struct that normally resides in a standalone impls 46 // header, but since we are using the power of templates it must be declared part of 47 // this class to have easy access the same typedefs. Technically one can make a 48 // templated struct outside the class but it's more code for the same result. 49 struct PetscDeviceContext_IMPLS 50 { 51 cupmStream_t stream; 52 cupmEvent_t event; 53 cupmEvent_t begin; // timer-only 54 cupmEvent_t end; // timer-only 55 #if PetscDefined(USE_DEBUG) 56 PetscBool timerInUse; 57 #endif 58 cupmBlasHandle_t blas; 59 cupmSolverHandle_t solver; 60 61 PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; } 62 PETSC_NODISCARD auto get(blas_tag) const -> decltype(this->blas) { return this->blas; } 63 PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; } 64 }; 65 66 private: 67 static bool initialized_; 68 static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> blashandles_; 69 static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> solverhandles_; 70 71 PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS* impls_cast_(PetscDeviceContext ptr)) 72 { 73 return static_cast<PetscDeviceContext_IMPLS*>(ptr->data); 74 } 75 76 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(cupmBlasHandle_t &handle)) 77 { 78 PetscFunctionBegin; 79 if (handle) PetscFunctionReturn(0); 80 for (auto i = 0; i < 3; ++i) { 81 auto cberr = cupmBlasCreate(&handle); 82 if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break; 83 if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr); 84 if (i != 2) { 85 PetscCall(PetscSleep(3)); 86 continue; 87 } 88 PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS,PETSC_COMM_SELF,PETSC_ERR_GPU_RESOURCE,"Unable to initialize %s",cupmBlasName()); 89 } 90 PetscFunctionReturn(0); 91 } 92 93 PETSC_CXX_COMPAT_DECL(PetscErrorCode set_handle_stream_(cupmBlasHandle_t &handle, cupmStream_t &stream)) 94 { 95 cupmStream_t cupmStream; 96 97 PetscFunctionBegin; 98 PetscCallCUPMBLAS(cupmBlasGetStream(handle,&cupmStream)); 99 if (cupmStream != stream) PetscCallCUPMBLAS(cupmBlasSetStream(handle,stream)); 100 PetscFunctionReturn(0); 101 } 102 103 PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_()) 104 { 105 PetscFunctionBegin; 106 for (auto&& handle : blashandles_) { 107 if (handle) { 108 PetscCallCUPMBLAS(cupmBlasDestroy(handle)); 109 handle = nullptr; 110 } 111 } 112 for (auto&& handle : solverhandles_) { 113 if (handle) { 114 PetscCall(cupmBlasInterface_t::DestroyHandle(handle)); 115 handle = nullptr; 116 } 117 } 118 initialized_ = false; 119 PetscFunctionReturn(0); 120 } 121 122 PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_(PetscInt id, PetscDeviceContext_IMPLS *dci)) 123 { 124 PetscFunctionBegin; 125 PetscCall(PetscDeviceCheckDeviceCount_Internal(id)); 126 if (!initialized_) { 127 initialized_ = true; 128 PetscCall(PetscRegisterFinalize(finalize_)); 129 } 130 // use the blashandle as a canary 131 if (!blashandles_[id]) { 132 PetscCall(initialize_handle_(blashandles_[id])); 133 PetscCall(cupmBlasInterface_t::InitializeHandle(solverhandles_[id])); 134 } 135 PetscCall(set_handle_stream_(blashandles_[id],dci->stream)); 136 PetscCall(cupmBlasInterface_t::SetHandleStream(solverhandles_[id],dci->stream)); 137 dci->blas = blashandles_[id]; 138 dci->solver = solverhandles_[id]; 139 PetscFunctionReturn(0); 140 } 141 142 public: 143 const struct _DeviceContextOps ops = { 144 destroy, 145 changeStreamType, 146 setUp, 147 query, 148 waitForContext, 149 synchronize, 150 getHandle<blas_tag>, 151 getHandle<solver_tag>, 152 getHandle<stream_tag>, 153 beginTimer, 154 endTimer, 155 }; 156 157 // All of these functions MUST be static in order to be callable from C, otherwise they 158 // get the implicit 'this' pointer tacked on 159 PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext)); 160 PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType)); 161 PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext)); 162 PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext,PetscBool*)); 163 PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext)); 164 PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext)); 165 template <typename Handle_t> 166 PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext,void*)); 167 PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext)); 168 PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*)); 169 }; 170 171 template <DeviceType T> 172 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx)) 173 { 174 auto dci = impls_cast_(dctx); 175 176 PetscFunctionBegin; 177 if (dci->stream) PetscCallCUPM(cupmStreamDestroy(dci->stream)); 178 if (dci->event) PetscCallCUPM(cupmEventDestroy(dci->event)); 179 if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin)); 180 if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end)); 181 PetscCall(PetscFree(dctx->data)); 182 PetscFunctionReturn(0); 183 } 184 185 template <DeviceType T> 186 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype)) 187 { 188 auto dci = impls_cast_(dctx); 189 190 PetscFunctionBegin; 191 if (dci->stream) { 192 PetscCallCUPM(cupmStreamDestroy(dci->stream)); 193 dci->stream = nullptr; 194 } 195 // set these to null so they aren't usable until setup is called again 196 dci->blas = nullptr; 197 dci->solver = nullptr; 198 PetscFunctionReturn(0); 199 } 200 201 template <DeviceType T> 202 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx)) 203 { 204 auto dci = impls_cast_(dctx); 205 206 PetscFunctionBegin; 207 if (dci->stream) { 208 PetscCallCUPM(cupmStreamDestroy(dci->stream)); 209 dci->stream = nullptr; 210 } 211 switch (dctx->streamType) { 212 case PETSC_STREAM_GLOBAL_BLOCKING: 213 // don't create a stream for global blocking 214 break; 215 case PETSC_STREAM_DEFAULT_BLOCKING: 216 PetscCallCUPM(cupmStreamCreate(&dci->stream)); 217 break; 218 case PETSC_STREAM_GLOBAL_NONBLOCKING: 219 PetscCallCUPM(cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking)); 220 break; 221 default: 222 SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[util::integral_value(dctx->streamType)]); 223 break; 224 } 225 if (!dci->event) PetscCallCUPM(cupmEventCreate(&dci->event)); 226 #if PetscDefined(USE_DEBUG) 227 dci->timerInUse = PETSC_FALSE; 228 #endif 229 PetscCall(initialize_(dctx->device->deviceId,dci)); 230 PetscFunctionReturn(0); 231 } 232 233 template <DeviceType T> 234 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle)) 235 { 236 cupmError_t cerr; 237 238 PetscFunctionBegin; 239 cerr = cupmStreamQuery(impls_cast_(dctx)->stream); 240 if (cerr == cupmSuccess) *idle = PETSC_TRUE; 241 else { 242 // somethings gone wrong 243 if (PetscUnlikely(cerr != cupmErrorNotReady)) PetscCallCUPM(cerr); 244 *idle = PETSC_FALSE; 245 } 246 PetscFunctionReturn(0); 247 } 248 249 template <DeviceType T> 250 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb)) 251 { 252 auto dcib = impls_cast_(dctxb); 253 254 PetscFunctionBegin; 255 PetscCallCUPM(cupmEventRecord(dcib->event,dcib->stream)); 256 PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream,dcib->event,0)); 257 PetscFunctionReturn(0); 258 } 259 260 template <DeviceType T> 261 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx)) 262 { 263 auto dci = impls_cast_(dctx); 264 265 PetscFunctionBegin; 266 // in case anything was queued on the event 267 PetscCallCUPM(cupmStreamWaitEvent(dci->stream,dci->event,0)); 268 PetscCallCUPM(cupmStreamSynchronize(dci->stream)); 269 PetscFunctionReturn(0); 270 } 271 272 template <DeviceType T> 273 template <typename handle_t> 274 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle)) 275 { 276 PetscFunctionBegin; 277 *static_cast<typename handle_t::type*>(handle) = impls_cast_(dctx)->get(handle_t()); 278 PetscFunctionReturn(0); 279 } 280 281 template <DeviceType T> 282 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx)) 283 { 284 auto dci = impls_cast_(dctx); 285 286 PetscFunctionBegin; 287 #if PetscDefined(USE_DEBUG) 288 PetscCheck(!dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeEnd()?"); 289 dci->timerInUse = PETSC_TRUE; 290 #endif 291 if (!dci->begin) { 292 PetscCallCUPM(cupmEventCreate(&dci->begin)); 293 PetscCallCUPM(cupmEventCreate(&dci->end)); 294 } 295 PetscCallCUPM(cupmEventRecord(dci->begin,dci->stream)); 296 PetscFunctionReturn(0); 297 } 298 299 template <DeviceType T> 300 PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed)) 301 { 302 float gtime; 303 auto dci = impls_cast_(dctx); 304 305 PetscFunctionBegin; 306 #if PetscDefined(USE_DEBUG) 307 PetscCheck(dci->timerInUse,PETSC_COMM_SELF,PETSC_ERR_PLIB,"Forgot to call PetscLogGpuTimeBegin()?"); 308 dci->timerInUse = PETSC_FALSE; 309 #endif 310 PetscCallCUPM(cupmEventRecord(dci->end,dci->stream)); 311 PetscCallCUPM(cupmEventSynchronize(dci->end)); 312 PetscCallCUPM(cupmEventElapsedTime(>ime,dci->begin,dci->end)); 313 *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime); 314 PetscFunctionReturn(0); 315 } 316 317 // initialize the static member variables 318 template <DeviceType T> bool DeviceContext<T>::initialized_ = false; 319 320 template <DeviceType T> 321 std::array<typename DeviceContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {}; 322 323 template <DeviceType T> 324 std::array<typename DeviceContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {}; 325 326 } // namespace Impl 327 328 // shorten this one up a bit (and instantiate the templates) 329 using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>; 330 using CUPMContextHip = Impl::DeviceContext<DeviceType::HIP>; 331 332 // shorthand for what is an EXTREMELY long name 333 #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS 334 335 } // namespace CUPM 336 337 } // namespace Device 338 339 } // namespace Petsc 340 341 #endif // PETSCDEVICECONTEXTCUDA_HPP 342