xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-qfunctioncontext.c (revision 833ca6b54922989c9ab81e79f7087a78c914e097)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <cuda_runtime.h>
11 #include <string.h>
12 
13 #include "ceed-cuda-ref.h"
14 
15 //------------------------------------------------------------------------------
16 // Sync host to device
17 //------------------------------------------------------------------------------
18 static inline int CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx) {
19   Ceed ceed;
20   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
21   CeedQFunctionContext_Cuda *impl;
22   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
23 
24   if (!impl->h_data) {
25     // LCOV_EXCL_START
26     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
27     // LCOV_EXCL_STOP
28   }
29 
30   size_t ctxsize;
31   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
32 
33   if (impl->d_data_borrowed) {
34     impl->d_data = impl->d_data_borrowed;
35   } else if (impl->d_data_owned) {
36     impl->d_data = impl->d_data_owned;
37   } else {
38     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctxsize));
39     impl->d_data = impl->d_data_owned;
40   }
41 
42   CeedCallCuda(ceed, cudaMemcpy(impl->d_data, impl->h_data, ctxsize, cudaMemcpyHostToDevice));
43 
44   return CEED_ERROR_SUCCESS;
45 }
46 
47 //------------------------------------------------------------------------------
48 // Sync device to host
49 //------------------------------------------------------------------------------
50 static inline int CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx) {
51   Ceed ceed;
52   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
53   CeedQFunctionContext_Cuda *impl;
54   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
55 
56   if (!impl->d_data) {
57     // LCOV_EXCL_START
58     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
59     // LCOV_EXCL_STOP
60   }
61 
62   size_t ctxsize;
63   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
64 
65   if (impl->h_data_borrowed) {
66     impl->h_data = impl->h_data_borrowed;
67   } else if (impl->h_data_owned) {
68     impl->h_data = impl->h_data_owned;
69   } else {
70     CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
71     impl->h_data = impl->h_data_owned;
72   }
73 
74   CeedCallCuda(ceed, cudaMemcpy(impl->h_data, impl->d_data, ctxsize, cudaMemcpyDeviceToHost));
75 
76   return CEED_ERROR_SUCCESS;
77 }
78 
79 //------------------------------------------------------------------------------
80 // Sync data of type
81 //------------------------------------------------------------------------------
82 static inline int CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type) {
83   switch (mem_type) {
84     case CEED_MEM_HOST:
85       return CeedQFunctionContextSyncD2H_Cuda(ctx);
86     case CEED_MEM_DEVICE:
87       return CeedQFunctionContextSyncH2D_Cuda(ctx);
88   }
89   return CEED_ERROR_UNSUPPORTED;
90 }
91 
92 //------------------------------------------------------------------------------
93 // Set all pointers as invalid
94 //------------------------------------------------------------------------------
95 static inline int CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx) {
96   CeedQFunctionContext_Cuda *impl;
97   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
98 
99   impl->h_data = NULL;
100   impl->d_data = NULL;
101 
102   return CEED_ERROR_SUCCESS;
103 }
104 
105 //------------------------------------------------------------------------------
106 // Check if ctx has valid data
107 //------------------------------------------------------------------------------
108 static inline int CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx, bool *has_valid_data) {
109   CeedQFunctionContext_Cuda *impl;
110   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
111 
112   *has_valid_data = impl && (!!impl->h_data || !!impl->d_data);
113 
114   return CEED_ERROR_SUCCESS;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Check if ctx has borrowed data
119 //------------------------------------------------------------------------------
120 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type,
121                                                                  bool *has_borrowed_data_of_type) {
122   CeedQFunctionContext_Cuda *impl;
123   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
124 
125   switch (mem_type) {
126     case CEED_MEM_HOST:
127       *has_borrowed_data_of_type = !!impl->h_data_borrowed;
128       break;
129     case CEED_MEM_DEVICE:
130       *has_borrowed_data_of_type = !!impl->d_data_borrowed;
131       break;
132   }
133 
134   return CEED_ERROR_SUCCESS;
135 }
136 
137 //------------------------------------------------------------------------------
138 // Check if data of given type needs sync
139 //------------------------------------------------------------------------------
140 static inline int CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
141   CeedQFunctionContext_Cuda *impl;
142   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
143 
144   bool has_valid_data = true;
145   CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
146   switch (mem_type) {
147     case CEED_MEM_HOST:
148       *need_sync = has_valid_data && !impl->h_data;
149       break;
150     case CEED_MEM_DEVICE:
151       *need_sync = has_valid_data && !impl->d_data;
152       break;
153   }
154 
155   return CEED_ERROR_SUCCESS;
156 }
157 
158 //------------------------------------------------------------------------------
159 // Set data from host
160 //------------------------------------------------------------------------------
161 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
162   CeedQFunctionContext_Cuda *impl;
163   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
164 
165   CeedCallBackend(CeedFree(&impl->h_data_owned));
166   switch (copy_mode) {
167     case CEED_COPY_VALUES: {
168       size_t ctxsize;
169       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
170       CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
171       impl->h_data_borrowed = NULL;
172       impl->h_data          = impl->h_data_owned;
173       memcpy(impl->h_data, data, ctxsize);
174     } break;
175     case CEED_OWN_POINTER:
176       impl->h_data_owned    = data;
177       impl->h_data_borrowed = NULL;
178       impl->h_data          = data;
179       break;
180     case CEED_USE_POINTER:
181       impl->h_data_borrowed = data;
182       impl->h_data          = data;
183       break;
184   }
185 
186   return CEED_ERROR_SUCCESS;
187 }
188 
189 //------------------------------------------------------------------------------
190 // Set data from device
191 //------------------------------------------------------------------------------
192 static int CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
193   Ceed ceed;
194   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
195   CeedQFunctionContext_Cuda *impl;
196   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
197 
198   CeedCallCuda(ceed, cudaFree(impl->d_data_owned));
199   impl->d_data_owned = NULL;
200   switch (copy_mode) {
201     case CEED_COPY_VALUES: {
202       size_t ctxsize;
203       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
204       CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctxsize));
205       impl->d_data_borrowed = NULL;
206       impl->d_data          = impl->d_data_owned;
207       CeedCallCuda(ceed, cudaMemcpy(impl->d_data, data, ctxsize, cudaMemcpyDeviceToDevice));
208     } break;
209     case CEED_OWN_POINTER:
210       impl->d_data_owned    = data;
211       impl->d_data_borrowed = NULL;
212       impl->d_data          = data;
213       break;
214     case CEED_USE_POINTER:
215       impl->d_data_owned    = NULL;
216       impl->d_data_borrowed = data;
217       impl->d_data          = data;
218       break;
219   }
220 
221   return CEED_ERROR_SUCCESS;
222 }
223 
224 //------------------------------------------------------------------------------
225 // Set the data used by a user context,
226 //   freeing any previously allocated data if applicable
227 //------------------------------------------------------------------------------
228 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
229   Ceed ceed;
230   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
231 
232   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
233   switch (mem_type) {
234     case CEED_MEM_HOST:
235       return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data);
236     case CEED_MEM_DEVICE:
237       return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data);
238   }
239 
240   return CEED_ERROR_UNSUPPORTED;
241 }
242 
243 //------------------------------------------------------------------------------
244 // Take data
245 //------------------------------------------------------------------------------
246 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
247   Ceed ceed;
248   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
249   CeedQFunctionContext_Cuda *impl;
250   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
251 
252   // Sync data to requested mem_type
253   bool need_sync = false;
254   CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
255   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
256 
257   // Update pointer
258   switch (mem_type) {
259     case CEED_MEM_HOST:
260       *(void **)data        = impl->h_data_borrowed;
261       impl->h_data_borrowed = NULL;
262       impl->h_data          = NULL;
263       break;
264     case CEED_MEM_DEVICE:
265       *(void **)data        = impl->d_data_borrowed;
266       impl->d_data_borrowed = NULL;
267       impl->d_data          = NULL;
268       break;
269   }
270 
271   return CEED_ERROR_SUCCESS;
272 }
273 
274 //------------------------------------------------------------------------------
275 // Core logic for GetData.
276 //   If a different memory type is most up to date, this will perform a copy
277 //------------------------------------------------------------------------------
278 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
279   Ceed ceed;
280   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
281   CeedQFunctionContext_Cuda *impl;
282   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
283 
284   // Sync data to requested mem_type
285   bool need_sync = false;
286   CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
287   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
288 
289   // Update pointer
290   switch (mem_type) {
291     case CEED_MEM_HOST:
292       *(void **)data = impl->h_data;
293       break;
294     case CEED_MEM_DEVICE:
295       *(void **)data = impl->d_data;
296       break;
297   }
298 
299   return CEED_ERROR_SUCCESS;
300 }
301 
302 //------------------------------------------------------------------------------
303 // Get read-only access to the data
304 //------------------------------------------------------------------------------
305 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
306   return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data);
307 }
308 
309 //------------------------------------------------------------------------------
310 // Get read/write access to the data
311 //------------------------------------------------------------------------------
312 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
313   CeedQFunctionContext_Cuda *impl;
314   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
315 
316   CeedCallBackend(CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data));
317 
318   // Mark only pointer for requested memory as valid
319   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
320   switch (mem_type) {
321     case CEED_MEM_HOST:
322       impl->h_data = *(void **)data;
323       break;
324     case CEED_MEM_DEVICE:
325       impl->d_data = *(void **)data;
326       break;
327   }
328 
329   return CEED_ERROR_SUCCESS;
330 }
331 
332 //------------------------------------------------------------------------------
333 // Destroy the user context
334 //------------------------------------------------------------------------------
335 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) {
336   Ceed ceed;
337   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
338   CeedQFunctionContext_Cuda *impl;
339   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
340 
341   CeedCallCuda(ceed, cudaFree(impl->d_data_owned));
342   CeedCallBackend(CeedFree(&impl->h_data_owned));
343   CeedCallBackend(CeedFree(&impl));
344 
345   return CEED_ERROR_SUCCESS;
346 }
347 
348 //------------------------------------------------------------------------------
349 // QFunctionContext Create
350 //------------------------------------------------------------------------------
351 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) {
352   CeedQFunctionContext_Cuda *impl;
353   Ceed                       ceed;
354   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
355 
356   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Cuda));
357   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Cuda));
358   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Cuda));
359   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Cuda));
360   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Cuda));
361   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Cuda));
362   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Cuda));
363 
364   CeedCallBackend(CeedCalloc(1, &impl));
365   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
366 
367   return CEED_ERROR_SUCCESS;
368 }
369 //------------------------------------------------------------------------------
370