xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-qfunctioncontext.c (revision 019b76820d7ff306c177822c4e76ffe5939c204b)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <string.h>
12 #include "ceed-cuda-ref.h"
13 
14 //------------------------------------------------------------------------------
15 // Sync host to device
16 //------------------------------------------------------------------------------
17 static inline int CeedQFunctionContextSyncH2D_Cuda(
18   const CeedQFunctionContext ctx) {
19   int ierr;
20   Ceed ceed;
21   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
22   CeedQFunctionContext_Cuda *impl;
23   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
24 
25   if (!impl->h_data)
26     // LCOV_EXCL_START
27     return CeedError(ceed, CEED_ERROR_BACKEND,
28                      "No valid host data to sync to device");
29   // LCOV_EXCL_STOP
30 
31   size_t ctxsize;
32   ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr);
33 
34   if (impl->d_data_borrowed) {
35     impl->d_data = impl->d_data_borrowed;
36   } else if (impl->d_data_owned) {
37     impl->d_data = impl->d_data_owned;
38   } else {
39     ierr = cudaMalloc((void **)&impl->d_data_owned, ctxsize);
40     CeedChk_Cu(ceed, ierr);
41     impl->d_data = impl->d_data_owned;
42   }
43 
44   ierr = cudaMemcpy(impl->d_data, impl->h_data, ctxsize,
45                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
46 
47   return CEED_ERROR_SUCCESS;
48 }
49 
50 //------------------------------------------------------------------------------
51 // Sync device to host
52 //------------------------------------------------------------------------------
53 static inline int CeedQFunctionContextSyncD2H_Cuda(
54   const CeedQFunctionContext ctx) {
55   int ierr;
56   Ceed ceed;
57   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
58   CeedQFunctionContext_Cuda *impl;
59   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
60 
61   if (!impl->d_data)
62     // LCOV_EXCL_START
63     return CeedError(ceed, CEED_ERROR_BACKEND,
64                      "No valid device data to sync to host");
65   // LCOV_EXCL_STOP
66 
67   size_t ctxsize;
68   ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr);
69 
70   if (impl->h_data_borrowed) {
71     impl->h_data = impl->h_data_borrowed;
72   } else if (impl->h_data_owned) {
73     impl->h_data = impl->h_data_owned;
74   } else {
75     ierr = CeedMallocArray(1, ctxsize, &impl->h_data_owned);
76     CeedChkBackend(ierr);
77     impl->h_data = impl->h_data_owned;
78   }
79 
80   ierr = cudaMemcpy(impl->h_data, impl->d_data, ctxsize,
81                     cudaMemcpyDeviceToHost); CeedChk_Cu(ceed, ierr);
82 
83   return CEED_ERROR_SUCCESS;
84 }
85 
86 //------------------------------------------------------------------------------
87 // Sync data of type
88 //------------------------------------------------------------------------------
89 static inline int CeedQFunctionContextSync_Cuda(
90   const CeedQFunctionContext ctx, CeedMemType mem_type) {
91   switch (mem_type) {
92   case CEED_MEM_HOST: return CeedQFunctionContextSyncD2H_Cuda(ctx);
93   case CEED_MEM_DEVICE: return CeedQFunctionContextSyncH2D_Cuda(ctx);
94   }
95   return CEED_ERROR_UNSUPPORTED;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Set all pointers as invalid
100 //------------------------------------------------------------------------------
101 static inline int CeedQFunctionContextSetAllInvalid_Cuda(
102   const CeedQFunctionContext ctx) {
103   int ierr;
104   CeedQFunctionContext_Cuda *impl;
105   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
106 
107   impl->h_data = NULL;
108   impl->d_data = NULL;
109 
110   return CEED_ERROR_SUCCESS;
111 }
112 
113 //------------------------------------------------------------------------------
114 // Check if ctx has valid data
115 //------------------------------------------------------------------------------
116 static inline int CeedQFunctionContextHasValidData_Cuda(
117   const CeedQFunctionContext ctx, bool *has_valid_data) {
118   int ierr;
119   CeedQFunctionContext_Cuda *impl;
120   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
121 
122   *has_valid_data = !!impl->h_data || !!impl->d_data;
123 
124   return CEED_ERROR_SUCCESS;
125 }
126 
127 //------------------------------------------------------------------------------
128 // Check if ctx has borrowed data
129 //------------------------------------------------------------------------------
130 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(
131   const CeedQFunctionContext ctx, CeedMemType mem_type,
132   bool *has_borrowed_data_of_type) {
133   int ierr;
134   CeedQFunctionContext_Cuda *impl;
135   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
136 
137   switch (mem_type) {
138   case CEED_MEM_HOST:
139     *has_borrowed_data_of_type = !!impl->h_data_borrowed;
140     break;
141   case CEED_MEM_DEVICE:
142     *has_borrowed_data_of_type = !!impl->d_data_borrowed;
143     break;
144   }
145 
146   return CEED_ERROR_SUCCESS;
147 }
148 
149 //------------------------------------------------------------------------------
150 // Check if data of given type needs sync
151 //------------------------------------------------------------------------------
152 static inline int CeedQFunctionContextNeedSync_Cuda(
153   const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
154   int ierr;
155   CeedQFunctionContext_Cuda *impl;
156   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
157 
158   bool has_valid_data = true;
159   ierr = CeedQFunctionContextHasValidData(ctx, &has_valid_data);
160   CeedChkBackend(ierr);
161   switch (mem_type) {
162   case CEED_MEM_HOST:
163     *need_sync = has_valid_data && !impl->h_data;
164     break;
165   case CEED_MEM_DEVICE:
166     *need_sync = has_valid_data && !impl->d_data;
167     break;
168   }
169 
170   return CEED_ERROR_SUCCESS;
171 }
172 
173 //------------------------------------------------------------------------------
174 // Set data from host
175 //------------------------------------------------------------------------------
176 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx,
177     const CeedCopyMode copy_mode, void *data) {
178   int ierr;
179   CeedQFunctionContext_Cuda *impl;
180   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
181 
182   ierr = CeedFree(&impl->h_data_owned); CeedChkBackend(ierr);
183   switch (copy_mode) {
184   case CEED_COPY_VALUES: {
185     size_t ctxsize;
186     ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr);
187     ierr = CeedMallocArray(1, ctxsize, &impl->h_data_owned);
188     CeedChkBackend(ierr);
189     impl->h_data_borrowed = NULL;
190     impl->h_data = impl->h_data_owned;
191     memcpy(impl->h_data, data, ctxsize);
192   } break;
193   case CEED_OWN_POINTER:
194     impl->h_data_owned = data;
195     impl->h_data_borrowed = NULL;
196     impl->h_data = data;
197     break;
198   case CEED_USE_POINTER:
199     impl->h_data_borrowed = data;
200     impl->h_data = data;
201     break;
202   }
203 
204   return CEED_ERROR_SUCCESS;
205 }
206 
207 //------------------------------------------------------------------------------
208 // Set data from device
209 //------------------------------------------------------------------------------
210 static int CeedQFunctionContextSetDataDevice_Cuda(
211   const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
212   int ierr;
213   Ceed ceed;
214   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
215   CeedQFunctionContext_Cuda *impl;
216   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
217 
218   ierr = cudaFree(impl->d_data_owned); CeedChk_Cu(ceed, ierr);
219   impl->d_data_owned = NULL;
220   switch (copy_mode) {
221   case CEED_COPY_VALUES: {
222     size_t ctxsize;
223     ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr);
224     ierr = cudaMalloc((void **)&impl->d_data_owned, ctxsize);
225     CeedChk_Cu(ceed, ierr);
226     impl->d_data_borrowed = NULL;
227     impl->d_data = impl->d_data_owned;
228     ierr = cudaMemcpy(impl->d_data, data, ctxsize,
229                       cudaMemcpyDeviceToDevice); CeedChk_Cu(ceed, ierr);
230   } break;
231   case CEED_OWN_POINTER:
232     impl->d_data_owned = data;
233     impl->d_data_borrowed = NULL;
234     impl->d_data = data;
235     break;
236   case CEED_USE_POINTER:
237     impl->d_data_owned = NULL;
238     impl->d_data_borrowed = data;
239     impl->d_data = data;
240     break;
241   }
242 
243   return CEED_ERROR_SUCCESS;
244 }
245 
246 //------------------------------------------------------------------------------
247 // Set the data used by a user context,
248 //   freeing any previously allocated data if applicable
249 //------------------------------------------------------------------------------
250 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx,
251     const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
252   int ierr;
253   Ceed ceed;
254   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
255 
256   ierr = CeedQFunctionContextSetAllInvalid_Cuda(ctx); CeedChkBackend(ierr);
257   switch (mem_type) {
258   case CEED_MEM_HOST:
259     return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data);
260   case CEED_MEM_DEVICE:
261     return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data);
262   }
263 
264   return CEED_ERROR_UNSUPPORTED;
265 }
266 
267 //------------------------------------------------------------------------------
268 // Take data
269 //------------------------------------------------------------------------------
270 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx,
271     const CeedMemType mem_type, void *data) {
272   int ierr;
273   Ceed ceed;
274   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
275   CeedQFunctionContext_Cuda *impl;
276   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
277 
278   // Sync data to requested mem_type
279   bool need_sync = false;
280   ierr = CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync);
281   CeedChkBackend(ierr);
282   if (need_sync) {
283     ierr = CeedQFunctionContextSync_Cuda(ctx, mem_type); CeedChkBackend(ierr);
284   }
285 
286   // Update pointer
287   switch (mem_type) {
288   case CEED_MEM_HOST:
289     *(void **)data = impl->h_data_borrowed;
290     impl->h_data_borrowed = NULL;
291     impl->h_data = NULL;
292     break;
293   case CEED_MEM_DEVICE:
294     *(void **)data = impl->d_data_borrowed;
295     impl->d_data_borrowed = NULL;
296     impl->d_data = NULL;
297     break;
298   }
299 
300   return CEED_ERROR_SUCCESS;
301 }
302 
303 //------------------------------------------------------------------------------
304 // Core logic for GetData.
305 //   If a different memory type is most up to date, this will perform a copy
306 //------------------------------------------------------------------------------
307 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx,
308     const CeedMemType mem_type, void *data) {
309   int ierr;
310   Ceed ceed;
311   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
312   CeedQFunctionContext_Cuda *impl;
313   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
314 
315   // Sync data to requested mem_type
316   bool need_sync = false;
317   ierr = CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync);
318   CeedChkBackend(ierr);
319   if (need_sync) {
320     ierr = CeedQFunctionContextSync_Cuda(ctx, mem_type); CeedChkBackend(ierr);
321   }
322 
323   // Update pointer
324   switch (mem_type) {
325   case CEED_MEM_HOST:
326     *(void **)data = impl->h_data;
327     break;
328   case CEED_MEM_DEVICE:
329     *(void **)data = impl->d_data;
330     break;
331   }
332 
333   return CEED_ERROR_SUCCESS;
334 }
335 
336 //------------------------------------------------------------------------------
337 // Get read-only access to the data
338 //------------------------------------------------------------------------------
339 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx,
340     const CeedMemType mem_type, void *data) {
341   return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data);
342 }
343 
344 //------------------------------------------------------------------------------
345 // Get read/write access to the data
346 //------------------------------------------------------------------------------
347 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx,
348     const CeedMemType mem_type, void *data) {
349   int ierr;
350   CeedQFunctionContext_Cuda *impl;
351   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
352 
353   ierr = CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data);
354   CeedChkBackend(ierr);
355 
356   // Mark only pointer for requested memory as valid
357   ierr = CeedQFunctionContextSetAllInvalid_Cuda(ctx); CeedChkBackend(ierr);
358   switch (mem_type) {
359   case CEED_MEM_HOST:
360     impl->h_data = *(void **)data;
361     break;
362   case CEED_MEM_DEVICE:
363     impl->d_data = *(void **)data;
364     break;
365   }
366 
367   return CEED_ERROR_SUCCESS;
368 }
369 
370 //------------------------------------------------------------------------------
371 // Destroy the user context
372 //------------------------------------------------------------------------------
373 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) {
374   int ierr;
375   Ceed ceed;
376   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
377   CeedQFunctionContext_Cuda *impl;
378   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
379 
380   ierr = cudaFree(impl->d_data_owned); CeedChk_Cu(ceed, ierr);
381   ierr = CeedFree(&impl->h_data_owned); CeedChkBackend(ierr);
382   ierr = CeedFree(&impl); CeedChkBackend(ierr);
383 
384   return CEED_ERROR_SUCCESS;
385 }
386 
387 //------------------------------------------------------------------------------
388 // QFunctionContext Create
389 //------------------------------------------------------------------------------
390 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) {
391   int ierr;
392   CeedQFunctionContext_Cuda *impl;
393   Ceed ceed;
394   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
395 
396   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData",
397                                 CeedQFunctionContextHasValidData_Cuda);
398   CeedChkBackend(ierr);
399   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx,
400                                 "HasBorrowedDataOfType",
401                                 CeedQFunctionContextHasBorrowedDataOfType_Cuda);
402   CeedChkBackend(ierr);
403   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData",
404                                 CeedQFunctionContextSetData_Cuda); CeedChkBackend(ierr);
405   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData",
406                                 CeedQFunctionContextTakeData_Cuda); CeedChkBackend(ierr);
407   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData",
408                                 CeedQFunctionContextGetData_Cuda); CeedChkBackend(ierr);
409   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead",
410                                 CeedQFunctionContextGetDataRead_Cuda); CeedChkBackend(ierr);
411   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy",
412                                 CeedQFunctionContextDestroy_Cuda); CeedChkBackend(ierr);
413 
414   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
415   ierr = CeedQFunctionContextSetBackendData(ctx, impl); CeedChkBackend(ierr);
416 
417   return CEED_ERROR_SUCCESS;
418 }
419 //------------------------------------------------------------------------------
420