xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-qfunctioncontext.c (revision b5404d5da366e284b3ef54a63de743df457e0da0)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <stdbool.h>
12 #include <string.h>
13 
14 #include "../cuda/ceed-cuda-common.h"
15 #include "ceed-cuda-ref.h"
16 
17 //------------------------------------------------------------------------------
18 // Sync host to device
19 //------------------------------------------------------------------------------
20 static inline int CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx) {
21   Ceed ceed;
22   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
23   CeedQFunctionContext_Cuda *impl;
24   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
25 
26   if (!impl->h_data) {
27     // LCOV_EXCL_START
28     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29     // LCOV_EXCL_STOP
30   }
31 
32   size_t ctxsize;
33   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
34 
35   if (impl->d_data_borrowed) {
36     impl->d_data = impl->d_data_borrowed;
37   } else if (impl->d_data_owned) {
38     impl->d_data = impl->d_data_owned;
39   } else {
40     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctxsize));
41     impl->d_data = impl->d_data_owned;
42   }
43 
44   CeedCallCuda(ceed, cudaMemcpy(impl->d_data, impl->h_data, ctxsize, cudaMemcpyHostToDevice));
45 
46   return CEED_ERROR_SUCCESS;
47 }
48 
49 //------------------------------------------------------------------------------
50 // Sync device to host
51 //------------------------------------------------------------------------------
52 static inline int CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx) {
53   Ceed ceed;
54   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
55   CeedQFunctionContext_Cuda *impl;
56   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
57 
58   if (!impl->d_data) {
59     // LCOV_EXCL_START
60     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
61     // LCOV_EXCL_STOP
62   }
63 
64   size_t ctxsize;
65   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
66 
67   if (impl->h_data_borrowed) {
68     impl->h_data = impl->h_data_borrowed;
69   } else if (impl->h_data_owned) {
70     impl->h_data = impl->h_data_owned;
71   } else {
72     CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
73     impl->h_data = impl->h_data_owned;
74   }
75 
76   CeedCallCuda(ceed, cudaMemcpy(impl->h_data, impl->d_data, ctxsize, cudaMemcpyDeviceToHost));
77 
78   return CEED_ERROR_SUCCESS;
79 }
80 
81 //------------------------------------------------------------------------------
82 // Sync data of type
83 //------------------------------------------------------------------------------
84 static inline int CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type) {
85   switch (mem_type) {
86     case CEED_MEM_HOST:
87       return CeedQFunctionContextSyncD2H_Cuda(ctx);
88     case CEED_MEM_DEVICE:
89       return CeedQFunctionContextSyncH2D_Cuda(ctx);
90   }
91   return CEED_ERROR_UNSUPPORTED;
92 }
93 
94 //------------------------------------------------------------------------------
95 // Set all pointers as invalid
96 //------------------------------------------------------------------------------
97 static inline int CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx) {
98   CeedQFunctionContext_Cuda *impl;
99   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
100 
101   impl->h_data = NULL;
102   impl->d_data = NULL;
103 
104   return CEED_ERROR_SUCCESS;
105 }
106 
107 //------------------------------------------------------------------------------
108 // Check if ctx has valid data
109 //------------------------------------------------------------------------------
110 static inline int CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx, bool *has_valid_data) {
111   CeedQFunctionContext_Cuda *impl;
112   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
113 
114   *has_valid_data = impl && (!!impl->h_data || !!impl->d_data);
115 
116   return CEED_ERROR_SUCCESS;
117 }
118 
119 //------------------------------------------------------------------------------
120 // Check if ctx has borrowed data
121 //------------------------------------------------------------------------------
122 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type,
123                                                                  bool *has_borrowed_data_of_type) {
124   CeedQFunctionContext_Cuda *impl;
125   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
126 
127   switch (mem_type) {
128     case CEED_MEM_HOST:
129       *has_borrowed_data_of_type = !!impl->h_data_borrowed;
130       break;
131     case CEED_MEM_DEVICE:
132       *has_borrowed_data_of_type = !!impl->d_data_borrowed;
133       break;
134   }
135 
136   return CEED_ERROR_SUCCESS;
137 }
138 
139 //------------------------------------------------------------------------------
140 // Check if data of given type needs sync
141 //------------------------------------------------------------------------------
142 static inline int CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
143   CeedQFunctionContext_Cuda *impl;
144   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
145 
146   bool has_valid_data = true;
147   CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
148   switch (mem_type) {
149     case CEED_MEM_HOST:
150       *need_sync = has_valid_data && !impl->h_data;
151       break;
152     case CEED_MEM_DEVICE:
153       *need_sync = has_valid_data && !impl->d_data;
154       break;
155   }
156 
157   return CEED_ERROR_SUCCESS;
158 }
159 
160 //------------------------------------------------------------------------------
161 // Set data from host
162 //------------------------------------------------------------------------------
163 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
164   CeedQFunctionContext_Cuda *impl;
165   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
166 
167   CeedCallBackend(CeedFree(&impl->h_data_owned));
168   switch (copy_mode) {
169     case CEED_COPY_VALUES: {
170       size_t ctxsize;
171       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
172       CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
173       impl->h_data_borrowed = NULL;
174       impl->h_data          = impl->h_data_owned;
175       memcpy(impl->h_data, data, ctxsize);
176     } break;
177     case CEED_OWN_POINTER:
178       impl->h_data_owned    = data;
179       impl->h_data_borrowed = NULL;
180       impl->h_data          = data;
181       break;
182     case CEED_USE_POINTER:
183       impl->h_data_borrowed = data;
184       impl->h_data          = data;
185       break;
186   }
187 
188   return CEED_ERROR_SUCCESS;
189 }
190 
191 //------------------------------------------------------------------------------
192 // Set data from device
193 //------------------------------------------------------------------------------
194 static int CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
195   Ceed ceed;
196   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
197   CeedQFunctionContext_Cuda *impl;
198   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
199 
200   CeedCallCuda(ceed, cudaFree(impl->d_data_owned));
201   impl->d_data_owned = NULL;
202   switch (copy_mode) {
203     case CEED_COPY_VALUES: {
204       size_t ctxsize;
205       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
206       CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctxsize));
207       impl->d_data_borrowed = NULL;
208       impl->d_data          = impl->d_data_owned;
209       CeedCallCuda(ceed, cudaMemcpy(impl->d_data, data, ctxsize, cudaMemcpyDeviceToDevice));
210     } break;
211     case CEED_OWN_POINTER:
212       impl->d_data_owned    = data;
213       impl->d_data_borrowed = NULL;
214       impl->d_data          = data;
215       break;
216     case CEED_USE_POINTER:
217       impl->d_data_owned    = NULL;
218       impl->d_data_borrowed = data;
219       impl->d_data          = data;
220       break;
221   }
222 
223   return CEED_ERROR_SUCCESS;
224 }
225 
226 //------------------------------------------------------------------------------
227 // Set the data used by a user context,
228 //   freeing any previously allocated data if applicable
229 //------------------------------------------------------------------------------
230 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
231   Ceed ceed;
232   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
233 
234   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
235   switch (mem_type) {
236     case CEED_MEM_HOST:
237       return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data);
238     case CEED_MEM_DEVICE:
239       return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data);
240   }
241 
242   return CEED_ERROR_UNSUPPORTED;
243 }
244 
245 //------------------------------------------------------------------------------
246 // Take data
247 //------------------------------------------------------------------------------
248 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
249   Ceed ceed;
250   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
251   CeedQFunctionContext_Cuda *impl;
252   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
253 
254   // Sync data to requested mem_type
255   bool need_sync = false;
256   CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
257   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
258 
259   // Update pointer
260   switch (mem_type) {
261     case CEED_MEM_HOST:
262       *(void **)data        = impl->h_data_borrowed;
263       impl->h_data_borrowed = NULL;
264       impl->h_data          = NULL;
265       break;
266     case CEED_MEM_DEVICE:
267       *(void **)data        = impl->d_data_borrowed;
268       impl->d_data_borrowed = NULL;
269       impl->d_data          = NULL;
270       break;
271   }
272 
273   return CEED_ERROR_SUCCESS;
274 }
275 
276 //------------------------------------------------------------------------------
277 // Core logic for GetData.
278 //   If a different memory type is most up to date, this will perform a copy
279 //------------------------------------------------------------------------------
280 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
281   Ceed ceed;
282   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
283   CeedQFunctionContext_Cuda *impl;
284   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
285 
286   // Sync data to requested mem_type
287   bool need_sync = false;
288   CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
289   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
290 
291   // Update pointer
292   switch (mem_type) {
293     case CEED_MEM_HOST:
294       *(void **)data = impl->h_data;
295       break;
296     case CEED_MEM_DEVICE:
297       *(void **)data = impl->d_data;
298       break;
299   }
300 
301   return CEED_ERROR_SUCCESS;
302 }
303 
304 //------------------------------------------------------------------------------
305 // Get read-only access to the data
306 //------------------------------------------------------------------------------
307 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
308   return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data);
309 }
310 
311 //------------------------------------------------------------------------------
312 // Get read/write access to the data
313 //------------------------------------------------------------------------------
314 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
315   CeedQFunctionContext_Cuda *impl;
316   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
317 
318   CeedCallBackend(CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data));
319 
320   // Mark only pointer for requested memory as valid
321   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
322   switch (mem_type) {
323     case CEED_MEM_HOST:
324       impl->h_data = *(void **)data;
325       break;
326     case CEED_MEM_DEVICE:
327       impl->d_data = *(void **)data;
328       break;
329   }
330 
331   return CEED_ERROR_SUCCESS;
332 }
333 
334 //------------------------------------------------------------------------------
335 // Destroy the user context
336 //------------------------------------------------------------------------------
337 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) {
338   Ceed ceed;
339   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
340   CeedQFunctionContext_Cuda *impl;
341   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
342 
343   CeedCallCuda(ceed, cudaFree(impl->d_data_owned));
344   CeedCallBackend(CeedFree(&impl->h_data_owned));
345   CeedCallBackend(CeedFree(&impl));
346 
347   return CEED_ERROR_SUCCESS;
348 }
349 
350 //------------------------------------------------------------------------------
351 // QFunctionContext Create
352 //------------------------------------------------------------------------------
353 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) {
354   CeedQFunctionContext_Cuda *impl;
355   Ceed                       ceed;
356   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
357 
358   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Cuda));
359   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Cuda));
360   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Cuda));
361   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Cuda));
362   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Cuda));
363   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Cuda));
364   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Cuda));
365 
366   CeedCallBackend(CeedCalloc(1, &impl));
367   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
368 
369   return CEED_ERROR_SUCCESS;
370 }
371 //------------------------------------------------------------------------------
372