xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-qfunctioncontext.c (revision 2459f3f1cd4d7d2e210e1c26d669bd2fde41a0b6)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <string.h>
12 #include "ceed-cuda-ref.h"
13 
14 //------------------------------------------------------------------------------
15 // * Bytes used
16 //------------------------------------------------------------------------------
17 static inline size_t bytes(const CeedQFunctionContext ctx) {
18   int ierr;
19   size_t ctxsize;
20   ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr);
21   return ctxsize;
22 }
23 
24 //------------------------------------------------------------------------------
25 // Sync host to device
26 //------------------------------------------------------------------------------
27 static inline int CeedQFunctionContextSyncH2D_Cuda(
28   const CeedQFunctionContext ctx) {
29   int ierr;
30   Ceed ceed;
31   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
32   CeedQFunctionContext_Cuda *impl;
33   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
34 
35   if (!impl->h_data)
36     // LCOV_EXCL_START
37     return CeedError(ceed, CEED_ERROR_BACKEND,
38                      "No valid host data to sync to device");
39   // LCOV_EXCL_STOP
40 
41   if (impl->d_data_borrowed) {
42     impl->d_data = impl->d_data_borrowed;
43   } else if (impl->d_data_owned) {
44     impl->d_data = impl->d_data_owned;
45   } else {
46     ierr = cudaMalloc((void **)&impl->d_data_owned, bytes(ctx));
47     CeedChk_Cu(ceed, ierr);
48     impl->d_data = impl->d_data_owned;
49   }
50 
51   ierr = cudaMemcpy(impl->d_data, impl->h_data, bytes(ctx),
52                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
53 
54   return CEED_ERROR_SUCCESS;
55 }
56 
57 //------------------------------------------------------------------------------
58 // Sync device to host
59 //------------------------------------------------------------------------------
60 static inline int CeedQFunctionContextSyncD2H_Cuda(
61   const CeedQFunctionContext ctx) {
62   int ierr;
63   Ceed ceed;
64   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
65   CeedQFunctionContext_Cuda *impl;
66   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
67 
68   if (!impl->d_data)
69     // LCOV_EXCL_START
70     return CeedError(ceed, CEED_ERROR_BACKEND,
71                      "No valid device data to sync to host");
72   // LCOV_EXCL_STOP
73 
74   if (impl->h_data_borrowed) {
75     impl->h_data = impl->h_data_borrowed;
76   } else if (impl->h_data_owned) {
77     impl->h_data = impl->h_data_owned;
78   } else {
79     ierr = CeedMalloc(bytes(ctx), &impl->h_data_owned);
80     CeedChkBackend(ierr);
81     impl->h_data = impl->h_data_owned;
82   }
83 
84   ierr = cudaMemcpy(impl->h_data, impl->d_data, bytes(ctx),
85                     cudaMemcpyDeviceToHost); CeedChk_Cu(ceed, ierr);
86 
87   return CEED_ERROR_SUCCESS;
88 }
89 
90 //------------------------------------------------------------------------------
91 // Sync data of type
92 //------------------------------------------------------------------------------
93 static inline int CeedQFunctionContextSync_Cuda(
94   const CeedQFunctionContext ctx, CeedMemType mem_type) {
95   switch (mem_type) {
96   case CEED_MEM_HOST: return CeedQFunctionContextSyncD2H_Cuda(ctx);
97   case CEED_MEM_DEVICE: return CeedQFunctionContextSyncH2D_Cuda(ctx);
98   }
99   return CEED_ERROR_UNSUPPORTED;
100 }
101 
102 //------------------------------------------------------------------------------
103 // Set all pointers as invalid
104 //------------------------------------------------------------------------------
105 static inline int CeedQFunctionContextSetAllInvalid_Cuda(
106   const CeedQFunctionContext ctx) {
107   int ierr;
108   CeedQFunctionContext_Cuda *impl;
109   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
110 
111   impl->h_data = NULL;
112   impl->d_data = NULL;
113 
114   return CEED_ERROR_SUCCESS;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Check if ctx has valid data
119 //------------------------------------------------------------------------------
120 static inline int CeedQFunctionContextHasValidData_Cuda(
121   const CeedQFunctionContext ctx, bool *has_valid_data) {
122   int ierr;
123   CeedQFunctionContext_Cuda *impl;
124   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
125 
126   *has_valid_data = !!impl->h_data || !!impl->d_data;
127 
128   return CEED_ERROR_SUCCESS;
129 }
130 
131 //------------------------------------------------------------------------------
132 // Check if ctx has borrowed data
133 //------------------------------------------------------------------------------
134 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(
135   const CeedQFunctionContext ctx, CeedMemType mem_type,
136   bool *has_borrowed_data_of_type) {
137   int ierr;
138   CeedQFunctionContext_Cuda *impl;
139   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
140 
141   switch (mem_type) {
142   case CEED_MEM_HOST:
143     *has_borrowed_data_of_type = !!impl->h_data_borrowed;
144     break;
145   case CEED_MEM_DEVICE:
146     *has_borrowed_data_of_type = !!impl->d_data_borrowed;
147     break;
148   }
149 
150   return CEED_ERROR_SUCCESS;
151 }
152 
153 //------------------------------------------------------------------------------
154 // Check if data of given type needs sync
155 //------------------------------------------------------------------------------
156 static inline int CeedQFunctionContextNeedSync_Cuda(
157   const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
158   int ierr;
159   CeedQFunctionContext_Cuda *impl;
160   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
161 
162   bool has_valid_data = true;
163   ierr = CeedQFunctionContextHasValidData(ctx, &has_valid_data);
164   CeedChkBackend(ierr);
165   switch (mem_type) {
166   case CEED_MEM_HOST:
167     *need_sync = has_valid_data && !impl->h_data;
168     break;
169   case CEED_MEM_DEVICE:
170     *need_sync = has_valid_data && !impl->d_data;
171     break;
172   }
173 
174   return CEED_ERROR_SUCCESS;
175 }
176 
177 //------------------------------------------------------------------------------
178 // Set data from host
179 //------------------------------------------------------------------------------
180 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx,
181     const CeedCopyMode copy_mode, void *data) {
182   int ierr;
183   CeedQFunctionContext_Cuda *impl;
184   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
185 
186   ierr = CeedFree(&impl->h_data_owned); CeedChkBackend(ierr);
187   switch (copy_mode) {
188   case CEED_COPY_VALUES: {
189     ierr = CeedMalloc(bytes(ctx), &impl->h_data_owned); CeedChkBackend(ierr);
190     impl->h_data_borrowed = NULL;
191     impl->h_data = impl->h_data_owned;
192     memcpy(impl->h_data, data, bytes(ctx));
193   } break;
194   case CEED_OWN_POINTER:
195     impl->h_data_owned = data;
196     impl->h_data_borrowed = NULL;
197     impl->h_data = data;
198     break;
199   case CEED_USE_POINTER:
200     impl->h_data_borrowed = data;
201     impl->h_data = data;
202     break;
203   }
204 
205   return CEED_ERROR_SUCCESS;
206 }
207 
208 //------------------------------------------------------------------------------
209 // Set data from device
210 //------------------------------------------------------------------------------
211 static int CeedQFunctionContextSetDataDevice_Cuda(
212   const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
213   int ierr;
214   Ceed ceed;
215   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
216   CeedQFunctionContext_Cuda *impl;
217   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
218 
219   ierr = cudaFree(impl->d_data_owned); CeedChk_Cu(ceed, ierr);
220   impl->d_data_owned = NULL;
221   switch (copy_mode) {
222   case CEED_COPY_VALUES:
223     ierr = cudaMalloc((void **)&impl->d_data_owned, bytes(ctx));
224     CeedChk_Cu(ceed, ierr);
225     impl->d_data_borrowed = NULL;
226     impl->d_data = impl->d_data_owned;
227     ierr = cudaMemcpy(impl->d_data, data, bytes(ctx),
228                       cudaMemcpyDeviceToDevice); CeedChk_Cu(ceed, ierr);
229     break;
230   case CEED_OWN_POINTER:
231     impl->d_data_owned = data;
232     impl->d_data_borrowed = NULL;
233     impl->d_data = data;
234     break;
235   case CEED_USE_POINTER:
236     impl->d_data_owned = NULL;
237     impl->d_data_borrowed = data;
238     impl->d_data = data;
239     break;
240   }
241 
242   return CEED_ERROR_SUCCESS;
243 }
244 
245 //------------------------------------------------------------------------------
246 // Set the data used by a user context,
247 //   freeing any previously allocated data if applicable
248 //------------------------------------------------------------------------------
249 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx,
250     const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
251   int ierr;
252   Ceed ceed;
253   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
254 
255   ierr = CeedQFunctionContextSetAllInvalid_Cuda(ctx); CeedChkBackend(ierr);
256   switch (mem_type) {
257   case CEED_MEM_HOST:
258     return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data);
259   case CEED_MEM_DEVICE:
260     return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data);
261   }
262 
263   return CEED_ERROR_UNSUPPORTED;
264 }
265 
266 //------------------------------------------------------------------------------
267 // Take data
268 //------------------------------------------------------------------------------
269 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx,
270     const CeedMemType mem_type, void *data) {
271   int ierr;
272   Ceed ceed;
273   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
274   CeedQFunctionContext_Cuda *impl;
275   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
276 
277   // Sync data to requested mem_type
278   bool need_sync = false;
279   ierr = CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync);
280   CeedChkBackend(ierr);
281   if (need_sync) {
282     ierr = CeedQFunctionContextSync_Cuda(ctx, mem_type); CeedChkBackend(ierr);
283   }
284 
285   // Update pointer
286   switch (mem_type) {
287   case CEED_MEM_HOST:
288     *(void **)data = impl->h_data_borrowed;
289     impl->h_data_borrowed = NULL;
290     impl->h_data = NULL;
291     break;
292   case CEED_MEM_DEVICE:
293     *(void **)data = impl->d_data_borrowed;
294     impl->d_data_borrowed = NULL;
295     impl->d_data = NULL;
296     break;
297   }
298 
299   return CEED_ERROR_SUCCESS;
300 }
301 
302 //------------------------------------------------------------------------------
303 // Core logic for GetData.
304 //   If a different memory type is most up to date, this will perform a copy
305 //------------------------------------------------------------------------------
306 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx,
307     const CeedMemType mem_type, void *data) {
308   int ierr;
309   Ceed ceed;
310   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
311   CeedQFunctionContext_Cuda *impl;
312   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
313 
314   // Sync data to requested mem_type
315   bool need_sync = false;
316   ierr = CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync);
317   CeedChkBackend(ierr);
318   if (need_sync) {
319     ierr = CeedQFunctionContextSync_Cuda(ctx, mem_type); CeedChkBackend(ierr);
320   }
321 
322   // Update pointer
323   switch (mem_type) {
324   case CEED_MEM_HOST:
325     *(void **)data = impl->h_data;
326     break;
327   case CEED_MEM_DEVICE:
328     *(void **)data = impl->d_data;
329     break;
330   }
331 
332   return CEED_ERROR_SUCCESS;
333 }
334 
335 //------------------------------------------------------------------------------
336 // Get read-only access to the data
337 //------------------------------------------------------------------------------
338 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx,
339     const CeedMemType mem_type, void *data) {
340   return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data);
341 }
342 
343 //------------------------------------------------------------------------------
344 // Get read/write access to the data
345 //------------------------------------------------------------------------------
346 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx,
347     const CeedMemType mem_type, void *data) {
348   int ierr;
349   CeedQFunctionContext_Cuda *impl;
350   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
351 
352   ierr = CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data);
353   CeedChkBackend(ierr);
354 
355   // Mark only pointer for requested memory as valid
356   ierr = CeedQFunctionContextSetAllInvalid_Cuda(ctx); CeedChkBackend(ierr);
357   switch (mem_type) {
358   case CEED_MEM_HOST:
359     impl->h_data = *(void **)data;
360     break;
361   case CEED_MEM_DEVICE:
362     impl->d_data = *(void **)data;
363     break;
364   }
365 
366   return CEED_ERROR_SUCCESS;
367 }
368 
369 //------------------------------------------------------------------------------
370 // Destroy the user context
371 //------------------------------------------------------------------------------
372 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) {
373   int ierr;
374   Ceed ceed;
375   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
376   CeedQFunctionContext_Cuda *impl;
377   ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr);
378 
379   ierr = cudaFree(impl->d_data_owned); CeedChk_Cu(ceed, ierr);
380   ierr = CeedFree(&impl->h_data_owned); CeedChkBackend(ierr);
381   ierr = CeedFree(&impl); CeedChkBackend(ierr);
382 
383   return CEED_ERROR_SUCCESS;
384 }
385 
386 //------------------------------------------------------------------------------
387 // QFunctionContext Create
388 //------------------------------------------------------------------------------
389 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) {
390   int ierr;
391   CeedQFunctionContext_Cuda *impl;
392   Ceed ceed;
393   ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr);
394 
395   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData",
396                                 CeedQFunctionContextHasValidData_Cuda);
397   CeedChkBackend(ierr);
398   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx,
399                                 "HasBorrowedDataOfType",
400                                 CeedQFunctionContextHasBorrowedDataOfType_Cuda);
401   CeedChkBackend(ierr);
402   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData",
403                                 CeedQFunctionContextSetData_Cuda); CeedChkBackend(ierr);
404   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData",
405                                 CeedQFunctionContextTakeData_Cuda); CeedChkBackend(ierr);
406   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData",
407                                 CeedQFunctionContextGetData_Cuda); CeedChkBackend(ierr);
408   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead",
409                                 CeedQFunctionContextGetDataRead_Cuda); CeedChkBackend(ierr);
410   ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy",
411                                 CeedQFunctionContextDestroy_Cuda); CeedChkBackend(ierr);
412 
413   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
414   ierr = CeedQFunctionContextSetBackendData(ctx, impl); CeedChkBackend(ierr);
415 
416   return CEED_ERROR_SUCCESS;
417 }
418 //------------------------------------------------------------------------------
419