xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-qfunctioncontext.sycl.cpp (revision 49a40c8a2d720db341b0b117b89656b473cbebfb)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 
11 #include <string>
12 #include <sycl/sycl.hpp>
13 
14 #include "ceed-sycl-ref.hpp"
15 
16 //------------------------------------------------------------------------------
17 // Sync host to device
18 //------------------------------------------------------------------------------
19 static inline int CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx) {
20   Ceed                       ceed;
21   Ceed_Sycl                 *sycl_data;
22   size_t                     ctx_size;
23   CeedQFunctionContext_Sycl *impl;
24 
25   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
26   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
27   CeedCallBackend(CeedGetData(ceed, &sycl_data));
28 
29   if (!impl->h_data) {
30     // LCOV_EXCL_START
31     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
32     // LCOV_EXCL_STOP
33   }
34 
35   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
36 
37   if (impl->d_data_borrowed) {
38     impl->d_data = impl->d_data_borrowed;
39   } else if (impl->d_data_owned) {
40     impl->d_data = impl->d_data_owned;
41   } else {
42     CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context));
43     impl->d_data = impl->d_data_owned;
44   }
45   // Order queue
46   sycl::event e          = sycl_data->sycl_queue.ext_oneapi_submit_barrier();
47   sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, impl->h_data, ctx_size, {e});
48   CeedCallSycl(ceed, copy_event.wait_and_throw());
49   return CEED_ERROR_SUCCESS;
50 }
51 
52 //------------------------------------------------------------------------------
53 // Sync device to host
54 //------------------------------------------------------------------------------
55 static inline int CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx) {
56   Ceed                       ceed;
57   Ceed_Sycl                 *sycl_data;
58   size_t                     ctx_size;
59   CeedQFunctionContext_Sycl *impl;
60 
61   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
62   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
63   CeedCallBackend(CeedGetData(ceed, &sycl_data));
64 
65   if (!impl->d_data) {
66     // LCOV_EXCL_START
67     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
68     // LCOV_EXCL_STOP
69   }
70 
71   CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
72 
73   if (impl->h_data_borrowed) {
74     impl->h_data = impl->h_data_borrowed;
75   } else if (impl->h_data_owned) {
76     impl->h_data = impl->h_data_owned;
77   } else {
78     CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
79     impl->h_data = impl->h_data_owned;
80   }
81 
82   // Order queue
83   sycl::event e          = sycl_data->sycl_queue.ext_oneapi_submit_barrier();
84   sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->h_data, impl->d_data, ctx_size, {e});
85   CeedCallSycl(ceed, copy_event.wait_and_throw());
86   return CEED_ERROR_SUCCESS;
87 }
88 
89 //------------------------------------------------------------------------------
90 // Sync data of type
91 //------------------------------------------------------------------------------
92 static inline int CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type) {
93   switch (mem_type) {
94     case CEED_MEM_HOST:
95       return CeedQFunctionContextSyncD2H_Sycl(ctx);
96     case CEED_MEM_DEVICE:
97       return CeedQFunctionContextSyncH2D_Sycl(ctx);
98   }
99   return CEED_ERROR_UNSUPPORTED;
100 }
101 
102 //------------------------------------------------------------------------------
103 // Set all pointers as invalid
104 //------------------------------------------------------------------------------
105 static inline int CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx) {
106   CeedQFunctionContext_Sycl *impl;
107 
108   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
109   impl->h_data = NULL;
110   impl->d_data = NULL;
111   return CEED_ERROR_SUCCESS;
112 }
113 
114 //------------------------------------------------------------------------------
115 // Check if ctx has valid data
116 //------------------------------------------------------------------------------
117 static inline int CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx, bool *has_valid_data) {
118   CeedQFunctionContext_Sycl *impl;
119 
120   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
121   *has_valid_data = impl && (impl->h_data || impl->d_data);
122   return CEED_ERROR_SUCCESS;
123 }
124 
125 //------------------------------------------------------------------------------
126 // Check if ctx has borrowed data
127 //------------------------------------------------------------------------------
128 static inline int CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type,
129                                                                  bool *has_borrowed_data_of_type) {
130   CeedQFunctionContext_Sycl *impl;
131 
132   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
133   switch (mem_type) {
134     case CEED_MEM_HOST:
135       *has_borrowed_data_of_type = impl->h_data_borrowed;
136       break;
137     case CEED_MEM_DEVICE:
138       *has_borrowed_data_of_type = impl->d_data_borrowed;
139       break;
140   }
141   return CEED_ERROR_SUCCESS;
142 }
143 
144 //------------------------------------------------------------------------------
145 // Check if data of given type needs sync
146 //------------------------------------------------------------------------------
147 static inline int CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) {
148   bool                       has_valid_data = true;
149   CeedQFunctionContext_Sycl *impl;
150 
151   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
152   CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
153   switch (mem_type) {
154     case CEED_MEM_HOST:
155       *need_sync = has_valid_data && !impl->h_data;
156       break;
157     case CEED_MEM_DEVICE:
158       *need_sync = has_valid_data && !impl->d_data;
159       break;
160   }
161   return CEED_ERROR_SUCCESS;
162 }
163 
164 //------------------------------------------------------------------------------
165 // Set data from host
166 //------------------------------------------------------------------------------
167 static int CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
168   CeedQFunctionContext_Sycl *impl;
169 
170   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
171   CeedCallBackend(CeedFree(&impl->h_data_owned));
172   switch (copy_mode) {
173     case CEED_COPY_VALUES:
174       size_t ctx_size;
175 
176       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
177       CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned));
178       impl->h_data_borrowed = NULL;
179       impl->h_data          = impl->h_data_owned;
180       memcpy(impl->h_data, data, ctx_size);
181       break;
182     case CEED_OWN_POINTER:
183       impl->h_data_owned    = data;
184       impl->h_data_borrowed = NULL;
185       impl->h_data          = data;
186       break;
187     case CEED_USE_POINTER:
188       impl->h_data_borrowed = data;
189       impl->h_data          = data;
190       break;
191   }
192   return CEED_ERROR_SUCCESS;
193 }
194 
195 //------------------------------------------------------------------------------
196 // Set data from device
197 //------------------------------------------------------------------------------
198 static int CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) {
199   Ceed                       ceed;
200   Ceed_Sycl                 *sycl_data;
201   CeedQFunctionContext_Sycl *impl;
202 
203   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
204   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
205   CeedCallBackend(CeedGetData(ceed, &sycl_data));
206 
207   // Order queue
208   sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier();
209 
210   // Wait for all work to finish before freeing memory
211   if (impl->d_data_owned) {
212     CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
213     CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
214     impl->d_data_owned = NULL;
215   }
216 
217   switch (copy_mode) {
218     case CEED_COPY_VALUES: {
219       size_t ctx_size;
220 
221       CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size));
222       CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context));
223       impl->d_data_borrowed  = NULL;
224       impl->d_data           = impl->d_data_owned;
225       sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, data, ctx_size, {e});
226       CeedCallSycl(ceed, copy_event.wait_and_throw());
227     } break;
228     case CEED_OWN_POINTER: {
229       impl->d_data_owned    = data;
230       impl->d_data_borrowed = NULL;
231       impl->d_data          = data;
232     } break;
233     case CEED_USE_POINTER: {
234       impl->d_data_owned    = NULL;
235       impl->d_data_borrowed = data;
236       impl->d_data          = data;
237     } break;
238   }
239   return CEED_ERROR_SUCCESS;
240 }
241 
242 //------------------------------------------------------------------------------
243 // Set the data used by a user context,
244 //   freeing any previously allocated data if applicable
245 //------------------------------------------------------------------------------
246 static int CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) {
247   Ceed ceed;
248 
249   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
250   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
251   switch (mem_type) {
252     case CEED_MEM_HOST:
253       return CeedQFunctionContextSetDataHost_Sycl(ctx, copy_mode, data);
254     case CEED_MEM_DEVICE:
255       return CeedQFunctionContextSetDataDevice_Sycl(ctx, copy_mode, data);
256   }
257   return CEED_ERROR_UNSUPPORTED;
258 }
259 
260 //------------------------------------------------------------------------------
261 // Take data
262 //------------------------------------------------------------------------------
263 static int CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
264   Ceed                       ceed;
265   Ceed_Sycl                 *ceedSycl;
266   bool                       need_sync = false;
267   CeedQFunctionContext_Sycl *impl;
268 
269   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
270   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
271   CeedCallBackend(CeedGetData(ceed, &ceedSycl));
272 
273   // Order queue
274   ceedSycl->sycl_queue.ext_oneapi_submit_barrier();
275 
276   // Sync data to requested mem_type
277   CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
278   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
279 
280   // Update pointer
281   switch (mem_type) {
282     case CEED_MEM_HOST:
283       *(void **)data        = impl->h_data_borrowed;
284       impl->h_data_borrowed = NULL;
285       impl->h_data          = NULL;
286       break;
287     case CEED_MEM_DEVICE:
288       *(void **)data        = impl->d_data_borrowed;
289       impl->d_data_borrowed = NULL;
290       impl->d_data          = NULL;
291       break;
292   }
293   return CEED_ERROR_SUCCESS;
294 }
295 
296 //------------------------------------------------------------------------------
297 // Core logic for GetData.
298 //   If a different memory type is most up to date, this will perform a copy
299 //------------------------------------------------------------------------------
300 static int CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
301   Ceed                       ceed;
302   bool                       need_sync = false;
303   CeedQFunctionContext_Sycl *impl;
304 
305   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
306   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
307 
308   // Sync data to requested mem_type
309   CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync));
310   if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type));
311 
312   // Update pointer
313   switch (mem_type) {
314     case CEED_MEM_HOST:
315       *(void **)data = impl->h_data;
316       break;
317     case CEED_MEM_DEVICE:
318       *(void **)data = impl->d_data;
319       break;
320   }
321   return CEED_ERROR_SUCCESS;
322 }
323 
324 //------------------------------------------------------------------------------
325 // Get read-only access to the data
326 //------------------------------------------------------------------------------
327 static int CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
328   return CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data);
329 }
330 
331 //------------------------------------------------------------------------------
332 // Get read/write access to the data
333 //------------------------------------------------------------------------------
334 static int CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) {
335   Ceed                       ceed;
336   CeedQFunctionContext_Sycl *impl;
337 
338   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
339   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
340   CeedCallBackend(CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data));
341 
342   // Mark only pointer for requested memory as valid
343   CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx));
344   switch (mem_type) {
345     case CEED_MEM_HOST:
346       impl->h_data = *(void **)data;
347       break;
348     case CEED_MEM_DEVICE:
349       impl->d_data = *(void **)data;
350       break;
351   }
352   return CEED_ERROR_SUCCESS;
353 }
354 
355 //------------------------------------------------------------------------------
356 // Destroy the user context
357 //------------------------------------------------------------------------------
358 static int CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx) {
359   Ceed                       ceed;
360   Ceed_Sycl                 *sycl_data;
361   CeedQFunctionContext_Sycl *impl;
362 
363   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
364   CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl));
365   CeedCallBackend(CeedGetData(ceed, &sycl_data));
366 
367   // Wait for all work to finish before freeing memory
368   CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw());
369   CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context));
370   CeedCallBackend(CeedFree(&impl->h_data_owned));
371   CeedCallBackend(CeedFree(&impl));
372   return CEED_ERROR_SUCCESS;
373 }
374 
375 //------------------------------------------------------------------------------
376 // QFunctionContext Create
377 //------------------------------------------------------------------------------
378 int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx) {
379   Ceed                       ceed;
380   CeedQFunctionContext_Sycl *impl;
381 
382   CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
383   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Sycl));
384   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Sycl));
385   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Sycl));
386   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Sycl));
387   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Sycl));
388   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Sycl));
389   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Sycl));
390   CeedCallBackend(CeedCalloc(1, &impl));
391   CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl));
392   return CEED_ERROR_SUCCESS;
393 }
394 
395 //------------------------------------------------------------------------------
396