1 #include "petscdevice_interface_internal.hpp" /*I <petscdevice.h> I*/
2 #include <petscdevice_cupm.h>
3
4 static auto rootDeviceType = PETSC_DEVICE_CONTEXT_DEFAULT_DEVICE_TYPE;
5 static auto rootStreamType = PETSC_DEVICE_CONTEXT_DEFAULT_STREAM_TYPE;
6 static PetscDeviceContext globalContext = nullptr;
7
8 /* when PetscDevice initializes PetscDeviceContext eagerly the type of device created should
9 * match whatever device is eagerly initialized */
PetscDeviceContextSetRootDeviceType_Internal(PetscDeviceType type)10 PetscErrorCode PetscDeviceContextSetRootDeviceType_Internal(PetscDeviceType type)
11 {
12 PetscFunctionBegin;
13 PetscValidDeviceType(type, 1);
14 rootDeviceType = type;
15 PetscFunctionReturn(PETSC_SUCCESS);
16 }
17
PetscDeviceContextSetRootStreamType_Internal(PetscStreamType type)18 PetscErrorCode PetscDeviceContextSetRootStreamType_Internal(PetscStreamType type)
19 {
20 PetscFunctionBegin;
21 PetscValidStreamType(type, 1);
22 rootStreamType = type;
23 PetscFunctionReturn(PETSC_SUCCESS);
24 }
25
PetscSetDefaultCUPMStreamFromDeviceContext(PetscDeviceContext dctx,PetscDeviceType dtype)26 static inline PetscErrorCode PetscSetDefaultCUPMStreamFromDeviceContext(PetscDeviceContext dctx, PetscDeviceType dtype)
27 {
28 PetscFunctionBegin;
29 #if PetscDefined(HAVE_CUDA)
30 if (dtype == PETSC_DEVICE_CUDA) {
31 void *handle;
32
33 PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &handle));
34 PetscDefaultCudaStream = *static_cast<cudaStream_t *>(handle);
35 }
36 #endif
37 #if PetscDefined(HAVE_HIP)
38 if (dtype == PETSC_DEVICE_HIP) {
39 void *handle;
40
41 PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &handle));
42 PetscDefaultHipStream = *static_cast<hipStream_t *>(handle);
43 }
44 #endif
45 #if !PetscDefined(HAVE_CUDA) && !PetscDefined(HAVE_HIP)
46 (void)dctx, (void)dtype;
47 #endif
48 PetscFunctionReturn(PETSC_SUCCESS);
49 }
50
PetscDeviceContextSetupGlobalContext_Private()51 static PetscErrorCode PetscDeviceContextSetupGlobalContext_Private() noexcept
52 {
53 PetscFunctionBegin;
54 if (PetscUnlikely(!globalContext)) {
55 PetscObject pobj;
56 const auto dtype = rootDeviceType;
57 const auto finalizer = [] {
58 PetscDeviceType dtype;
59
60 PetscFunctionBegin;
61 PetscCall(PetscDeviceContextGetDeviceType(globalContext, &dtype));
62 PetscCall(PetscInfo(globalContext, "Destroying global PetscDeviceContext with device type %s\n", PetscDeviceTypes[dtype]));
63 PetscCall(PetscDeviceContextDestroy(&globalContext));
64 rootDeviceType = PETSC_DEVICE_CONTEXT_DEFAULT_DEVICE_TYPE;
65 rootStreamType = PETSC_DEVICE_CONTEXT_DEFAULT_STREAM_TYPE;
66 PetscFunctionReturn(PETSC_SUCCESS);
67 };
68
69 /* this exists purely as a valid device check. */
70 PetscCall(PetscDeviceInitializePackage());
71 PetscCall(PetscRegisterFinalize(std::move(finalizer)));
72 PetscCall(PetscDeviceContextCreate(&globalContext));
73 PetscCall(PetscInfo(globalContext, "Initializing global PetscDeviceContext with device type %s\n", PetscDeviceTypes[dtype]));
74 pobj = PetscObjectCast(globalContext);
75 PetscCall(PetscObjectSetName(pobj, "global root"));
76 PetscCall(PetscObjectSetOptionsPrefix(pobj, "root_"));
77 PetscCall(PetscDeviceContextSetStreamType(globalContext, rootStreamType));
78 PetscCall(PetscDeviceContextSetDefaultDeviceForType_Internal(globalContext, dtype));
79 PetscCall(PetscDeviceContextSetUp(globalContext));
80 PetscCall(PetscSetDefaultCUPMStreamFromDeviceContext(globalContext, dtype));
81 }
82 PetscFunctionReturn(PETSC_SUCCESS);
83 }
84
85 /*@C
86 PetscDeviceContextGetCurrentContext - Get the current active `PetscDeviceContext`
87
88 Not Collective
89
90 Output Parameter:
91 . dctx - The `PetscDeviceContext`
92
93 Notes:
94 The user generally should not destroy contexts retrieved with this routine unless they
95 themselves have created them. There exists no protection against destroying the root
96 context.
97
98 Developer Notes:
99 Unless the user has set their own, this routine creates the "root" context the first time it
100 is called, registering its destructor to `PetscFinalize()`.
101
102 Level: beginner
103
104 .seealso: `PetscDeviceContextSetCurrentContext()`, `PetscDeviceContextFork()`,
105 `PetscDeviceContextJoin()`, `PetscDeviceContextCreate()`
106 @*/
PetscDeviceContextGetCurrentContext(PetscDeviceContext * dctx)107 PetscErrorCode PetscDeviceContextGetCurrentContext(PetscDeviceContext *dctx)
108 {
109 PetscFunctionBegin;
110 PetscAssertPointer(dctx, 1);
111 PetscCall(PetscDeviceContextSetupGlobalContext_Private());
112 /* while the static analyzer can find global variables, it will throw a warning about not
113 * being able to connect this back to the function arguments */
114 PetscDisableStaticAnalyzerForExpressionUnderstandingThatThisIsDangerousAndBugprone(PetscValidDeviceContext(globalContext, -1));
115 *dctx = globalContext;
116 PetscFunctionReturn(PETSC_SUCCESS);
117 }
118
119 /*@C
120 PetscDeviceContextSetCurrentContext - Set the current active `PetscDeviceContext`
121
122 Not Collective
123
124 Input Parameter:
125 . dctx - The `PetscDeviceContext`
126
127 Notes:
128 This routine can be used to set the defacto "root" `PetscDeviceContext` to a user-defined
129 implementation by calling this routine immediately after `PetscInitialize()` and ensuring that
130 `PetscDevice` is not greedily initialized. In this case the user is responsible for destroying
131 their `PetscDeviceContext` before `PetscFinalize()` returns.
132
133 The old context is not stored in any way by this routine; if one is overriding a context that
134 they themselves do not control, one should take care to temporarily store it by calling
135 `PetscDeviceContextGetCurrentContext()` before calling this routine.
136
137 Level: beginner
138
139 .seealso: `PetscDeviceContextGetCurrentContext()`, `PetscDeviceContextFork()`,
140 `PetscDeviceContextJoin()`, `PetscDeviceContextCreate()`
141 @*/
PetscDeviceContextSetCurrentContext(PetscDeviceContext dctx)142 PetscErrorCode PetscDeviceContextSetCurrentContext(PetscDeviceContext dctx)
143 {
144 PetscDeviceType dtype;
145
146 PetscFunctionBegin;
147 PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
148 PetscAssert(dctx->setup, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "PetscDeviceContext %" PetscInt64_FMT " must be set up before being set as global context", PetscObjectCast(dctx)->id);
149 PetscCall(PetscDeviceContextGetDeviceType(dctx, &dtype));
150 PetscCall(PetscDeviceSetDefaultDeviceType(dtype));
151 globalContext = dctx;
152 PetscCall(PetscInfo(dctx, "Set global PetscDeviceContext id %" PetscInt64_FMT "\n", PetscObjectCast(dctx)->id));
153 PetscCall(PetscSetDefaultCUPMStreamFromDeviceContext(globalContext, dtype));
154 PetscFunctionReturn(PETSC_SUCCESS);
155 }
156