1 #include "petscdevice_interface_internal.hpp" /*I <petscdevice.h> I*/ 2 #include <petscdevice_cupm.h> 3 4 static auto rootDeviceType = PETSC_DEVICE_CONTEXT_DEFAULT_DEVICE_TYPE; 5 static auto rootStreamType = PETSC_DEVICE_CONTEXT_DEFAULT_STREAM_TYPE; 6 static PetscDeviceContext globalContext = nullptr; 7 8 /* when PetscDevice initializes PetscDeviceContext eagerly the type of device created should 9 * match whatever device is eagerly initialized */ 10 PetscErrorCode PetscDeviceContextSetRootDeviceType_Internal(PetscDeviceType type) 11 { 12 PetscFunctionBegin; 13 PetscValidDeviceType(type, 1); 14 rootDeviceType = type; 15 PetscFunctionReturn(PETSC_SUCCESS); 16 } 17 18 PetscErrorCode PetscDeviceContextSetRootStreamType_Internal(PetscStreamType type) 19 { 20 PetscFunctionBegin; 21 PetscValidStreamType(type, 1); 22 rootStreamType = type; 23 PetscFunctionReturn(PETSC_SUCCESS); 24 } 25 26 static inline PetscErrorCode PetscSetDefaultCUPMStreamFromDeviceContext(PetscDeviceContext dctx, PetscDeviceType dtype) 27 { 28 PetscFunctionBegin; 29 #if PetscDefined(HAVE_CUDA) 30 if (dtype == PETSC_DEVICE_CUDA) { 31 void *handle; 32 33 PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &handle)); 34 PetscDefaultCudaStream = *static_cast<cudaStream_t *>(handle); 35 } 36 #endif 37 #if PetscDefined(HAVE_HIP) 38 if (dtype == PETSC_DEVICE_HIP) { 39 void *handle; 40 41 PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &handle)); 42 PetscDefaultHipStream = *static_cast<hipStream_t *>(handle); 43 } 44 #endif 45 #if !PetscDefined(HAVE_CUDA) && !PetscDefined(HAVE_HIP) 46 (void)dctx, (void)dtype; 47 #endif 48 PetscFunctionReturn(PETSC_SUCCESS); 49 } 50 51 static PetscErrorCode PetscDeviceContextSetupGlobalContext_Private() noexcept 52 { 53 PetscFunctionBegin; 54 if (PetscUnlikely(!globalContext)) { 55 PetscObject pobj; 56 const auto dtype = rootDeviceType; 57 const auto finalizer = [] { 58 PetscDeviceType dtype; 59 60 PetscFunctionBegin; 61 PetscCall(PetscDeviceContextGetDeviceType(globalContext, &dtype)); 62 PetscCall(PetscInfo(globalContext, "Destroying global PetscDeviceContext with device type %s\n", PetscDeviceTypes[dtype])); 63 PetscCall(PetscDeviceContextDestroy(&globalContext)); 64 rootDeviceType = PETSC_DEVICE_CONTEXT_DEFAULT_DEVICE_TYPE; 65 rootStreamType = PETSC_DEVICE_CONTEXT_DEFAULT_STREAM_TYPE; 66 PetscFunctionReturn(PETSC_SUCCESS); 67 }; 68 69 /* this exists purely as a valid device check. */ 70 PetscCall(PetscDeviceInitializePackage()); 71 PetscCall(PetscRegisterFinalize(std::move(finalizer))); 72 PetscCall(PetscDeviceContextCreate(&globalContext)); 73 PetscCall(PetscInfo(globalContext, "Initializing global PetscDeviceContext with device type %s\n", PetscDeviceTypes[dtype])); 74 pobj = PetscObjectCast(globalContext); 75 PetscCall(PetscObjectSetName(pobj, "global root")); 76 PetscCall(PetscObjectSetOptionsPrefix(pobj, "root_")); 77 PetscCall(PetscDeviceContextSetStreamType(globalContext, rootStreamType)); 78 PetscCall(PetscDeviceContextSetDefaultDeviceForType_Internal(globalContext, dtype)); 79 PetscCall(PetscDeviceContextSetUp(globalContext)); 80 PetscCall(PetscSetDefaultCUPMStreamFromDeviceContext(globalContext, dtype)); 81 } 82 PetscFunctionReturn(PETSC_SUCCESS); 83 } 84 85 /*@C 86 PetscDeviceContextGetCurrentContext - Get the current active `PetscDeviceContext` 87 88 Not Collective 89 90 Output Parameter: 91 . dctx - The `PetscDeviceContext` 92 93 Notes: 94 The user generally should not destroy contexts retrieved with this routine unless they 95 themselves have created them. There exists no protection against destroying the root 96 context. 97 98 Developer Notes: 99 Unless the user has set their own, this routine creates the "root" context the first time it 100 is called, registering its destructor to `PetscFinalize()`. 101 102 Level: beginner 103 104 .seealso: `PetscDeviceContextSetCurrentContext()`, `PetscDeviceContextFork()`, 105 `PetscDeviceContextJoin()`, `PetscDeviceContextCreate()` 106 @*/ 107 PetscErrorCode PetscDeviceContextGetCurrentContext(PetscDeviceContext *dctx) 108 { 109 PetscFunctionBegin; 110 PetscAssertPointer(dctx, 1); 111 PetscCall(PetscDeviceContextSetupGlobalContext_Private()); 112 /* while the static analyzer can find global variables, it will throw a warning about not 113 * being able to connect this back to the function arguments */ 114 PetscDisableStaticAnalyzerForExpressionUnderstandingThatThisIsDangerousAndBugprone(PetscValidDeviceContext(globalContext, -1)); 115 *dctx = globalContext; 116 PetscFunctionReturn(PETSC_SUCCESS); 117 } 118 119 /*@C 120 PetscDeviceContextSetCurrentContext - Set the current active `PetscDeviceContext` 121 122 Not Collective 123 124 Input Parameter: 125 . dctx - The `PetscDeviceContext` 126 127 Notes: 128 This routine can be used to set the defacto "root" `PetscDeviceContext` to a user-defined 129 implementation by calling this routine immediately after `PetscInitialize()` and ensuring that 130 `PetscDevice` is not greedily initialized. In this case the user is responsible for destroying 131 their `PetscDeviceContext` before `PetscFinalize()` returns. 132 133 The old context is not stored in any way by this routine; if one is overriding a context that 134 they themselves do not control, one should take care to temporarily store it by calling 135 `PetscDeviceContextGetCurrentContext()` before calling this routine. 136 137 Level: beginner 138 139 .seealso: `PetscDeviceContextGetCurrentContext()`, `PetscDeviceContextFork()`, 140 `PetscDeviceContextJoin()`, `PetscDeviceContextCreate()` 141 @*/ 142 PetscErrorCode PetscDeviceContextSetCurrentContext(PetscDeviceContext dctx) 143 { 144 PetscDeviceType dtype; 145 146 PetscFunctionBegin; 147 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx)); 148 PetscAssert(dctx->setup, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "PetscDeviceContext %" PetscInt64_FMT " must be set up before being set as global context", PetscObjectCast(dctx)->id); 149 PetscCall(PetscDeviceContextGetDeviceType(dctx, &dtype)); 150 PetscCall(PetscDeviceSetDefaultDeviceType(dtype)); 151 globalContext = dctx; 152 PetscCall(PetscInfo(dctx, "Set global PetscDeviceContext id %" PetscInt64_FMT "\n", PetscObjectCast(dctx)->id)); 153 PetscCall(PetscSetDefaultCUPMStreamFromDeviceContext(globalContext, dtype)); 154 PetscFunctionReturn(PETSC_SUCCESS); 155 } 156