xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-ref/ceed-cuda-ref-qfunctioncontext.c (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc)
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <stdbool.h>
12 #include <string.h>
13 
14 #include "../cuda/ceed-cuda-common.h"
15 #include "ceed-cuda-ref.h"
16 
17 //------------------------------------------------------------------------------
18 // Sync host to device
19 //------------------------------------------------------------------------------
CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx)20 static inline int CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx) {
21   Ceed                       ceed;
22   size_t                     ctx_size;
23   CeedQFunctionContext_Cuda *impl;
24 
25   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
26   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
27 
28   CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29 
30   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31   if (impl->d_data_borrowed) {
32     impl->d_data = impl->d_data_borrowed;
33   } else if (impl->d_data_owned) {
34     impl->d_data = impl->d_data_owned;
35   } else {
36     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size));
37     impl->d_data = impl->d_data_owned;
38   }
39   CeedCallCuda(ceed, cudaMemcpy(impl->d_data, impl->h_data, ctx_size, cudaMemcpyHostToDevice));
40   CeedCallBackend(CeedDestroy(&ceed));
41   return CEED_ERROR_SUCCESS;
42 }
43 
44 //------------------------------------------------------------------------------
45 // Sync device to host
46 //------------------------------------------------------------------------------
CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx)47 static inline int CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx) {
48   Ceed                       ceed;
49   size_t                     ctx_size;
50   CeedQFunctionContext_Cuda *impl;
51 
52   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
53   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
54 
55   CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
56 
57   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
58 
59   if (impl->h_data_borrowed) {
60     impl->h_data = impl->h_data_borrowed;
61   } else if (impl->h_data_owned) {
62     impl->h_data = impl->h_data_owned;
63   } else {
64     CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
65     impl->h_data = impl->h_data_owned;
66   }
67   CeedCallCuda(ceed, cudaMemcpy(impl->h_data, impl->d_data, ctx_size, cudaMemcpyDeviceToHost));
68   CeedCallBackend(CeedDestroy(&ceed));
69   return CEED_ERROR_SUCCESS;
70 }
71 
72 //------------------------------------------------------------------------------
73 // Sync data of type
74 //------------------------------------------------------------------------------
CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx,CeedMemType mem_type)75 static inline int CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type) {
76   switch (mem_type) {
77     case CEED_MEM_HOST:
78       return CeedQFunctionContextSyncD2H_Cuda(ctx);
79     case CEED_MEM_DEVICE:
80       return CeedQFunctionContextSyncH2D_Cuda(ctx);
81   }
82   return CEED_ERROR_UNSUPPORTED;
83 }
84 
85 //------------------------------------------------------------------------------
86 // Set all pointers as invalid
87 //------------------------------------------------------------------------------
CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx)88 static inline int CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx) {
89   CeedQFunctionContext_Cuda *impl;
90 
91   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
92   impl->h_data = NULL;
93   impl->d_data = NULL;
94   return CEED_ERROR_SUCCESS;
95 }
96 
97 //------------------------------------------------------------------------------
98 // Check if ctx has valid data
99 //------------------------------------------------------------------------------
CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx,bool * has_valid_data)100 static inline int CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx, bool *has_valid_data) {
101   CeedQFunctionContext_Cuda *impl;
102 
103   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
104   *has_valid_data = impl && (impl->h_data || impl->d_data);
105   return CEED_ERROR_SUCCESS;
106 }
107 
108 //------------------------------------------------------------------------------
109 // Check if ctx has borrowed data
110 //------------------------------------------------------------------------------
CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx,CeedMemType mem_type,bool * has_borrowed_data_of_type)111 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type,
112                                                                  bool *has_borrowed_data_of_type) {
113   CeedQFunctionContext_Cuda *impl;
114 
115   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
116   switch (mem_type) {
117     case CEED_MEM_HOST:
118       *has_borrowed_data_of_type = impl->h_data_borrowed;
119       break;
120     case CEED_MEM_DEVICE:
121       *has_borrowed_data_of_type = impl->d_data_borrowed;
122       break;
123   }
124   return CEED_ERROR_SUCCESS;
125 }
126 
127 //------------------------------------------------------------------------------
128 // Check if data of given type needs sync
129 //------------------------------------------------------------------------------
CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx,CeedMemType mem_type,bool * need_sync)130 static inline int CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
131   bool                       has_valid_data = true;
132   CeedQFunctionContext_Cuda *impl;
133 
134   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
135   CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
136   switch (mem_type) {
137     case CEED_MEM_HOST:
138       *need_sync = has_valid_data && !impl->h_data;
139       break;
140     case CEED_MEM_DEVICE:
141       *need_sync = has_valid_data && !impl->d_data;
142       break;
143   }
144   return CEED_ERROR_SUCCESS;
145 }
146 
147 //------------------------------------------------------------------------------
148 // Set data from host
149 //------------------------------------------------------------------------------
CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx,const CeedCopyMode copy_mode,void * data)150 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
151   CeedQFunctionContext_Cuda *impl;
152 
153   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
154 
155   CeedCallBackend(CeedFree(&impl->h_data_owned));
156   switch (copy_mode) {
157     case CEED_COPY_VALUES: {
158       size_t ctx_size;
159       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
160       CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
161       impl->h_data_borrowed = NULL;
162       impl->h_data          = impl->h_data_owned;
163       memcpy(impl->h_data, data, ctx_size);
164     } break;
165     case CEED_OWN_POINTER:
166       impl->h_data_owned    = data;
167       impl->h_data_borrowed = NULL;
168       impl->h_data          = data;
169       break;
170     case CEED_USE_POINTER:
171       impl->h_data_borrowed = data;
172       impl->h_data          = data;
173       break;
174   }
175   return CEED_ERROR_SUCCESS;
176 }
177 
178 //------------------------------------------------------------------------------
179 // Set data from device
180 //------------------------------------------------------------------------------
CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx,const CeedCopyMode copy_mode,void * data)181 static int CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
182   Ceed                       ceed;
183   CeedQFunctionContext_Cuda *impl;
184 
185   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
186   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
187 
188   CeedCallCuda(ceed, cudaFree(impl->d_data_owned));
189   impl->d_data_owned = NULL;
190   switch (copy_mode) {
191     case CEED_COPY_VALUES: {
192       size_t ctx_size;
193       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
194       CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size));
195       impl->d_data_borrowed = NULL;
196       impl->d_data          = impl->d_data_owned;
197       CeedCallCuda(ceed, cudaMemcpy(impl->d_data, data, ctx_size, cudaMemcpyDeviceToDevice));
198     } break;
199     case CEED_OWN_POINTER:
200       impl->d_data_owned    = data;
201       impl->d_data_borrowed = NULL;
202       impl->d_data          = data;
203       break;
204     case CEED_USE_POINTER:
205       impl->d_data_owned    = NULL;
206       impl->d_data_borrowed = data;
207       impl->d_data          = data;
208       break;
209   }
210   CeedCallBackend(CeedDestroy(&ceed));
211   return CEED_ERROR_SUCCESS;
212 }
213 
214 //------------------------------------------------------------------------------
215 // Set the data used by a user context,
216 //   freeing any previously allocated data if applicable
217 //------------------------------------------------------------------------------
CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx,const CeedMemType mem_type,const CeedCopyMode copy_mode,void * data)218 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
219   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
220   switch (mem_type) {
221     case CEED_MEM_HOST:
222       return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data);
223     case CEED_MEM_DEVICE:
224       return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data);
225   }
226   return CEED_ERROR_UNSUPPORTED;
227 }
228 
229 //------------------------------------------------------------------------------
230 // Take data
231 //------------------------------------------------------------------------------
CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)232 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
233   CeedQFunctionContext_Cuda *impl;
234 
235   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
236 
237   // Sync data to requested mem_type
238   bool need_sync = false;
239   CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
240   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
241 
242   // Update pointer
243   switch (mem_type) {
244     case CEED_MEM_HOST:
245       *(void **)data        = impl->h_data_borrowed;
246       impl->h_data_borrowed = NULL;
247       impl->h_data          = NULL;
248       break;
249     case CEED_MEM_DEVICE:
250       *(void **)data        = impl->d_data_borrowed;
251       impl->d_data_borrowed = NULL;
252       impl->d_data          = NULL;
253       break;
254   }
255   return CEED_ERROR_SUCCESS;
256 }
257 
258 //------------------------------------------------------------------------------
259 // Core logic for GetData.
260 //   If a different memory type is most up to date, this will perform a copy
261 //------------------------------------------------------------------------------
CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)262 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
263   bool                       need_sync = false;
264   CeedQFunctionContext_Cuda *impl;
265 
266   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
267 
268   // Sync data to requested mem_type
269   CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
270   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
271 
272   // Update pointer
273   switch (mem_type) {
274     case CEED_MEM_HOST:
275       *(void **)data = impl->h_data;
276       break;
277     case CEED_MEM_DEVICE:
278       *(void **)data = impl->d_data;
279       break;
280   }
281   return CEED_ERROR_SUCCESS;
282 }
283 
284 //------------------------------------------------------------------------------
285 // Get read-only access to the data
286 //------------------------------------------------------------------------------
CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)287 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
288   return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data);
289 }
290 
291 //------------------------------------------------------------------------------
292 // Get read/write access to the data
293 //------------------------------------------------------------------------------
CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)294 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
295   CeedQFunctionContext_Cuda *impl;
296 
297   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
298   CeedCallBackend(CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data));
299 
300   // Mark only pointer for requested memory as valid
301   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
302   switch (mem_type) {
303     case CEED_MEM_HOST:
304       impl->h_data = *(void **)data;
305       break;
306     case CEED_MEM_DEVICE:
307       impl->d_data = *(void **)data;
308       break;
309   }
310   return CEED_ERROR_SUCCESS;
311 }
312 
313 //------------------------------------------------------------------------------
314 // Destroy the user context
315 //------------------------------------------------------------------------------
CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx)316 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) {
317   CeedQFunctionContext_Cuda *impl;
318 
319   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
320   CeedCallCuda(CeedQFunctionContextReturnCeed(ctx), cudaFree(impl->d_data_owned));
321   CeedCallBackend(CeedFree(&impl->h_data_owned));
322   CeedCallBackend(CeedFree(&impl));
323   return CEED_ERROR_SUCCESS;
324 }
325 
326 //------------------------------------------------------------------------------
327 // QFunctionContext Create
328 //------------------------------------------------------------------------------
CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx)329 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) {
330   CeedQFunctionContext_Cuda *impl;
331   Ceed                       ceed;
332 
333   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
334   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Cuda));
335   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Cuda));
336   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Cuda));
337   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Cuda));
338   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Cuda));
339   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Cuda));
340   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Cuda));
341   CeedCallBackend(CeedDestroy(&ceed));
342   CeedCallBackend(CeedCalloc(1, &impl));
343   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
344   return CEED_ERROR_SUCCESS;
345 }
346 
347 //------------------------------------------------------------------------------
348