xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-qfunctioncontext.sycl.cpp (revision b94338b90626f5663d73dd1094a9176150930a44)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 
11 #include <string>
12 #include <sycl/sycl.hpp>
13 
14 #include "ceed-sycl-ref.hpp"
15 
16 //------------------------------------------------------------------------------
17 // Sync host to device
18 //------------------------------------------------------------------------------
19 static inline int CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx) {
20   CeedQFunctionContext_Sycl *impl;
21   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
22   Ceed ceed;
23   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
24   Ceed_Sycl *sycl_data;
25   CeedCallBackend(CeedGetData(ceed, &sycl_data));
26 
27   if (!impl->h_data) {
28     // LCOV_EXCL_START
29     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
30     // LCOV_EXCL_STOP
31   }
32 
33   size_t ctxsize;
34   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
35 
36   if (impl->d_data_borrowed) {
37     impl->d_data = impl->d_data_borrowed;
38   } else if (impl->d_data_owned) {
39     impl->d_data = impl->d_data_owned;
40   } else {
41     CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctxsize, sycl_data->sycl_device, sycl_data->sycl_context));
42     impl->d_data = impl->d_data_owned;
43   }
44   // Order queue
45   sycl::event e          = sycl_data->sycl_queue.ext_oneapi_submit_barrier();
46   sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, impl->h_data, ctxsize, {e});
47   CeedCallSycl(ceed, copy_event.wait_and_throw());
48 
49   return CEED_ERROR_SUCCESS;
50 }
51 
52 //------------------------------------------------------------------------------
53 // Sync device to host
54 //------------------------------------------------------------------------------
55 static inline int CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx) {
56   CeedQFunctionContext_Sycl *impl;
57   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
58   Ceed ceed;
59   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
60   Ceed_Sycl *sycl_data;
61   CeedCallBackend(CeedGetData(ceed, &sycl_data));
62 
63   if (!impl->d_data) {
64     // LCOV_EXCL_START
65     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
66     // LCOV_EXCL_STOP
67   }
68 
69   size_t ctxsize;
70   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
71 
72   if (impl->h_data_borrowed) {
73     impl->h_data = impl->h_data_borrowed;
74   } else if (impl->h_data_owned) {
75     impl->h_data = impl->h_data_owned;
76   } else {
77     CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
78     impl->h_data = impl->h_data_owned;
79   }
80 
81   // Order queue
82   sycl::event e          = sycl_data->sycl_queue.ext_oneapi_submit_barrier();
83   sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->h_data, impl->d_data, ctxsize, {e});
84   CeedCallSycl(ceed, copy_event.wait_and_throw());
85 
86   return CEED_ERROR_SUCCESS;
87 }
88 
89 //------------------------------------------------------------------------------
90 // Sync data of type
91 //------------------------------------------------------------------------------
92 static inline int CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type) {
93   switch (mem_type) {
94     case CEED_MEM_HOST:
95       return CeedQFunctionContextSyncD2H_Sycl(ctx);
96     case CEED_MEM_DEVICE:
97       return CeedQFunctionContextSyncH2D_Sycl(ctx);
98   }
99   return CEED_ERROR_UNSUPPORTED;
100 }
101 
102 //------------------------------------------------------------------------------
103 // Set all pointers as invalid
104 //------------------------------------------------------------------------------
105 static inline int CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx) {
106   CeedQFunctionContext_Sycl *impl;
107   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
108 
109   impl->h_data = NULL;
110   impl->d_data = NULL;
111 
112   return CEED_ERROR_SUCCESS;
113 }
114 
115 //------------------------------------------------------------------------------
116 // Check if ctx has valid data
117 //------------------------------------------------------------------------------
118 static inline int CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx, bool *has_valid_data) {
119   CeedQFunctionContext_Sycl *impl;
120   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
121 
122   *has_valid_data = impl && (!!impl->h_data || !!impl->d_data);
123 
124   return CEED_ERROR_SUCCESS;
125 }
126 
127 //------------------------------------------------------------------------------
128 // Check if ctx has borrowed data
129 //------------------------------------------------------------------------------
130 static inline int CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type,
131                                                                  bool *has_borrowed_data_of_type) {
132   CeedQFunctionContext_Sycl *impl;
133   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
134 
135   switch (mem_type) {
136     case CEED_MEM_HOST:
137       *has_borrowed_data_of_type = !!impl->h_data_borrowed;
138       break;
139     case CEED_MEM_DEVICE:
140       *has_borrowed_data_of_type = !!impl->d_data_borrowed;
141       break;
142   }
143 
144   return CEED_ERROR_SUCCESS;
145 }
146 
147 //------------------------------------------------------------------------------
148 // Check if data of given type needs sync
149 //------------------------------------------------------------------------------
150 static inline int CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
151   CeedQFunctionContext_Sycl *impl;
152   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
153 
154   bool has_valid_data = true;
155   CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
156   switch (mem_type) {
157     case CEED_MEM_HOST:
158       *need_sync = has_valid_data && !impl->h_data;
159       break;
160     case CEED_MEM_DEVICE:
161       *need_sync = has_valid_data && !impl->d_data;
162       break;
163   }
164 
165   return CEED_ERROR_SUCCESS;
166 }
167 
168 //------------------------------------------------------------------------------
169 // Set data from host
170 //------------------------------------------------------------------------------
171 static int CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
172   CeedQFunctionContext_Sycl *impl;
173   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
174 
175   CeedCallBackend(CeedFree(&impl->h_data_owned));
176   switch (copy_mode) {
177     case CEED_COPY_VALUES:
178       size_t ctxsize;
179       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
180       CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned));
181       impl->h_data_borrowed = NULL;
182       impl->h_data          = impl->h_data_owned;
183       memcpy(impl->h_data, data, ctxsize);
184       break;
185     case CEED_OWN_POINTER:
186       impl->h_data_owned    = data;
187       impl->h_data_borrowed = NULL;
188       impl->h_data          = data;
189       break;
190     case CEED_USE_POINTER:
191       impl->h_data_borrowed = data;
192       impl->h_data          = data;
193       break;
194   }
195 
196   return CEED_ERROR_SUCCESS;
197 }
198 
199 //------------------------------------------------------------------------------
200 // Set data from device
201 //------------------------------------------------------------------------------
202 static int CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
203   CeedQFunctionContext_Sycl *impl;
204   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
205   Ceed ceed;
206   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
207   Ceed_Sycl *sycl_data;
208   CeedCallBackend(CeedGetData(ceed, &sycl_data));
209 
210   // Order queue
211   sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier();
212 
213   // Wait for all work to finish before freeing memory
214   if (impl->d_data_owned) {
215     CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
216     CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
217     impl->d_data_owned = NULL;
218   }
219 
220   switch (copy_mode) {
221     case CEED_COPY_VALUES: {
222       size_t ctxsize;
223       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize));
224       CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctxsize, sycl_data->sycl_device, sycl_data->sycl_context));
225       impl->d_data_borrowed  = NULL;
226       impl->d_data           = impl->d_data_owned;
227       sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, data, ctxsize, {e});
228       CeedCallSycl(ceed, copy_event.wait_and_throw());
229     } break;
230     case CEED_OWN_POINTER: {
231       impl->d_data_owned    = data;
232       impl->d_data_borrowed = NULL;
233       impl->d_data          = data;
234     } break;
235     case CEED_USE_POINTER: {
236       impl->d_data_owned    = NULL;
237       impl->d_data_borrowed = data;
238       impl->d_data          = data;
239     } break;
240   }
241 
242   return CEED_ERROR_SUCCESS;
243 }
244 
245 //------------------------------------------------------------------------------
246 // Set the data used by a user context,
247 //   freeing any previously allocated data if applicable
248 //------------------------------------------------------------------------------
249 static int CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
250   Ceed ceed;
251   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
252 
253   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
254   switch (mem_type) {
255     case CEED_MEM_HOST:
256       return CeedQFunctionContextSetDataHost_Sycl(ctx, copy_mode, data);
257     case CEED_MEM_DEVICE:
258       return CeedQFunctionContextSetDataDevice_Sycl(ctx, copy_mode, data);
259   }
260 
261   return CEED_ERROR_UNSUPPORTED;
262 }
263 
264 //------------------------------------------------------------------------------
265 // Take data
266 //------------------------------------------------------------------------------
267 static int CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
268   Ceed ceed;
269   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
270   CeedQFunctionContext_Sycl *impl;
271   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
272 
273   Ceed_Sycl *ceedSycl;
274   CeedCallBackend(CeedGetData(ceed, &ceedSycl));
275 
276   // Order queue
277   ceedSycl->sycl_queue.ext_oneapi_submit_barrier();
278 
279   // Sync data to requested mem_type
280   bool need_sync = false;
281   CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
282   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
283 
284   // Update pointer
285   switch (mem_type) {
286     case CEED_MEM_HOST:
287       *(void **)data        = impl->h_data_borrowed;
288       impl->h_data_borrowed = NULL;
289       impl->h_data          = NULL;
290       break;
291     case CEED_MEM_DEVICE:
292       *(void **)data        = impl->d_data_borrowed;
293       impl->d_data_borrowed = NULL;
294       impl->d_data          = NULL;
295       break;
296   }
297 
298   return CEED_ERROR_SUCCESS;
299 }
300 
301 //------------------------------------------------------------------------------
302 // Core logic for GetData.
303 //   If a different memory type is most up to date, this will perform a copy
304 //------------------------------------------------------------------------------
305 static int CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
306   Ceed ceed;
307   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
308   CeedQFunctionContext_Sycl *impl;
309   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
310 
311   // Sync data to requested mem_type
312   bool need_sync = false;
313   CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
314   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
315 
316   // Update pointer
317   switch (mem_type) {
318     case CEED_MEM_HOST:
319       *(void **)data = impl->h_data;
320       break;
321     case CEED_MEM_DEVICE:
322       *(void **)data = impl->d_data;
323       break;
324   }
325 
326   return CEED_ERROR_SUCCESS;
327 }
328 
329 //------------------------------------------------------------------------------
330 // Get read-only access to the data
331 //------------------------------------------------------------------------------
332 static int CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
333   return CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data);
334 }
335 
336 //------------------------------------------------------------------------------
337 // Get read/write access to the data
338 //------------------------------------------------------------------------------
339 static int CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
340   CeedQFunctionContext_Sycl *impl;
341   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
342   Ceed ceed;
343   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
344 
345   CeedCallBackend(CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data));
346 
347   // Mark only pointer for requested memory as valid
348   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
349   switch (mem_type) {
350     case CEED_MEM_HOST:
351       impl->h_data = *(void **)data;
352       break;
353     case CEED_MEM_DEVICE:
354       impl->d_data = *(void **)data;
355       break;
356   }
357 
358   return CEED_ERROR_SUCCESS;
359 }
360 
361 //------------------------------------------------------------------------------
362 // Destroy the user context
363 //------------------------------------------------------------------------------
364 static int CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx) {
365   Ceed ceed;
366   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
367   CeedQFunctionContext_Sycl *impl;
368   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
369   Ceed_Sycl *sycl_data;
370   CeedCallBackend(CeedGetData(ceed, &sycl_data));
371 
372   // Wait for all work to finish before freeing memory
373   CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
374   CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
375   CeedCallBackend(CeedFree(&impl->h_data_owned));
376   CeedCallBackend(CeedFree(&impl));
377 
378   return CEED_ERROR_SUCCESS;
379 }
380 
381 //------------------------------------------------------------------------------
382 // QFunctionContext Create
383 //------------------------------------------------------------------------------
384 int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx) {
385   CeedQFunctionContext_Sycl *impl;
386   Ceed                       ceed;
387   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
388 
389   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Sycl));
390   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Sycl));
391   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Sycl));
392   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Sycl));
393   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Sycl));
394   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Sycl));
395   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Sycl));
396 
397   CeedCallBackend(CeedCalloc(1, &impl));
398   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
399 
400   return CEED_ERROR_SUCCESS;
401 }
402 
403 //------------------------------------------------------------------------------
404