xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-ref/ceed-sycl-ref-qfunctioncontext.sycl.cpp (revision 9ba83ac0e4b1fca39d6fa6737a318a9f0cbc172d)
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 
11 #include <string>
12 #include <sycl/sycl.hpp>
13 
14 #include "ceed-sycl-ref.hpp"
15 
16 //------------------------------------------------------------------------------
17 // Sync host to device
18 //------------------------------------------------------------------------------
19 static inline int CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx) {
20   Ceed                       ceed;
21   Ceed_Sycl                 *sycl_data;
22   size_t                     ctx_size;
23   CeedQFunctionContext_Sycl *impl;
24 
25   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
26   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
27   CeedCallBackend(CeedGetData(ceed, &sycl_data));
28   CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29 
30   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31 
32   if (impl->d_data_borrowed) {
33     impl->d_data = impl->d_data_borrowed;
34   } else if (impl->d_data_owned) {
35     impl->d_data = impl->d_data_owned;
36   } else {
37     CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context));
38     impl->d_data = impl->d_data_owned;
39   }
40   std::vector<sycl::event> e;
41 
42   if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()};
43   sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, impl->h_data, ctx_size, e);
44   CeedCallSycl(ceed, copy_event.wait_and_throw());
45   CeedCallBackend(CeedDestroy(&ceed));
46   return CEED_ERROR_SUCCESS;
47 }
48 
49 //------------------------------------------------------------------------------
50 // Sync device to host
51 //------------------------------------------------------------------------------
52 static inline int CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx) {
53   Ceed                       ceed;
54   Ceed_Sycl                 *sycl_data;
55   size_t                     ctx_size;
56   CeedQFunctionContext_Sycl *impl;
57 
58   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
59   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
60   CeedCallBackend(CeedGetData(ceed, &sycl_data));
61   CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
62 
63   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
64 
65   if (impl->h_data_borrowed) {
66     impl->h_data = impl->h_data_borrowed;
67   } else if (impl->h_data_owned) {
68     impl->h_data = impl->h_data_owned;
69   } else {
70     CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
71     impl->h_data = impl->h_data_owned;
72   }
73 
74   std::vector<sycl::event> e;
75 
76   if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()};
77   sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->h_data, impl->d_data, ctx_size, e);
78   CeedCallSycl(ceed, copy_event.wait_and_throw());
79   CeedCallBackend(CeedDestroy(&ceed));
80   return CEED_ERROR_SUCCESS;
81 }
82 
83 //------------------------------------------------------------------------------
84 // Sync data of type
85 //------------------------------------------------------------------------------
86 static inline int CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type) {
87   switch (mem_type) {
88     case CEED_MEM_HOST:
89       return CeedQFunctionContextSyncD2H_Sycl(ctx);
90     case CEED_MEM_DEVICE:
91       return CeedQFunctionContextSyncH2D_Sycl(ctx);
92   }
93   return CEED_ERROR_UNSUPPORTED;
94 }
95 
96 //------------------------------------------------------------------------------
97 // Set all pointers as invalid
98 //------------------------------------------------------------------------------
99 static inline int CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx) {
100   CeedQFunctionContext_Sycl *impl;
101 
102   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
103   impl->h_data = NULL;
104   impl->d_data = NULL;
105   return CEED_ERROR_SUCCESS;
106 }
107 
108 //------------------------------------------------------------------------------
109 // Check if ctx has valid data
110 //------------------------------------------------------------------------------
111 static inline int CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx, bool *has_valid_data) {
112   CeedQFunctionContext_Sycl *impl;
113 
114   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
115   *has_valid_data = impl && (impl->h_data || impl->d_data);
116   return CEED_ERROR_SUCCESS;
117 }
118 
119 //------------------------------------------------------------------------------
120 // Check if ctx has borrowed data
121 //------------------------------------------------------------------------------
122 static inline int CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type,
123                                                                  bool *has_borrowed_data_of_type) {
124   CeedQFunctionContext_Sycl *impl;
125 
126   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
127   switch (mem_type) {
128     case CEED_MEM_HOST:
129       *has_borrowed_data_of_type = impl->h_data_borrowed;
130       break;
131     case CEED_MEM_DEVICE:
132       *has_borrowed_data_of_type = impl->d_data_borrowed;
133       break;
134   }
135   return CEED_ERROR_SUCCESS;
136 }
137 
138 //------------------------------------------------------------------------------
139 // Check if data of given type needs sync
140 //------------------------------------------------------------------------------
141 static inline int CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
142   bool                       has_valid_data = true;
143   CeedQFunctionContext_Sycl *impl;
144 
145   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
146   CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
147   switch (mem_type) {
148     case CEED_MEM_HOST:
149       *need_sync = has_valid_data && !impl->h_data;
150       break;
151     case CEED_MEM_DEVICE:
152       *need_sync = has_valid_data && !impl->d_data;
153       break;
154   }
155   return CEED_ERROR_SUCCESS;
156 }
157 
158 //------------------------------------------------------------------------------
159 // Set data from host
160 //------------------------------------------------------------------------------
161 static int CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
162   CeedQFunctionContext_Sycl *impl;
163 
164   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
165   CeedCallBackend(CeedFree(&impl->h_data_owned));
166   switch (copy_mode) {
167     case CEED_COPY_VALUES:
168       size_t ctx_size;
169 
170       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
171       CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
172       impl->h_data_borrowed = NULL;
173       impl->h_data          = impl->h_data_owned;
174       memcpy(impl->h_data, data, ctx_size);
175       break;
176     case CEED_OWN_POINTER:
177       impl->h_data_owned    = data;
178       impl->h_data_borrowed = NULL;
179       impl->h_data          = data;
180       break;
181     case CEED_USE_POINTER:
182       impl->h_data_borrowed = data;
183       impl->h_data          = data;
184       break;
185   }
186   return CEED_ERROR_SUCCESS;
187 }
188 
189 //------------------------------------------------------------------------------
190 // Set data from device
191 //------------------------------------------------------------------------------
192 static int CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
193   Ceed                       ceed;
194   Ceed_Sycl                 *sycl_data;
195   CeedQFunctionContext_Sycl *impl;
196 
197   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
198   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
199   CeedCallBackend(CeedGetData(ceed, &sycl_data));
200 
201   std::vector<sycl::event> e;
202 
203   if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()};
204 
205   // Wait for all work to finish before freeing memory
206   if (impl->d_data_owned) {
207     CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
208     CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
209     impl->d_data_owned = NULL;
210   }
211 
212   switch (copy_mode) {
213     case CEED_COPY_VALUES: {
214       size_t ctx_size;
215 
216       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
217       CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context));
218       impl->d_data_borrowed  = NULL;
219       impl->d_data           = impl->d_data_owned;
220       sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, data, ctx_size, e);
221       CeedCallSycl(ceed, copy_event.wait_and_throw());
222     } break;
223     case CEED_OWN_POINTER: {
224       impl->d_data_owned    = data;
225       impl->d_data_borrowed = NULL;
226       impl->d_data          = data;
227     } break;
228     case CEED_USE_POINTER: {
229       impl->d_data_owned    = NULL;
230       impl->d_data_borrowed = data;
231       impl->d_data          = data;
232     } break;
233   }
234   CeedCallBackend(CeedDestroy(&ceed));
235   return CEED_ERROR_SUCCESS;
236 }
237 
238 //------------------------------------------------------------------------------
239 // Set the data used by a user context,
240 //   freeing any previously allocated data if applicable
241 //------------------------------------------------------------------------------
242 static int CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
243   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
244   switch (mem_type) {
245     case CEED_MEM_HOST:
246       return CeedQFunctionContextSetDataHost_Sycl(ctx, copy_mode, data);
247     case CEED_MEM_DEVICE:
248       return CeedQFunctionContextSetDataDevice_Sycl(ctx, copy_mode, data);
249   }
250   return CEED_ERROR_UNSUPPORTED;
251 }
252 
253 //------------------------------------------------------------------------------
254 // Take data
255 //------------------------------------------------------------------------------
256 static int CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
257   Ceed                       ceed;
258   Ceed_Sycl                 *ceedSycl;
259   bool                       need_sync = false;
260   CeedQFunctionContext_Sycl *impl;
261 
262   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
263   CeedCallBackend(CeedGetData(ceed, &ceedSycl));
264   CeedCallBackend(CeedDestroy(&ceed));
265   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
266 
267   // Order queue if needed
268   if (!ceedSycl->sycl_queue.is_in_order()) ceedSycl->sycl_queue.ext_oneapi_submit_barrier();
269 
270   // Sync data to requested mem_type
271   CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
272   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
273 
274   // Update pointer
275   switch (mem_type) {
276     case CEED_MEM_HOST:
277       *(void **)data        = impl->h_data_borrowed;
278       impl->h_data_borrowed = NULL;
279       impl->h_data          = NULL;
280       break;
281     case CEED_MEM_DEVICE:
282       *(void **)data        = impl->d_data_borrowed;
283       impl->d_data_borrowed = NULL;
284       impl->d_data          = NULL;
285       break;
286   }
287   return CEED_ERROR_SUCCESS;
288 }
289 
290 //------------------------------------------------------------------------------
291 // Core logic for GetData.
292 //   If a different memory type is most up to date, this will perform a copy
293 //------------------------------------------------------------------------------
294 static int CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
295   bool                       need_sync = false;
296   CeedQFunctionContext_Sycl *impl;
297 
298   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
299 
300   // Sync data to requested mem_type
301   CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
302   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
303 
304   // Update pointer
305   switch (mem_type) {
306     case CEED_MEM_HOST:
307       *(void **)data = impl->h_data;
308       break;
309     case CEED_MEM_DEVICE:
310       *(void **)data = impl->d_data;
311       break;
312   }
313   return CEED_ERROR_SUCCESS;
314 }
315 
316 //------------------------------------------------------------------------------
317 // Get read-only access to the data
318 //------------------------------------------------------------------------------
319 static int CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
320   return CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data);
321 }
322 
323 //------------------------------------------------------------------------------
324 // Get read/write access to the data
325 //------------------------------------------------------------------------------
326 static int CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
327   CeedQFunctionContext_Sycl *impl;
328 
329   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
330   CeedCallBackend(CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data));
331 
332   // Mark only pointer for requested memory as valid
333   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
334   switch (mem_type) {
335     case CEED_MEM_HOST:
336       impl->h_data = *(void **)data;
337       break;
338     case CEED_MEM_DEVICE:
339       impl->d_data = *(void **)data;
340       break;
341   }
342   return CEED_ERROR_SUCCESS;
343 }
344 
345 //------------------------------------------------------------------------------
346 // Destroy the user context
347 //------------------------------------------------------------------------------
348 static int CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx) {
349   Ceed                       ceed;
350   Ceed_Sycl                 *sycl_data;
351   CeedQFunctionContext_Sycl *impl;
352 
353   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
354   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
355   CeedCallBackend(CeedGetData(ceed, &sycl_data));
356 
357   // Wait for all work to finish before freeing memory
358   CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
359   CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
360   CeedCallBackend(CeedDestroy(&ceed));
361   CeedCallBackend(CeedFree(&impl->h_data_owned));
362   CeedCallBackend(CeedFree(&impl));
363   return CEED_ERROR_SUCCESS;
364 }
365 
366 //------------------------------------------------------------------------------
367 // QFunctionContext Create
368 //------------------------------------------------------------------------------
369 int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx) {
370   Ceed                       ceed;
371   CeedQFunctionContext_Sycl *impl;
372 
373   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
374   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Sycl));
375   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Sycl));
376   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Sycl));
377   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Sycl));
378   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Sycl));
379   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Sycl));
380   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Sycl));
381   CeedCallBackend(CeedDestroy(&ceed));
382   CeedCallBackend(CeedCalloc(1, &impl));
383   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
384   return CEED_ERROR_SUCCESS;
385 }
386 
387 //------------------------------------------------------------------------------
388