xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-qfunctioncontext.sycl.cpp (revision b3de639df132887ed4e21d8d409407a2c1d2e41c)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 
11 #include <string>
12 #include <sycl/sycl.hpp>
13 
14 #include "ceed-sycl-ref.hpp"
15 
16 //------------------------------------------------------------------------------
17 // Sync host to device
18 //------------------------------------------------------------------------------
19 static inline int CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx) {
20   Ceed                       ceed;
21   Ceed_Sycl                 *sycl_data;
22   size_t                     ctx_size;
23   CeedQFunctionContext_Sycl *impl;
24 
25   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
26   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
27   CeedCallBackend(CeedGetData(ceed, &sycl_data));
28   CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29 
30   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31 
32   if (impl->d_data_borrowed) {
33     impl->d_data = impl->d_data_borrowed;
34   } else if (impl->d_data_owned) {
35     impl->d_data = impl->d_data_owned;
36   } else {
37     CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context));
38     impl->d_data = impl->d_data_owned;
39   }
40   // Order queue
41   sycl::event e          = sycl_data->sycl_queue.ext_oneapi_submit_barrier();
42   sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, impl->h_data, ctx_size, {e});
43   CeedCallSycl(ceed, copy_event.wait_and_throw());
44   return CEED_ERROR_SUCCESS;
45 }
46 
47 //------------------------------------------------------------------------------
48 // Sync device to host
49 //------------------------------------------------------------------------------
50 static inline int CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx) {
51   Ceed                       ceed;
52   Ceed_Sycl                 *sycl_data;
53   size_t                     ctx_size;
54   CeedQFunctionContext_Sycl *impl;
55 
56   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
57   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
58   CeedCallBackend(CeedGetData(ceed, &sycl_data));
59   CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
60 
61   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
62 
63   if (impl->h_data_borrowed) {
64     impl->h_data = impl->h_data_borrowed;
65   } else if (impl->h_data_owned) {
66     impl->h_data = impl->h_data_owned;
67   } else {
68     CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
69     impl->h_data = impl->h_data_owned;
70   }
71 
72   // Order queue
73   sycl::event e          = sycl_data->sycl_queue.ext_oneapi_submit_barrier();
74   sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->h_data, impl->d_data, ctx_size, {e});
75   CeedCallSycl(ceed, copy_event.wait_and_throw());
76   return CEED_ERROR_SUCCESS;
77 }
78 
79 //------------------------------------------------------------------------------
80 // Sync data of type
81 //------------------------------------------------------------------------------
82 static inline int CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type) {
83   switch (mem_type) {
84     case CEED_MEM_HOST:
85       return CeedQFunctionContextSyncD2H_Sycl(ctx);
86     case CEED_MEM_DEVICE:
87       return CeedQFunctionContextSyncH2D_Sycl(ctx);
88   }
89   return CEED_ERROR_UNSUPPORTED;
90 }
91 
92 //------------------------------------------------------------------------------
93 // Set all pointers as invalid
94 //------------------------------------------------------------------------------
95 static inline int CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx) {
96   CeedQFunctionContext_Sycl *impl;
97 
98   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
99   impl->h_data = NULL;
100   impl->d_data = NULL;
101   return CEED_ERROR_SUCCESS;
102 }
103 
104 //------------------------------------------------------------------------------
105 // Check if ctx has valid data
106 //------------------------------------------------------------------------------
107 static inline int CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx, bool *has_valid_data) {
108   CeedQFunctionContext_Sycl *impl;
109 
110   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
111   *has_valid_data = impl && (impl->h_data || impl->d_data);
112   return CEED_ERROR_SUCCESS;
113 }
114 
115 //------------------------------------------------------------------------------
116 // Check if ctx has borrowed data
117 //------------------------------------------------------------------------------
118 static inline int CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type,
119                                                                  bool *has_borrowed_data_of_type) {
120   CeedQFunctionContext_Sycl *impl;
121 
122   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
123   switch (mem_type) {
124     case CEED_MEM_HOST:
125       *has_borrowed_data_of_type = impl->h_data_borrowed;
126       break;
127     case CEED_MEM_DEVICE:
128       *has_borrowed_data_of_type = impl->d_data_borrowed;
129       break;
130   }
131   return CEED_ERROR_SUCCESS;
132 }
133 
134 //------------------------------------------------------------------------------
135 // Check if data of given type needs sync
136 //------------------------------------------------------------------------------
137 static inline int CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
138   bool                       has_valid_data = true;
139   CeedQFunctionContext_Sycl *impl;
140 
141   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
142   CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
143   switch (mem_type) {
144     case CEED_MEM_HOST:
145       *need_sync = has_valid_data && !impl->h_data;
146       break;
147     case CEED_MEM_DEVICE:
148       *need_sync = has_valid_data && !impl->d_data;
149       break;
150   }
151   return CEED_ERROR_SUCCESS;
152 }
153 
154 //------------------------------------------------------------------------------
155 // Set data from host
156 //------------------------------------------------------------------------------
157 static int CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
158   CeedQFunctionContext_Sycl *impl;
159 
160   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
161   CeedCallBackend(CeedFree(&impl->h_data_owned));
162   switch (copy_mode) {
163     case CEED_COPY_VALUES:
164       size_t ctx_size;
165 
166       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
167       CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
168       impl->h_data_borrowed = NULL;
169       impl->h_data          = impl->h_data_owned;
170       memcpy(impl->h_data, data, ctx_size);
171       break;
172     case CEED_OWN_POINTER:
173       impl->h_data_owned    = data;
174       impl->h_data_borrowed = NULL;
175       impl->h_data          = data;
176       break;
177     case CEED_USE_POINTER:
178       impl->h_data_borrowed = data;
179       impl->h_data          = data;
180       break;
181   }
182   return CEED_ERROR_SUCCESS;
183 }
184 
185 //------------------------------------------------------------------------------
186 // Set data from device
187 //------------------------------------------------------------------------------
188 static int CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
189   Ceed                       ceed;
190   Ceed_Sycl                 *sycl_data;
191   CeedQFunctionContext_Sycl *impl;
192 
193   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
194   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
195   CeedCallBackend(CeedGetData(ceed, &sycl_data));
196 
197   // Order queue
198   sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier();
199 
200   // Wait for all work to finish before freeing memory
201   if (impl->d_data_owned) {
202     CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
203     CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
204     impl->d_data_owned = NULL;
205   }
206 
207   switch (copy_mode) {
208     case CEED_COPY_VALUES: {
209       size_t ctx_size;
210 
211       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
212       CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context));
213       impl->d_data_borrowed  = NULL;
214       impl->d_data           = impl->d_data_owned;
215       sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, data, ctx_size, {e});
216       CeedCallSycl(ceed, copy_event.wait_and_throw());
217     } break;
218     case CEED_OWN_POINTER: {
219       impl->d_data_owned    = data;
220       impl->d_data_borrowed = NULL;
221       impl->d_data          = data;
222     } break;
223     case CEED_USE_POINTER: {
224       impl->d_data_owned    = NULL;
225       impl->d_data_borrowed = data;
226       impl->d_data          = data;
227     } break;
228   }
229   return CEED_ERROR_SUCCESS;
230 }
231 
232 //------------------------------------------------------------------------------
233 // Set the data used by a user context,
234 //   freeing any previously allocated data if applicable
235 //------------------------------------------------------------------------------
236 static int CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
237   Ceed ceed;
238 
239   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
240   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
241   switch (mem_type) {
242     case CEED_MEM_HOST:
243       return CeedQFunctionContextSetDataHost_Sycl(ctx, copy_mode, data);
244     case CEED_MEM_DEVICE:
245       return CeedQFunctionContextSetDataDevice_Sycl(ctx, copy_mode, data);
246   }
247   return CEED_ERROR_UNSUPPORTED;
248 }
249 
250 //------------------------------------------------------------------------------
251 // Take data
252 //------------------------------------------------------------------------------
253 static int CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
254   Ceed                       ceed;
255   Ceed_Sycl                 *ceedSycl;
256   bool                       need_sync = false;
257   CeedQFunctionContext_Sycl *impl;
258 
259   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
260   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
261   CeedCallBackend(CeedGetData(ceed, &ceedSycl));
262 
263   // Order queue
264   ceedSycl->sycl_queue.ext_oneapi_submit_barrier();
265 
266   // Sync data to requested mem_type
267   CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
268   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
269 
270   // Update pointer
271   switch (mem_type) {
272     case CEED_MEM_HOST:
273       *(void **)data        = impl->h_data_borrowed;
274       impl->h_data_borrowed = NULL;
275       impl->h_data          = NULL;
276       break;
277     case CEED_MEM_DEVICE:
278       *(void **)data        = impl->d_data_borrowed;
279       impl->d_data_borrowed = NULL;
280       impl->d_data          = NULL;
281       break;
282   }
283   return CEED_ERROR_SUCCESS;
284 }
285 
286 //------------------------------------------------------------------------------
287 // Core logic for GetData.
288 //   If a different memory type is most up to date, this will perform a copy
289 //------------------------------------------------------------------------------
290 static int CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
291   Ceed                       ceed;
292   bool                       need_sync = false;
293   CeedQFunctionContext_Sycl *impl;
294 
295   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
296   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
297 
298   // Sync data to requested mem_type
299   CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
300   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
301 
302   // Update pointer
303   switch (mem_type) {
304     case CEED_MEM_HOST:
305       *(void **)data = impl->h_data;
306       break;
307     case CEED_MEM_DEVICE:
308       *(void **)data = impl->d_data;
309       break;
310   }
311   return CEED_ERROR_SUCCESS;
312 }
313 
314 //------------------------------------------------------------------------------
315 // Get read-only access to the data
316 //------------------------------------------------------------------------------
317 static int CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
318   return CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data);
319 }
320 
321 //------------------------------------------------------------------------------
322 // Get read/write access to the data
323 //------------------------------------------------------------------------------
324 static int CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
325   Ceed                       ceed;
326   CeedQFunctionContext_Sycl *impl;
327 
328   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
329   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
330   CeedCallBackend(CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data));
331 
332   // Mark only pointer for requested memory as valid
333   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
334   switch (mem_type) {
335     case CEED_MEM_HOST:
336       impl->h_data = *(void **)data;
337       break;
338     case CEED_MEM_DEVICE:
339       impl->d_data = *(void **)data;
340       break;
341   }
342   return CEED_ERROR_SUCCESS;
343 }
344 
345 //------------------------------------------------------------------------------
346 // Destroy the user context
347 //------------------------------------------------------------------------------
348 static int CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx) {
349   Ceed                       ceed;
350   Ceed_Sycl                 *sycl_data;
351   CeedQFunctionContext_Sycl *impl;
352 
353   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
354   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
355   CeedCallBackend(CeedGetData(ceed, &sycl_data));
356 
357   // Wait for all work to finish before freeing memory
358   CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
359   CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
360   CeedCallBackend(CeedFree(&impl->h_data_owned));
361   CeedCallBackend(CeedFree(&impl));
362   return CEED_ERROR_SUCCESS;
363 }
364 
365 //------------------------------------------------------------------------------
366 // QFunctionContext Create
367 //------------------------------------------------------------------------------
368 int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx) {
369   Ceed                       ceed;
370   CeedQFunctionContext_Sycl *impl;
371 
372   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
373   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Sycl));
374   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Sycl));
375   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Sycl));
376   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Sycl));
377   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Sycl));
378   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Sycl));
379   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Sycl));
380   CeedCallBackend(CeedCalloc(1, &impl));
381   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
382   return CEED_ERROR_SUCCESS;
383 }
384 
385 //------------------------------------------------------------------------------
386