xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-qfunctioncontext.c (revision ca7355308a14805110a7d98dabf1f658d5ec16d5)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <stdbool.h>
12 #include <string.h>
13 
14 #include "../cuda/ceed-cuda-common.h"
15 #include "ceed-cuda-ref.h"
16 
17 //------------------------------------------------------------------------------
18 // Sync host to device
19 //------------------------------------------------------------------------------
20 static inline int CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx) {
21   Ceed                       ceed;
22   size_t                     ctx_size;
23   CeedQFunctionContext_Cuda *impl;
24 
25   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
26   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
27 
28   CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29 
30   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31   if (impl->d_data_borrowed) {
32     impl->d_data = impl->d_data_borrowed;
33   } else if (impl->d_data_owned) {
34     impl->d_data = impl->d_data_owned;
35   } else {
36     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size));
37     impl->d_data = impl->d_data_owned;
38   }
39   CeedCallCuda(ceed, cudaMemcpy(impl->d_data, impl->h_data, ctx_size, cudaMemcpyHostToDevice));
40   return CEED_ERROR_SUCCESS;
41 }
42 
43 //------------------------------------------------------------------------------
44 // Sync device to host
45 //------------------------------------------------------------------------------
46 static inline int CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx) {
47   Ceed                       ceed;
48   size_t                     ctx_size;
49   CeedQFunctionContext_Cuda *impl;
50 
51   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
52   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
53 
54   CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
55 
56   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
57 
58   if (impl->h_data_borrowed) {
59     impl->h_data = impl->h_data_borrowed;
60   } else if (impl->h_data_owned) {
61     impl->h_data = impl->h_data_owned;
62   } else {
63     CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
64     impl->h_data = impl->h_data_owned;
65   }
66   CeedCallCuda(ceed, cudaMemcpy(impl->h_data, impl->d_data, ctx_size, cudaMemcpyDeviceToHost));
67   return CEED_ERROR_SUCCESS;
68 }
69 
70 //------------------------------------------------------------------------------
71 // Sync data of type
72 //------------------------------------------------------------------------------
73 static inline int CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type) {
74   switch (mem_type) {
75     case CEED_MEM_HOST:
76       return CeedQFunctionContextSyncD2H_Cuda(ctx);
77     case CEED_MEM_DEVICE:
78       return CeedQFunctionContextSyncH2D_Cuda(ctx);
79   }
80   return CEED_ERROR_UNSUPPORTED;
81 }
82 
83 //------------------------------------------------------------------------------
84 // Set all pointers as invalid
85 //------------------------------------------------------------------------------
86 static inline int CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx) {
87   CeedQFunctionContext_Cuda *impl;
88 
89   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
90   impl->h_data = NULL;
91   impl->d_data = NULL;
92   return CEED_ERROR_SUCCESS;
93 }
94 
95 //------------------------------------------------------------------------------
96 // Check if ctx has valid data
97 //------------------------------------------------------------------------------
98 static inline int CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx, bool *has_valid_data) {
99   CeedQFunctionContext_Cuda *impl;
100 
101   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
102   *has_valid_data = impl && (impl->h_data || impl->d_data);
103   return CEED_ERROR_SUCCESS;
104 }
105 
106 //------------------------------------------------------------------------------
107 // Check if ctx has borrowed data
108 //------------------------------------------------------------------------------
109 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type,
110                                                                  bool *has_borrowed_data_of_type) {
111   CeedQFunctionContext_Cuda *impl;
112 
113   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
114   switch (mem_type) {
115     case CEED_MEM_HOST:
116       *has_borrowed_data_of_type = impl->h_data_borrowed;
117       break;
118     case CEED_MEM_DEVICE:
119       *has_borrowed_data_of_type = impl->d_data_borrowed;
120       break;
121   }
122   return CEED_ERROR_SUCCESS;
123 }
124 
125 //------------------------------------------------------------------------------
126 // Check if data of given type needs sync
127 //------------------------------------------------------------------------------
128 static inline int CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
129   bool                       has_valid_data = true;
130   CeedQFunctionContext_Cuda *impl;
131 
132   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
133   CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
134   switch (mem_type) {
135     case CEED_MEM_HOST:
136       *need_sync = has_valid_data && !impl->h_data;
137       break;
138     case CEED_MEM_DEVICE:
139       *need_sync = has_valid_data && !impl->d_data;
140       break;
141   }
142   return CEED_ERROR_SUCCESS;
143 }
144 
145 //------------------------------------------------------------------------------
146 // Set data from host
147 //------------------------------------------------------------------------------
148 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
149   CeedQFunctionContext_Cuda *impl;
150 
151   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
152 
153   CeedCallBackend(CeedFree(&impl->h_data_owned));
154   switch (copy_mode) {
155     case CEED_COPY_VALUES: {
156       size_t ctx_size;
157       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
158       CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
159       impl->h_data_borrowed = NULL;
160       impl->h_data          = impl->h_data_owned;
161       memcpy(impl->h_data, data, ctx_size);
162     } break;
163     case CEED_OWN_POINTER:
164       impl->h_data_owned    = data;
165       impl->h_data_borrowed = NULL;
166       impl->h_data          = data;
167       break;
168     case CEED_USE_POINTER:
169       impl->h_data_borrowed = data;
170       impl->h_data          = data;
171       break;
172   }
173   return CEED_ERROR_SUCCESS;
174 }
175 
176 //------------------------------------------------------------------------------
177 // Set data from device
178 //------------------------------------------------------------------------------
179 static int CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
180   Ceed                       ceed;
181   CeedQFunctionContext_Cuda *impl;
182 
183   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
184   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
185 
186   CeedCallCuda(ceed, cudaFree(impl->d_data_owned));
187   impl->d_data_owned = NULL;
188   switch (copy_mode) {
189     case CEED_COPY_VALUES: {
190       size_t ctx_size;
191       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
192       CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size));
193       impl->d_data_borrowed = NULL;
194       impl->d_data          = impl->d_data_owned;
195       CeedCallCuda(ceed, cudaMemcpy(impl->d_data, data, ctx_size, cudaMemcpyDeviceToDevice));
196     } break;
197     case CEED_OWN_POINTER:
198       impl->d_data_owned    = data;
199       impl->d_data_borrowed = NULL;
200       impl->d_data          = data;
201       break;
202     case CEED_USE_POINTER:
203       impl->d_data_owned    = NULL;
204       impl->d_data_borrowed = data;
205       impl->d_data          = data;
206       break;
207   }
208   return CEED_ERROR_SUCCESS;
209 }
210 
211 //------------------------------------------------------------------------------
212 // Set the data used by a user context,
213 //   freeing any previously allocated data if applicable
214 //------------------------------------------------------------------------------
215 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
216   Ceed ceed;
217 
218   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
219   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
220   switch (mem_type) {
221     case CEED_MEM_HOST:
222       return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data);
223     case CEED_MEM_DEVICE:
224       return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data);
225   }
226   return CEED_ERROR_UNSUPPORTED;
227 }
228 
229 //------------------------------------------------------------------------------
230 // Take data
231 //------------------------------------------------------------------------------
232 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
233   Ceed                       ceed;
234   CeedQFunctionContext_Cuda *impl;
235 
236   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
237   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
238 
239   // Sync data to requested mem_type
240   bool need_sync = false;
241   CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
242   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
243 
244   // Update pointer
245   switch (mem_type) {
246     case CEED_MEM_HOST:
247       *(void **)data        = impl->h_data_borrowed;
248       impl->h_data_borrowed = NULL;
249       impl->h_data          = NULL;
250       break;
251     case CEED_MEM_DEVICE:
252       *(void **)data        = impl->d_data_borrowed;
253       impl->d_data_borrowed = NULL;
254       impl->d_data          = NULL;
255       break;
256   }
257   return CEED_ERROR_SUCCESS;
258 }
259 
260 //------------------------------------------------------------------------------
261 // Core logic for GetData.
262 //   If a different memory type is most up to date, this will perform a copy
263 //------------------------------------------------------------------------------
264 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
265   Ceed                       ceed;
266   bool                       need_sync = false;
267   CeedQFunctionContext_Cuda *impl;
268 
269   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
270   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
271 
272   // Sync data to requested mem_type
273   CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
274   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
275 
276   // Update pointer
277   switch (mem_type) {
278     case CEED_MEM_HOST:
279       *(void **)data = impl->h_data;
280       break;
281     case CEED_MEM_DEVICE:
282       *(void **)data = impl->d_data;
283       break;
284   }
285   return CEED_ERROR_SUCCESS;
286 }
287 
288 //------------------------------------------------------------------------------
289 // Get read-only access to the data
290 //------------------------------------------------------------------------------
291 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
292   return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data);
293 }
294 
295 //------------------------------------------------------------------------------
296 // Get read/write access to the data
297 //------------------------------------------------------------------------------
298 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
299   CeedQFunctionContext_Cuda *impl;
300 
301   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
302   CeedCallBackend(CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data));
303 
304   // Mark only pointer for requested memory as valid
305   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
306   switch (mem_type) {
307     case CEED_MEM_HOST:
308       impl->h_data = *(void **)data;
309       break;
310     case CEED_MEM_DEVICE:
311       impl->d_data = *(void **)data;
312       break;
313   }
314   return CEED_ERROR_SUCCESS;
315 }
316 
317 //------------------------------------------------------------------------------
318 // Destroy the user context
319 //------------------------------------------------------------------------------
320 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) {
321   Ceed                       ceed;
322   CeedQFunctionContext_Cuda *impl;
323 
324   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
325   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
326   CeedCallCuda(ceed, cudaFree(impl->d_data_owned));
327   CeedCallBackend(CeedFree(&impl->h_data_owned));
328   CeedCallBackend(CeedFree(&impl));
329   return CEED_ERROR_SUCCESS;
330 }
331 
332 //------------------------------------------------------------------------------
333 // QFunctionContext Create
334 //------------------------------------------------------------------------------
335 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) {
336   CeedQFunctionContext_Cuda *impl;
337   Ceed                       ceed;
338 
339   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
340   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Cuda));
341   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Cuda));
342   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Cuda));
343   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Cuda));
344   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Cuda));
345   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Cuda));
346   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Cuda));
347   CeedCallBackend(CeedCalloc(1, &impl));
348   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
349   return CEED_ERROR_SUCCESS;
350 }
351 
352 //------------------------------------------------------------------------------
353