1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <stdbool.h>
12 #include <string.h>
13
14 #include "../cuda/ceed-cuda-common.h"
15 #include "ceed-cuda-ref.h"
16
17 //------------------------------------------------------------------------------
18 // Sync host to device
19 //------------------------------------------------------------------------------
CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx)20 static inline int CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx) {
21 Ceed ceed;
22 size_t ctx_size;
23 CeedQFunctionContext_Cuda *impl;
24
25 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
26 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
27
28 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29
30 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31 if (impl->d_data_borrowed) {
32 impl->d_data = impl->d_data_borrowed;
33 } else if (impl->d_data_owned) {
34 impl->d_data = impl->d_data_owned;
35 } else {
36 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size));
37 impl->d_data = impl->d_data_owned;
38 }
39 CeedCallCuda(ceed, cudaMemcpy(impl->d_data, impl->h_data, ctx_size, cudaMemcpyHostToDevice));
40 CeedCallBackend(CeedDestroy(&ceed));
41 return CEED_ERROR_SUCCESS;
42 }
43
44 //------------------------------------------------------------------------------
45 // Sync device to host
46 //------------------------------------------------------------------------------
CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx)47 static inline int CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx) {
48 Ceed ceed;
49 size_t ctx_size;
50 CeedQFunctionContext_Cuda *impl;
51
52 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
53 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
54
55 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
56
57 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
58
59 if (impl->h_data_borrowed) {
60 impl->h_data = impl->h_data_borrowed;
61 } else if (impl->h_data_owned) {
62 impl->h_data = impl->h_data_owned;
63 } else {
64 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
65 impl->h_data = impl->h_data_owned;
66 }
67 CeedCallCuda(ceed, cudaMemcpy(impl->h_data, impl->d_data, ctx_size, cudaMemcpyDeviceToHost));
68 CeedCallBackend(CeedDestroy(&ceed));
69 return CEED_ERROR_SUCCESS;
70 }
71
72 //------------------------------------------------------------------------------
73 // Sync data of type
74 //------------------------------------------------------------------------------
CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx,CeedMemType mem_type)75 static inline int CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type) {
76 switch (mem_type) {
77 case CEED_MEM_HOST:
78 return CeedQFunctionContextSyncD2H_Cuda(ctx);
79 case CEED_MEM_DEVICE:
80 return CeedQFunctionContextSyncH2D_Cuda(ctx);
81 }
82 return CEED_ERROR_UNSUPPORTED;
83 }
84
85 //------------------------------------------------------------------------------
86 // Set all pointers as invalid
87 //------------------------------------------------------------------------------
CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx)88 static inline int CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx) {
89 CeedQFunctionContext_Cuda *impl;
90
91 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
92 impl->h_data = NULL;
93 impl->d_data = NULL;
94 return CEED_ERROR_SUCCESS;
95 }
96
97 //------------------------------------------------------------------------------
98 // Check if ctx has valid data
99 //------------------------------------------------------------------------------
CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx,bool * has_valid_data)100 static inline int CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx, bool *has_valid_data) {
101 CeedQFunctionContext_Cuda *impl;
102
103 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
104 *has_valid_data = impl && (impl->h_data || impl->d_data);
105 return CEED_ERROR_SUCCESS;
106 }
107
108 //------------------------------------------------------------------------------
109 // Check if ctx has borrowed data
110 //------------------------------------------------------------------------------
CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx,CeedMemType mem_type,bool * has_borrowed_data_of_type)111 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type,
112 bool *has_borrowed_data_of_type) {
113 CeedQFunctionContext_Cuda *impl;
114
115 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
116 switch (mem_type) {
117 case CEED_MEM_HOST:
118 *has_borrowed_data_of_type = impl->h_data_borrowed;
119 break;
120 case CEED_MEM_DEVICE:
121 *has_borrowed_data_of_type = impl->d_data_borrowed;
122 break;
123 }
124 return CEED_ERROR_SUCCESS;
125 }
126
127 //------------------------------------------------------------------------------
128 // Check if data of given type needs sync
129 //------------------------------------------------------------------------------
CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx,CeedMemType mem_type,bool * need_sync)130 static inline int CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
131 bool has_valid_data = true;
132 CeedQFunctionContext_Cuda *impl;
133
134 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
135 CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
136 switch (mem_type) {
137 case CEED_MEM_HOST:
138 *need_sync = has_valid_data && !impl->h_data;
139 break;
140 case CEED_MEM_DEVICE:
141 *need_sync = has_valid_data && !impl->d_data;
142 break;
143 }
144 return CEED_ERROR_SUCCESS;
145 }
146
147 //------------------------------------------------------------------------------
148 // Set data from host
149 //------------------------------------------------------------------------------
CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx,const CeedCopyMode copy_mode,void * data)150 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
151 CeedQFunctionContext_Cuda *impl;
152
153 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
154
155 CeedCallBackend(CeedFree(&impl->h_data_owned));
156 switch (copy_mode) {
157 case CEED_COPY_VALUES: {
158 size_t ctx_size;
159 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
160 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
161 impl->h_data_borrowed = NULL;
162 impl->h_data = impl->h_data_owned;
163 memcpy(impl->h_data, data, ctx_size);
164 } break;
165 case CEED_OWN_POINTER:
166 impl->h_data_owned = data;
167 impl->h_data_borrowed = NULL;
168 impl->h_data = data;
169 break;
170 case CEED_USE_POINTER:
171 impl->h_data_borrowed = data;
172 impl->h_data = data;
173 break;
174 }
175 return CEED_ERROR_SUCCESS;
176 }
177
178 //------------------------------------------------------------------------------
179 // Set data from device
180 //------------------------------------------------------------------------------
CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx,const CeedCopyMode copy_mode,void * data)181 static int CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
182 Ceed ceed;
183 CeedQFunctionContext_Cuda *impl;
184
185 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
186 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
187
188 CeedCallCuda(ceed, cudaFree(impl->d_data_owned));
189 impl->d_data_owned = NULL;
190 switch (copy_mode) {
191 case CEED_COPY_VALUES: {
192 size_t ctx_size;
193 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
194 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size));
195 impl->d_data_borrowed = NULL;
196 impl->d_data = impl->d_data_owned;
197 CeedCallCuda(ceed, cudaMemcpy(impl->d_data, data, ctx_size, cudaMemcpyDeviceToDevice));
198 } break;
199 case CEED_OWN_POINTER:
200 impl->d_data_owned = data;
201 impl->d_data_borrowed = NULL;
202 impl->d_data = data;
203 break;
204 case CEED_USE_POINTER:
205 impl->d_data_owned = NULL;
206 impl->d_data_borrowed = data;
207 impl->d_data = data;
208 break;
209 }
210 CeedCallBackend(CeedDestroy(&ceed));
211 return CEED_ERROR_SUCCESS;
212 }
213
214 //------------------------------------------------------------------------------
215 // Set the data used by a user context,
216 // freeing any previously allocated data if applicable
217 //------------------------------------------------------------------------------
CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx,const CeedMemType mem_type,const CeedCopyMode copy_mode,void * data)218 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
219 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
220 switch (mem_type) {
221 case CEED_MEM_HOST:
222 return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data);
223 case CEED_MEM_DEVICE:
224 return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data);
225 }
226 return CEED_ERROR_UNSUPPORTED;
227 }
228
229 //------------------------------------------------------------------------------
230 // Take data
231 //------------------------------------------------------------------------------
CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)232 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
233 CeedQFunctionContext_Cuda *impl;
234
235 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
236
237 // Sync data to requested mem_type
238 bool need_sync = false;
239 CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
240 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
241
242 // Update pointer
243 switch (mem_type) {
244 case CEED_MEM_HOST:
245 *(void **)data = impl->h_data_borrowed;
246 impl->h_data_borrowed = NULL;
247 impl->h_data = NULL;
248 break;
249 case CEED_MEM_DEVICE:
250 *(void **)data = impl->d_data_borrowed;
251 impl->d_data_borrowed = NULL;
252 impl->d_data = NULL;
253 break;
254 }
255 return CEED_ERROR_SUCCESS;
256 }
257
258 //------------------------------------------------------------------------------
259 // Core logic for GetData.
260 // If a different memory type is most up to date, this will perform a copy
261 //------------------------------------------------------------------------------
CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)262 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
263 bool need_sync = false;
264 CeedQFunctionContext_Cuda *impl;
265
266 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
267
268 // Sync data to requested mem_type
269 CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
270 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
271
272 // Update pointer
273 switch (mem_type) {
274 case CEED_MEM_HOST:
275 *(void **)data = impl->h_data;
276 break;
277 case CEED_MEM_DEVICE:
278 *(void **)data = impl->d_data;
279 break;
280 }
281 return CEED_ERROR_SUCCESS;
282 }
283
284 //------------------------------------------------------------------------------
285 // Get read-only access to the data
286 //------------------------------------------------------------------------------
CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)287 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
288 return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data);
289 }
290
291 //------------------------------------------------------------------------------
292 // Get read/write access to the data
293 //------------------------------------------------------------------------------
CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)294 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
295 CeedQFunctionContext_Cuda *impl;
296
297 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
298 CeedCallBackend(CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data));
299
300 // Mark only pointer for requested memory as valid
301 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
302 switch (mem_type) {
303 case CEED_MEM_HOST:
304 impl->h_data = *(void **)data;
305 break;
306 case CEED_MEM_DEVICE:
307 impl->d_data = *(void **)data;
308 break;
309 }
310 return CEED_ERROR_SUCCESS;
311 }
312
313 //------------------------------------------------------------------------------
314 // Destroy the user context
315 //------------------------------------------------------------------------------
CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx)316 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) {
317 CeedQFunctionContext_Cuda *impl;
318
319 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
320 CeedCallCuda(CeedQFunctionContextReturnCeed(ctx), cudaFree(impl->d_data_owned));
321 CeedCallBackend(CeedFree(&impl->h_data_owned));
322 CeedCallBackend(CeedFree(&impl));
323 return CEED_ERROR_SUCCESS;
324 }
325
326 //------------------------------------------------------------------------------
327 // QFunctionContext Create
328 //------------------------------------------------------------------------------
CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx)329 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) {
330 CeedQFunctionContext_Cuda *impl;
331 Ceed ceed;
332
333 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
334 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Cuda));
335 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Cuda));
336 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Cuda));
337 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Cuda));
338 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Cuda));
339 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Cuda));
340 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Cuda));
341 CeedCallBackend(CeedDestroy(&ceed));
342 CeedCallBackend(CeedCalloc(1, &impl));
343 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
344 return CEED_ERROR_SUCCESS;
345 }
346
347 //------------------------------------------------------------------------------
348