xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-qfunctioncontext.sycl.cpp (revision 4b3e95d5da7d4d1a3e8f0acb9d6f6cc36918381b)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 
11 #include <string>
12 #include <sycl/sycl.hpp>
13 
14 #include "ceed-sycl-ref.hpp"
15 
16 //------------------------------------------------------------------------------
17 // Sync host to device
18 //------------------------------------------------------------------------------
19 static inline int CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx) {
20   Ceed                       ceed;
21   Ceed_Sycl                 *sycl_data;
22   size_t                     ctx_size;
23   CeedQFunctionContext_Sycl *impl;
24 
25   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
26   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
27   CeedCallBackend(CeedGetData(ceed, &sycl_data));
28   CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29 
30   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31 
32   if (impl->d_data_borrowed) {
33     impl->d_data = impl->d_data_borrowed;
34   } else if (impl->d_data_owned) {
35     impl->d_data = impl->d_data_owned;
36   } else {
37     CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context));
38     impl->d_data = impl->d_data_owned;
39   }
40   std::vector<sycl::event> e;
41 
42   if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()};
43   sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, impl->h_data, ctx_size, e);
44   CeedCallSycl(ceed, copy_event.wait_and_throw());
45   return CEED_ERROR_SUCCESS;
46 }
47 
48 //------------------------------------------------------------------------------
49 // Sync device to host
50 //------------------------------------------------------------------------------
51 static inline int CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx) {
52   Ceed                       ceed;
53   Ceed_Sycl                 *sycl_data;
54   size_t                     ctx_size;
55   CeedQFunctionContext_Sycl *impl;
56 
57   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
58   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
59   CeedCallBackend(CeedGetData(ceed, &sycl_data));
60   CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
61 
62   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
63 
64   if (impl->h_data_borrowed) {
65     impl->h_data = impl->h_data_borrowed;
66   } else if (impl->h_data_owned) {
67     impl->h_data = impl->h_data_owned;
68   } else {
69     CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
70     impl->h_data = impl->h_data_owned;
71   }
72 
73   std::vector<sycl::event> e;
74 
75   if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()};
76   sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->h_data, impl->d_data, ctx_size, e);
77   CeedCallSycl(ceed, copy_event.wait_and_throw());
78   return CEED_ERROR_SUCCESS;
79 }
80 
81 //------------------------------------------------------------------------------
82 // Sync data of type
83 //------------------------------------------------------------------------------
84 static inline int CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type) {
85   switch (mem_type) {
86     case CEED_MEM_HOST:
87       return CeedQFunctionContextSyncD2H_Sycl(ctx);
88     case CEED_MEM_DEVICE:
89       return CeedQFunctionContextSyncH2D_Sycl(ctx);
90   }
91   return CEED_ERROR_UNSUPPORTED;
92 }
93 
94 //------------------------------------------------------------------------------
95 // Set all pointers as invalid
96 //------------------------------------------------------------------------------
97 static inline int CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx) {
98   CeedQFunctionContext_Sycl *impl;
99 
100   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
101   impl->h_data = NULL;
102   impl->d_data = NULL;
103   return CEED_ERROR_SUCCESS;
104 }
105 
106 //------------------------------------------------------------------------------
107 // Check if ctx has valid data
108 //------------------------------------------------------------------------------
109 static inline int CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx, bool *has_valid_data) {
110   CeedQFunctionContext_Sycl *impl;
111 
112   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
113   *has_valid_data = impl && (impl->h_data || impl->d_data);
114   return CEED_ERROR_SUCCESS;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Check if ctx has borrowed data
119 //------------------------------------------------------------------------------
120 static inline int CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type,
121                                                                  bool *has_borrowed_data_of_type) {
122   CeedQFunctionContext_Sycl *impl;
123 
124   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
125   switch (mem_type) {
126     case CEED_MEM_HOST:
127       *has_borrowed_data_of_type = impl->h_data_borrowed;
128       break;
129     case CEED_MEM_DEVICE:
130       *has_borrowed_data_of_type = impl->d_data_borrowed;
131       break;
132   }
133   return CEED_ERROR_SUCCESS;
134 }
135 
136 //------------------------------------------------------------------------------
137 // Check if data of given type needs sync
138 //------------------------------------------------------------------------------
139 static inline int CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
140   bool                       has_valid_data = true;
141   CeedQFunctionContext_Sycl *impl;
142 
143   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
144   CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
145   switch (mem_type) {
146     case CEED_MEM_HOST:
147       *need_sync = has_valid_data && !impl->h_data;
148       break;
149     case CEED_MEM_DEVICE:
150       *need_sync = has_valid_data && !impl->d_data;
151       break;
152   }
153   return CEED_ERROR_SUCCESS;
154 }
155 
156 //------------------------------------------------------------------------------
157 // Set data from host
158 //------------------------------------------------------------------------------
159 static int CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
160   CeedQFunctionContext_Sycl *impl;
161 
162   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
163   CeedCallBackend(CeedFree(&impl->h_data_owned));
164   switch (copy_mode) {
165     case CEED_COPY_VALUES:
166       size_t ctx_size;
167 
168       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
169       CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
170       impl->h_data_borrowed = NULL;
171       impl->h_data          = impl->h_data_owned;
172       memcpy(impl->h_data, data, ctx_size);
173       break;
174     case CEED_OWN_POINTER:
175       impl->h_data_owned    = data;
176       impl->h_data_borrowed = NULL;
177       impl->h_data          = data;
178       break;
179     case CEED_USE_POINTER:
180       impl->h_data_borrowed = data;
181       impl->h_data          = data;
182       break;
183   }
184   return CEED_ERROR_SUCCESS;
185 }
186 
187 //------------------------------------------------------------------------------
188 // Set data from device
189 //------------------------------------------------------------------------------
190 static int CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
191   Ceed                       ceed;
192   Ceed_Sycl                 *sycl_data;
193   CeedQFunctionContext_Sycl *impl;
194 
195   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
196   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
197   CeedCallBackend(CeedGetData(ceed, &sycl_data));
198 
199   std::vector<sycl::event> e;
200 
201   if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()};
202 
203   // Wait for all work to finish before freeing memory
204   if (impl->d_data_owned) {
205     CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
206     CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
207     impl->d_data_owned = NULL;
208   }
209 
210   switch (copy_mode) {
211     case CEED_COPY_VALUES: {
212       size_t ctx_size;
213 
214       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
215       CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context));
216       impl->d_data_borrowed  = NULL;
217       impl->d_data           = impl->d_data_owned;
218       sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, data, ctx_size, e);
219       CeedCallSycl(ceed, copy_event.wait_and_throw());
220     } break;
221     case CEED_OWN_POINTER: {
222       impl->d_data_owned    = data;
223       impl->d_data_borrowed = NULL;
224       impl->d_data          = data;
225     } break;
226     case CEED_USE_POINTER: {
227       impl->d_data_owned    = NULL;
228       impl->d_data_borrowed = data;
229       impl->d_data          = data;
230     } break;
231   }
232   return CEED_ERROR_SUCCESS;
233 }
234 
235 //------------------------------------------------------------------------------
236 // Set the data used by a user context,
237 //   freeing any previously allocated data if applicable
238 //------------------------------------------------------------------------------
239 static int CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
240   Ceed ceed;
241 
242   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
243   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
244   switch (mem_type) {
245     case CEED_MEM_HOST:
246       return CeedQFunctionContextSetDataHost_Sycl(ctx, copy_mode, data);
247     case CEED_MEM_DEVICE:
248       return CeedQFunctionContextSetDataDevice_Sycl(ctx, copy_mode, data);
249   }
250   return CEED_ERROR_UNSUPPORTED;
251 }
252 
253 //------------------------------------------------------------------------------
254 // Take data
255 //------------------------------------------------------------------------------
256 static int CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
257   Ceed                       ceed;
258   Ceed_Sycl                 *ceedSycl;
259   bool                       need_sync = false;
260   CeedQFunctionContext_Sycl *impl;
261 
262   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
263   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
264   CeedCallBackend(CeedGetData(ceed, &ceedSycl));
265 
266   // Order queue if needed
267   if (!ceedSycl->sycl_queue.is_in_order()) ceedSycl->sycl_queue.ext_oneapi_submit_barrier();
268 
269   // Sync data to requested mem_type
270   CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
271   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
272 
273   // Update pointer
274   switch (mem_type) {
275     case CEED_MEM_HOST:
276       *(void **)data        = impl->h_data_borrowed;
277       impl->h_data_borrowed = NULL;
278       impl->h_data          = NULL;
279       break;
280     case CEED_MEM_DEVICE:
281       *(void **)data        = impl->d_data_borrowed;
282       impl->d_data_borrowed = NULL;
283       impl->d_data          = NULL;
284       break;
285   }
286   return CEED_ERROR_SUCCESS;
287 }
288 
289 //------------------------------------------------------------------------------
290 // Core logic for GetData.
291 //   If a different memory type is most up to date, this will perform a copy
292 //------------------------------------------------------------------------------
293 static int CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
294   Ceed                       ceed;
295   bool                       need_sync = false;
296   CeedQFunctionContext_Sycl *impl;
297 
298   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
299   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
300 
301   // Sync data to requested mem_type
302   CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
303   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
304 
305   // Update pointer
306   switch (mem_type) {
307     case CEED_MEM_HOST:
308       *(void **)data = impl->h_data;
309       break;
310     case CEED_MEM_DEVICE:
311       *(void **)data = impl->d_data;
312       break;
313   }
314   return CEED_ERROR_SUCCESS;
315 }
316 
317 //------------------------------------------------------------------------------
318 // Get read-only access to the data
319 //------------------------------------------------------------------------------
320 static int CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
321   return CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data);
322 }
323 
324 //------------------------------------------------------------------------------
325 // Get read/write access to the data
326 //------------------------------------------------------------------------------
327 static int CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
328   Ceed                       ceed;
329   CeedQFunctionContext_Sycl *impl;
330 
331   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
332   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
333   CeedCallBackend(CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data));
334 
335   // Mark only pointer for requested memory as valid
336   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
337   switch (mem_type) {
338     case CEED_MEM_HOST:
339       impl->h_data = *(void **)data;
340       break;
341     case CEED_MEM_DEVICE:
342       impl->d_data = *(void **)data;
343       break;
344   }
345   return CEED_ERROR_SUCCESS;
346 }
347 
348 //------------------------------------------------------------------------------
349 // Destroy the user context
350 //------------------------------------------------------------------------------
351 static int CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx) {
352   Ceed                       ceed;
353   Ceed_Sycl                 *sycl_data;
354   CeedQFunctionContext_Sycl *impl;
355 
356   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
357   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
358   CeedCallBackend(CeedGetData(ceed, &sycl_data));
359 
360   // Wait for all work to finish before freeing memory
361   CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
362   CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
363   CeedCallBackend(CeedFree(&impl->h_data_owned));
364   CeedCallBackend(CeedFree(&impl));
365   return CEED_ERROR_SUCCESS;
366 }
367 
368 //------------------------------------------------------------------------------
369 // QFunctionContext Create
370 //------------------------------------------------------------------------------
371 int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx) {
372   Ceed                       ceed;
373   CeedQFunctionContext_Sycl *impl;
374 
375   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
376   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Sycl));
377   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Sycl));
378   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Sycl));
379   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Sycl));
380   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Sycl));
381   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Sycl));
382   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Sycl));
383   CeedCallBackend(CeedCalloc(1, &impl));
384   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
385   return CEED_ERROR_SUCCESS;
386 }
387 
388 //------------------------------------------------------------------------------
389