xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-qfunctioncontext.c (revision f5d1e50421556545666f89e18ad21fef6dcea5ba)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <stdbool.h>
12 #include <string.h>
13 
14 #include "../cuda/ceed-cuda-common.h"
15 #include "ceed-cuda-ref.h"
16 
17 //------------------------------------------------------------------------------
18 // Sync host to device
19 //------------------------------------------------------------------------------
20 static inline int CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx) {
21   Ceed                       ceed;
22   size_t                     ctx_size;
23   CeedQFunctionContext_Cuda *impl;
24 
25   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
26   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
27 
28   CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29 
30   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31   if (impl->d_data_borrowed) {
32     impl->d_data = impl->d_data_borrowed;
33   } else if (impl->d_data_owned) {
34     impl->d_data = impl->d_data_owned;
35   } else {
36     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size));
37     impl->d_data = impl->d_data_owned;
38   }
39   CeedCallCuda(ceed, cudaMemcpy(impl->d_data, impl->h_data, ctx_size, cudaMemcpyHostToDevice));
40   return CEED_ERROR_SUCCESS;
41 }
42 
43 //------------------------------------------------------------------------------
44 // Sync device to host
45 //------------------------------------------------------------------------------
46 static inline int CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx) {
47   Ceed                       ceed;
48   size_t                     ctx_size;
49   CeedQFunctionContext_Cuda *impl;
50 
51   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
52   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
53 
54   CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
55 
56   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
57 
58   if (impl->h_data_borrowed) {
59     impl->h_data = impl->h_data_borrowed;
60   } else if (impl->h_data_owned) {
61     impl->h_data = impl->h_data_owned;
62   } else {
63     CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
64     impl->h_data = impl->h_data_owned;
65   }
66   CeedCallCuda(ceed, cudaMemcpy(impl->h_data, impl->d_data, ctx_size, cudaMemcpyDeviceToHost));
67   return CEED_ERROR_SUCCESS;
68 }
69 
70 //------------------------------------------------------------------------------
71 // Sync data of type
72 //------------------------------------------------------------------------------
73 static inline int CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type) {
74   switch (mem_type) {
75     case CEED_MEM_HOST:
76       return CeedQFunctionContextSyncD2H_Cuda(ctx);
77     case CEED_MEM_DEVICE:
78       return CeedQFunctionContextSyncH2D_Cuda(ctx);
79   }
80   return CEED_ERROR_UNSUPPORTED;
81 }
82 
83 //------------------------------------------------------------------------------
84 // Set all pointers as invalid
85 //------------------------------------------------------------------------------
86 static inline int CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx) {
87   CeedQFunctionContext_Cuda *impl;
88 
89   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
90   impl->h_data = NULL;
91   impl->d_data = NULL;
92   return CEED_ERROR_SUCCESS;
93 }
94 
95 //------------------------------------------------------------------------------
96 // Check if ctx has valid data
97 //------------------------------------------------------------------------------
98 static inline int CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx, bool *has_valid_data) {
99   CeedQFunctionContext_Cuda *impl;
100 
101   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
102   *has_valid_data = impl && (impl->h_data || impl->d_data);
103   return CEED_ERROR_SUCCESS;
104 }
105 
106 //------------------------------------------------------------------------------
107 // Check if ctx has borrowed data
108 //------------------------------------------------------------------------------
109 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type,
110                                                                  bool *has_borrowed_data_of_type) {
111   CeedQFunctionContext_Cuda *impl;
112 
113   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
114   switch (mem_type) {
115     case CEED_MEM_HOST:
116       *has_borrowed_data_of_type = impl->h_data_borrowed;
117       break;
118     case CEED_MEM_DEVICE:
119       *has_borrowed_data_of_type = impl->d_data_borrowed;
120       break;
121   }
122   return CEED_ERROR_SUCCESS;
123 }
124 
125 //------------------------------------------------------------------------------
126 // Check if data of given type needs sync
127 //------------------------------------------------------------------------------
128 static inline int CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
129   bool                       has_valid_data = true;
130   CeedQFunctionContext_Cuda *impl;
131 
132   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
133   CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
134   switch (mem_type) {
135     case CEED_MEM_HOST:
136       *need_sync = has_valid_data && !impl->h_data;
137       break;
138     case CEED_MEM_DEVICE:
139       *need_sync = has_valid_data && !impl->d_data;
140       break;
141   }
142   return CEED_ERROR_SUCCESS;
143 }
144 
145 //------------------------------------------------------------------------------
146 // Set data from host
147 //------------------------------------------------------------------------------
148 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
149   CeedQFunctionContext_Cuda *impl;
150 
151   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
152 
153   CeedCallBackend(CeedFree(&impl->h_data_owned));
154   switch (copy_mode) {
155     case CEED_COPY_VALUES: {
156       size_t ctx_size;
157       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
158       CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
159       impl->h_data_borrowed = NULL;
160       impl->h_data          = impl->h_data_owned;
161       memcpy(impl->h_data, data, ctx_size);
162     } break;
163     case CEED_OWN_POINTER:
164       impl->h_data_owned    = data;
165       impl->h_data_borrowed = NULL;
166       impl->h_data          = data;
167       break;
168     case CEED_USE_POINTER:
169       impl->h_data_borrowed = data;
170       impl->h_data          = data;
171       break;
172   }
173   return CEED_ERROR_SUCCESS;
174 }
175 
176 //------------------------------------------------------------------------------
177 // Set data from device
178 //------------------------------------------------------------------------------
179 static int CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
180   Ceed                       ceed;
181   CeedQFunctionContext_Cuda *impl;
182 
183   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
184   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
185 
186   CeedCallCuda(ceed, cudaFree(impl->d_data_owned));
187   impl->d_data_owned = NULL;
188   switch (copy_mode) {
189     case CEED_COPY_VALUES: {
190       size_t ctx_size;
191       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
192       CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size));
193       impl->d_data_borrowed = NULL;
194       impl->d_data          = impl->d_data_owned;
195       CeedCallCuda(ceed, cudaMemcpy(impl->d_data, data, ctx_size, cudaMemcpyDeviceToDevice));
196     } break;
197     case CEED_OWN_POINTER:
198       impl->d_data_owned    = data;
199       impl->d_data_borrowed = NULL;
200       impl->d_data          = data;
201       break;
202     case CEED_USE_POINTER:
203       impl->d_data_owned    = NULL;
204       impl->d_data_borrowed = data;
205       impl->d_data          = data;
206       break;
207   }
208   return CEED_ERROR_SUCCESS;
209 }
210 
211 //------------------------------------------------------------------------------
212 // Set the data used by a user context,
213 //   freeing any previously allocated data if applicable
214 //------------------------------------------------------------------------------
215 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
216   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
217   switch (mem_type) {
218     case CEED_MEM_HOST:
219       return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data);
220     case CEED_MEM_DEVICE:
221       return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data);
222   }
223   return CEED_ERROR_UNSUPPORTED;
224 }
225 
226 //------------------------------------------------------------------------------
227 // Take data
228 //------------------------------------------------------------------------------
229 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
230   CeedQFunctionContext_Cuda *impl;
231 
232   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
233 
234   // Sync data to requested mem_type
235   bool need_sync = false;
236   CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
237   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
238 
239   // Update pointer
240   switch (mem_type) {
241     case CEED_MEM_HOST:
242       *(void **)data        = impl->h_data_borrowed;
243       impl->h_data_borrowed = NULL;
244       impl->h_data          = NULL;
245       break;
246     case CEED_MEM_DEVICE:
247       *(void **)data        = impl->d_data_borrowed;
248       impl->d_data_borrowed = NULL;
249       impl->d_data          = NULL;
250       break;
251   }
252   return CEED_ERROR_SUCCESS;
253 }
254 
255 //------------------------------------------------------------------------------
256 // Core logic for GetData.
257 //   If a different memory type is most up to date, this will perform a copy
258 //------------------------------------------------------------------------------
259 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
260   bool                       need_sync = false;
261   CeedQFunctionContext_Cuda *impl;
262 
263   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
264 
265   // Sync data to requested mem_type
266   CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync));
267   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type));
268 
269   // Update pointer
270   switch (mem_type) {
271     case CEED_MEM_HOST:
272       *(void **)data = impl->h_data;
273       break;
274     case CEED_MEM_DEVICE:
275       *(void **)data = impl->d_data;
276       break;
277   }
278   return CEED_ERROR_SUCCESS;
279 }
280 
281 //------------------------------------------------------------------------------
282 // Get read-only access to the data
283 //------------------------------------------------------------------------------
284 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
285   return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data);
286 }
287 
288 //------------------------------------------------------------------------------
289 // Get read/write access to the data
290 //------------------------------------------------------------------------------
291 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
292   CeedQFunctionContext_Cuda *impl;
293 
294   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
295   CeedCallBackend(CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data));
296 
297   // Mark only pointer for requested memory as valid
298   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx));
299   switch (mem_type) {
300     case CEED_MEM_HOST:
301       impl->h_data = *(void **)data;
302       break;
303     case CEED_MEM_DEVICE:
304       impl->d_data = *(void **)data;
305       break;
306   }
307   return CEED_ERROR_SUCCESS;
308 }
309 
310 //------------------------------------------------------------------------------
311 // Destroy the user context
312 //------------------------------------------------------------------------------
313 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) {
314   CeedQFunctionContext_Cuda *impl;
315 
316   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
317   CeedCallCuda(CeedQFunctionContextReturnCeed(ctx), cudaFree(impl->d_data_owned));
318   CeedCallBackend(CeedFree(&impl->h_data_owned));
319   CeedCallBackend(CeedFree(&impl));
320   return CEED_ERROR_SUCCESS;
321 }
322 
323 //------------------------------------------------------------------------------
324 // QFunctionContext Create
325 //------------------------------------------------------------------------------
326 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) {
327   CeedQFunctionContext_Cuda *impl;
328   Ceed                       ceed;
329 
330   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
331   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Cuda));
332   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Cuda));
333   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Cuda));
334   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Cuda));
335   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Cuda));
336   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Cuda));
337   CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Cuda));
338   CeedCallBackend(CeedCalloc(1, &impl));
339   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
340   return CEED_ERROR_SUCCESS;
341 }
342 
343 //------------------------------------------------------------------------------
344