1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10
11 #include <string>
12 #include <sycl/sycl.hpp>
13
14 #include "ceed-sycl-ref.hpp"
15
16 //------------------------------------------------------------------------------
17 // Sync host to device
18 //------------------------------------------------------------------------------
CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx)19 static inline int CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx) {
20 Ceed ceed;
21 Ceed_Sycl *sycl_data;
22 size_t ctx_size;
23 CeedQFunctionContext_Sycl *impl;
24
25 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
26 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
27 CeedCallBackend(CeedGetData(ceed, &sycl_data));
28 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
29
30 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
31
32 if (impl->d_data_borrowed) {
33 impl->d_data = impl->d_data_borrowed;
34 } else if (impl->d_data_owned) {
35 impl->d_data = impl->d_data_owned;
36 } else {
37 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context));
38 impl->d_data = impl->d_data_owned;
39 }
40 std::vector<sycl::event> e;
41
42 if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()};
43 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, impl->h_data, ctx_size, e);
44 CeedCallSycl(ceed, copy_event.wait_and_throw());
45 CeedCallBackend(CeedDestroy(&ceed));
46 return CEED_ERROR_SUCCESS;
47 }
48
49 //------------------------------------------------------------------------------
50 // Sync device to host
51 //------------------------------------------------------------------------------
CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx)52 static inline int CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx) {
53 Ceed ceed;
54 Ceed_Sycl *sycl_data;
55 size_t ctx_size;
56 CeedQFunctionContext_Sycl *impl;
57
58 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
59 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
60 CeedCallBackend(CeedGetData(ceed, &sycl_data));
61 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
62
63 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
64
65 if (impl->h_data_borrowed) {
66 impl->h_data = impl->h_data_borrowed;
67 } else if (impl->h_data_owned) {
68 impl->h_data = impl->h_data_owned;
69 } else {
70 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
71 impl->h_data = impl->h_data_owned;
72 }
73
74 std::vector<sycl::event> e;
75
76 if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()};
77 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->h_data, impl->d_data, ctx_size, e);
78 CeedCallSycl(ceed, copy_event.wait_and_throw());
79 CeedCallBackend(CeedDestroy(&ceed));
80 return CEED_ERROR_SUCCESS;
81 }
82
83 //------------------------------------------------------------------------------
84 // Sync data of type
85 //------------------------------------------------------------------------------
CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx,CeedMemType mem_type)86 static inline int CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type) {
87 switch (mem_type) {
88 case CEED_MEM_HOST:
89 return CeedQFunctionContextSyncD2H_Sycl(ctx);
90 case CEED_MEM_DEVICE:
91 return CeedQFunctionContextSyncH2D_Sycl(ctx);
92 }
93 return CEED_ERROR_UNSUPPORTED;
94 }
95
96 //------------------------------------------------------------------------------
97 // Set all pointers as invalid
98 //------------------------------------------------------------------------------
CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx)99 static inline int CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx) {
100 CeedQFunctionContext_Sycl *impl;
101
102 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
103 impl->h_data = NULL;
104 impl->d_data = NULL;
105 return CEED_ERROR_SUCCESS;
106 }
107
108 //------------------------------------------------------------------------------
109 // Check if ctx has valid data
110 //------------------------------------------------------------------------------
CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx,bool * has_valid_data)111 static inline int CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx, bool *has_valid_data) {
112 CeedQFunctionContext_Sycl *impl;
113
114 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
115 *has_valid_data = impl && (impl->h_data || impl->d_data);
116 return CEED_ERROR_SUCCESS;
117 }
118
119 //------------------------------------------------------------------------------
120 // Check if ctx has borrowed data
121 //------------------------------------------------------------------------------
CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx,CeedMemType mem_type,bool * has_borrowed_data_of_type)122 static inline int CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type,
123 bool *has_borrowed_data_of_type) {
124 CeedQFunctionContext_Sycl *impl;
125
126 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
127 switch (mem_type) {
128 case CEED_MEM_HOST:
129 *has_borrowed_data_of_type = impl->h_data_borrowed;
130 break;
131 case CEED_MEM_DEVICE:
132 *has_borrowed_data_of_type = impl->d_data_borrowed;
133 break;
134 }
135 return CEED_ERROR_SUCCESS;
136 }
137
138 //------------------------------------------------------------------------------
139 // Check if data of given type needs sync
140 //------------------------------------------------------------------------------
CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx,CeedMemType mem_type,bool * need_sync)141 static inline int CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
142 bool has_valid_data = true;
143 CeedQFunctionContext_Sycl *impl;
144
145 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
146 CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
147 switch (mem_type) {
148 case CEED_MEM_HOST:
149 *need_sync = has_valid_data && !impl->h_data;
150 break;
151 case CEED_MEM_DEVICE:
152 *need_sync = has_valid_data && !impl->d_data;
153 break;
154 }
155 return CEED_ERROR_SUCCESS;
156 }
157
158 //------------------------------------------------------------------------------
159 // Set data from host
160 //------------------------------------------------------------------------------
CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx,const CeedCopyMode copy_mode,void * data)161 static int CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
162 CeedQFunctionContext_Sycl *impl;
163
164 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
165 CeedCallBackend(CeedFree(&impl->h_data_owned));
166 switch (copy_mode) {
167 case CEED_COPY_VALUES:
168 size_t ctx_size;
169
170 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
171 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
172 impl->h_data_borrowed = NULL;
173 impl->h_data = impl->h_data_owned;
174 memcpy(impl->h_data, data, ctx_size);
175 break;
176 case CEED_OWN_POINTER:
177 impl->h_data_owned = data;
178 impl->h_data_borrowed = NULL;
179 impl->h_data = data;
180 break;
181 case CEED_USE_POINTER:
182 impl->h_data_borrowed = data;
183 impl->h_data = data;
184 break;
185 }
186 return CEED_ERROR_SUCCESS;
187 }
188
189 //------------------------------------------------------------------------------
190 // Set data from device
191 //------------------------------------------------------------------------------
CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx,const CeedCopyMode copy_mode,void * data)192 static int CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
193 Ceed ceed;
194 Ceed_Sycl *sycl_data;
195 CeedQFunctionContext_Sycl *impl;
196
197 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
198 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
199 CeedCallBackend(CeedGetData(ceed, &sycl_data));
200
201 std::vector<sycl::event> e;
202
203 if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()};
204
205 // Wait for all work to finish before freeing memory
206 if (impl->d_data_owned) {
207 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
208 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
209 impl->d_data_owned = NULL;
210 }
211
212 switch (copy_mode) {
213 case CEED_COPY_VALUES: {
214 size_t ctx_size;
215
216 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
217 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context));
218 impl->d_data_borrowed = NULL;
219 impl->d_data = impl->d_data_owned;
220 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, data, ctx_size, e);
221 CeedCallSycl(ceed, copy_event.wait_and_throw());
222 } break;
223 case CEED_OWN_POINTER: {
224 impl->d_data_owned = data;
225 impl->d_data_borrowed = NULL;
226 impl->d_data = data;
227 } break;
228 case CEED_USE_POINTER: {
229 impl->d_data_owned = NULL;
230 impl->d_data_borrowed = data;
231 impl->d_data = data;
232 } break;
233 }
234 CeedCallBackend(CeedDestroy(&ceed));
235 return CEED_ERROR_SUCCESS;
236 }
237
238 //------------------------------------------------------------------------------
239 // Set the data used by a user context,
240 // freeing any previously allocated data if applicable
241 //------------------------------------------------------------------------------
CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx,const CeedMemType mem_type,const CeedCopyMode copy_mode,void * data)242 static int CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
243 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
244 switch (mem_type) {
245 case CEED_MEM_HOST:
246 return CeedQFunctionContextSetDataHost_Sycl(ctx, copy_mode, data);
247 case CEED_MEM_DEVICE:
248 return CeedQFunctionContextSetDataDevice_Sycl(ctx, copy_mode, data);
249 }
250 return CEED_ERROR_UNSUPPORTED;
251 }
252
253 //------------------------------------------------------------------------------
254 // Take data
255 //------------------------------------------------------------------------------
CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)256 static int CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
257 Ceed ceed;
258 Ceed_Sycl *ceedSycl;
259 bool need_sync = false;
260 CeedQFunctionContext_Sycl *impl;
261
262 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
263 CeedCallBackend(CeedGetData(ceed, &ceedSycl));
264 CeedCallBackend(CeedDestroy(&ceed));
265 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
266
267 // Order queue if needed
268 if (!ceedSycl->sycl_queue.is_in_order()) ceedSycl->sycl_queue.ext_oneapi_submit_barrier();
269
270 // Sync data to requested mem_type
271 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
272 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
273
274 // Update pointer
275 switch (mem_type) {
276 case CEED_MEM_HOST:
277 *(void **)data = impl->h_data_borrowed;
278 impl->h_data_borrowed = NULL;
279 impl->h_data = NULL;
280 break;
281 case CEED_MEM_DEVICE:
282 *(void **)data = impl->d_data_borrowed;
283 impl->d_data_borrowed = NULL;
284 impl->d_data = NULL;
285 break;
286 }
287 return CEED_ERROR_SUCCESS;
288 }
289
290 //------------------------------------------------------------------------------
291 // Core logic for GetData.
292 // If a different memory type is most up to date, this will perform a copy
293 //------------------------------------------------------------------------------
CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)294 static int CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
295 bool need_sync = false;
296 CeedQFunctionContext_Sycl *impl;
297
298 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
299
300 // Sync data to requested mem_type
301 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
302 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
303
304 // Update pointer
305 switch (mem_type) {
306 case CEED_MEM_HOST:
307 *(void **)data = impl->h_data;
308 break;
309 case CEED_MEM_DEVICE:
310 *(void **)data = impl->d_data;
311 break;
312 }
313 return CEED_ERROR_SUCCESS;
314 }
315
316 //------------------------------------------------------------------------------
317 // Get read-only access to the data
318 //------------------------------------------------------------------------------
CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)319 static int CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
320 return CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data);
321 }
322
323 //------------------------------------------------------------------------------
324 // Get read/write access to the data
325 //------------------------------------------------------------------------------
CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx,const CeedMemType mem_type,void * data)326 static int CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
327 CeedQFunctionContext_Sycl *impl;
328
329 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
330 CeedCallBackend(CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data));
331
332 // Mark only pointer for requested memory as valid
333 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
334 switch (mem_type) {
335 case CEED_MEM_HOST:
336 impl->h_data = *(void **)data;
337 break;
338 case CEED_MEM_DEVICE:
339 impl->d_data = *(void **)data;
340 break;
341 }
342 return CEED_ERROR_SUCCESS;
343 }
344
345 //------------------------------------------------------------------------------
346 // Destroy the user context
347 //------------------------------------------------------------------------------
CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx)348 static int CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx) {
349 Ceed ceed;
350 Ceed_Sycl *sycl_data;
351 CeedQFunctionContext_Sycl *impl;
352
353 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
354 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
355 CeedCallBackend(CeedGetData(ceed, &sycl_data));
356
357 // Wait for all work to finish before freeing memory
358 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
359 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
360 CeedCallBackend(CeedDestroy(&ceed));
361 CeedCallBackend(CeedFree(&impl->h_data_owned));
362 CeedCallBackend(CeedFree(&impl));
363 return CEED_ERROR_SUCCESS;
364 }
365
366 //------------------------------------------------------------------------------
367 // QFunctionContext Create
368 //------------------------------------------------------------------------------
CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx)369 int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx) {
370 Ceed ceed;
371 CeedQFunctionContext_Sycl *impl;
372
373 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
374 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Sycl));
375 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Sycl));
376 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Sycl));
377 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Sycl));
378 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Sycl));
379 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Sycl));
380 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Sycl));
381 CeedCallBackend(CeedDestroy(&ceed));
382 CeedCallBackend(CeedCalloc(1, &impl));
383 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
384 return CEED_ERROR_SUCCESS;
385 }
386
387 //------------------------------------------------------------------------------
388