xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-ref/ceed-sycl-vector.sycl.cpp (revision bd882c8a454763a096666645dc9a6229d5263694)
1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*bd882c8aSJames Wright //
4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5*bd882c8aSJames Wright //
6*bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7*bd882c8aSJames Wright 
8*bd882c8aSJames Wright #include <ceed/backend.h>
9*bd882c8aSJames Wright #include <ceed/ceed.h>
10*bd882c8aSJames Wright 
11*bd882c8aSJames Wright #include <cmath>
12*bd882c8aSJames Wright #include <string>
13*bd882c8aSJames Wright #include <sycl/sycl.hpp>
14*bd882c8aSJames Wright 
15*bd882c8aSJames Wright #include "ceed-sycl-ref.hpp"
16*bd882c8aSJames Wright 
17*bd882c8aSJames Wright //------------------------------------------------------------------------------
18*bd882c8aSJames Wright // Check if host/device sync is needed
19*bd882c8aSJames Wright //------------------------------------------------------------------------------
20*bd882c8aSJames Wright static inline int CeedVectorNeedSync_Sycl(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
21*bd882c8aSJames Wright   CeedVector_Sycl *impl;
22*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
23*bd882c8aSJames Wright 
24*bd882c8aSJames Wright   bool has_valid_array = false;
25*bd882c8aSJames Wright   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
26*bd882c8aSJames Wright   switch (mem_type) {
27*bd882c8aSJames Wright     case CEED_MEM_HOST:
28*bd882c8aSJames Wright       *need_sync = has_valid_array && !impl->h_array;
29*bd882c8aSJames Wright       break;
30*bd882c8aSJames Wright     case CEED_MEM_DEVICE:
31*bd882c8aSJames Wright       *need_sync = has_valid_array && !impl->d_array;
32*bd882c8aSJames Wright       break;
33*bd882c8aSJames Wright   }
34*bd882c8aSJames Wright 
35*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
36*bd882c8aSJames Wright }
37*bd882c8aSJames Wright 
38*bd882c8aSJames Wright //------------------------------------------------------------------------------
39*bd882c8aSJames Wright // Sync host to device
40*bd882c8aSJames Wright //------------------------------------------------------------------------------
41*bd882c8aSJames Wright static inline int CeedVectorSyncH2D_Sycl(const CeedVector vec) {
42*bd882c8aSJames Wright   Ceed ceed;
43*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
44*bd882c8aSJames Wright   CeedVector_Sycl *impl;
45*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
46*bd882c8aSJames Wright   Ceed_Sycl *data;
47*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
48*bd882c8aSJames Wright 
49*bd882c8aSJames Wright   if (!impl->h_array) {
50*bd882c8aSJames Wright     // LCOV_EXCL_START
51*bd882c8aSJames Wright     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
52*bd882c8aSJames Wright     // LCOV_EXCL_STOP
53*bd882c8aSJames Wright   }
54*bd882c8aSJames Wright 
55*bd882c8aSJames Wright   CeedSize length;
56*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetLength(vec, &length));
57*bd882c8aSJames Wright 
58*bd882c8aSJames Wright   if (impl->d_array_borrowed) {
59*bd882c8aSJames Wright     impl->d_array = impl->d_array_borrowed;
60*bd882c8aSJames Wright   } else if (impl->d_array_owned) {
61*bd882c8aSJames Wright     impl->d_array = impl->d_array_owned;
62*bd882c8aSJames Wright   } else {
63*bd882c8aSJames Wright     CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context));
64*bd882c8aSJames Wright     impl->d_array = impl->d_array_owned;
65*bd882c8aSJames Wright   }
66*bd882c8aSJames Wright 
67*bd882c8aSJames Wright   sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
68*bd882c8aSJames Wright   // Copy from host to device
69*bd882c8aSJames Wright   sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->h_array, impl->d_array, length, {e});
70*bd882c8aSJames Wright   // Wait for copy to finish and handle exceptions.
71*bd882c8aSJames Wright   CeedCallSycl(ceed, copy_event.wait_and_throw());
72*bd882c8aSJames Wright 
73*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
74*bd882c8aSJames Wright }
75*bd882c8aSJames Wright 
76*bd882c8aSJames Wright //------------------------------------------------------------------------------
77*bd882c8aSJames Wright // Sync device to host
78*bd882c8aSJames Wright //------------------------------------------------------------------------------
79*bd882c8aSJames Wright static inline int CeedVectorSyncD2H_Sycl(const CeedVector vec) {
80*bd882c8aSJames Wright   Ceed ceed;
81*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
82*bd882c8aSJames Wright   CeedVector_Sycl *impl;
83*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
84*bd882c8aSJames Wright   Ceed_Sycl *data;
85*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
86*bd882c8aSJames Wright 
87*bd882c8aSJames Wright   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
88*bd882c8aSJames Wright 
89*bd882c8aSJames Wright   CeedSize length;
90*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetLength(vec, &length));
91*bd882c8aSJames Wright 
92*bd882c8aSJames Wright   if (impl->h_array_borrowed) {
93*bd882c8aSJames Wright     impl->h_array = impl->h_array_borrowed;
94*bd882c8aSJames Wright   } else if (impl->h_array_owned) {
95*bd882c8aSJames Wright     impl->h_array = impl->h_array_owned;
96*bd882c8aSJames Wright   } else {
97*bd882c8aSJames Wright     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
98*bd882c8aSJames Wright     impl->h_array = impl->h_array_owned;
99*bd882c8aSJames Wright   }
100*bd882c8aSJames Wright 
101*bd882c8aSJames Wright   // Order queue
102*bd882c8aSJames Wright   sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
103*bd882c8aSJames Wright   // Copy from device to host
104*bd882c8aSJames Wright   sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->d_array, impl->h_array, length, {e});
105*bd882c8aSJames Wright 
106*bd882c8aSJames Wright   // Wait for copy to finish and handle exceptions.
107*bd882c8aSJames Wright   CeedCallSycl(ceed, copy_event.wait_and_throw());
108*bd882c8aSJames Wright 
109*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
110*bd882c8aSJames Wright }
111*bd882c8aSJames Wright 
112*bd882c8aSJames Wright //------------------------------------------------------------------------------
113*bd882c8aSJames Wright // Sync arrays
114*bd882c8aSJames Wright //------------------------------------------------------------------------------
115*bd882c8aSJames Wright static int CeedVectorSyncArray_Sycl(const CeedVector vec, CeedMemType mem_type) {
116*bd882c8aSJames Wright   // Check whether device/host sync is needed
117*bd882c8aSJames Wright   bool need_sync = false;
118*bd882c8aSJames Wright   CeedCallBackend(CeedVectorNeedSync_Sycl(vec, mem_type, &need_sync));
119*bd882c8aSJames Wright   if (!need_sync) return CEED_ERROR_SUCCESS;
120*bd882c8aSJames Wright 
121*bd882c8aSJames Wright   switch (mem_type) {
122*bd882c8aSJames Wright     case CEED_MEM_HOST:
123*bd882c8aSJames Wright       return CeedVectorSyncD2H_Sycl(vec);
124*bd882c8aSJames Wright     case CEED_MEM_DEVICE:
125*bd882c8aSJames Wright       return CeedVectorSyncH2D_Sycl(vec);
126*bd882c8aSJames Wright   }
127*bd882c8aSJames Wright   return CEED_ERROR_UNSUPPORTED;
128*bd882c8aSJames Wright }
129*bd882c8aSJames Wright 
130*bd882c8aSJames Wright //------------------------------------------------------------------------------
131*bd882c8aSJames Wright // Set all pointers as invalid
132*bd882c8aSJames Wright //------------------------------------------------------------------------------
133*bd882c8aSJames Wright static inline int CeedVectorSetAllInvalid_Sycl(const CeedVector vec) {
134*bd882c8aSJames Wright   CeedVector_Sycl *impl;
135*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
136*bd882c8aSJames Wright 
137*bd882c8aSJames Wright   impl->h_array = NULL;
138*bd882c8aSJames Wright   impl->d_array = NULL;
139*bd882c8aSJames Wright 
140*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
141*bd882c8aSJames Wright }
142*bd882c8aSJames Wright 
143*bd882c8aSJames Wright //------------------------------------------------------------------------------
144*bd882c8aSJames Wright // Check if CeedVector has any valid pointer
145*bd882c8aSJames Wright //------------------------------------------------------------------------------
146*bd882c8aSJames Wright static inline int CeedVectorHasValidArray_Sycl(const CeedVector vec, bool *has_valid_array) {
147*bd882c8aSJames Wright   CeedVector_Sycl *impl;
148*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
149*bd882c8aSJames Wright 
150*bd882c8aSJames Wright   *has_valid_array = !!impl->h_array || !!impl->d_array;
151*bd882c8aSJames Wright 
152*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
153*bd882c8aSJames Wright }
154*bd882c8aSJames Wright 
155*bd882c8aSJames Wright //------------------------------------------------------------------------------
156*bd882c8aSJames Wright // Check if has array of given type
157*bd882c8aSJames Wright //------------------------------------------------------------------------------
158*bd882c8aSJames Wright static inline int CeedVectorHasArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
159*bd882c8aSJames Wright   CeedVector_Sycl *impl;
160*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
161*bd882c8aSJames Wright 
162*bd882c8aSJames Wright   switch (mem_type) {
163*bd882c8aSJames Wright     case CEED_MEM_HOST:
164*bd882c8aSJames Wright       *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
165*bd882c8aSJames Wright       break;
166*bd882c8aSJames Wright     case CEED_MEM_DEVICE:
167*bd882c8aSJames Wright       *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
168*bd882c8aSJames Wright       break;
169*bd882c8aSJames Wright   }
170*bd882c8aSJames Wright 
171*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
172*bd882c8aSJames Wright }
173*bd882c8aSJames Wright 
174*bd882c8aSJames Wright //------------------------------------------------------------------------------
175*bd882c8aSJames Wright // Check if has borrowed array of given type
176*bd882c8aSJames Wright //------------------------------------------------------------------------------
177*bd882c8aSJames Wright static inline int CeedVectorHasBorrowedArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
178*bd882c8aSJames Wright   CeedVector_Sycl *impl;
179*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
180*bd882c8aSJames Wright 
181*bd882c8aSJames Wright   switch (mem_type) {
182*bd882c8aSJames Wright     case CEED_MEM_HOST:
183*bd882c8aSJames Wright       *has_borrowed_array_of_type = !!impl->h_array_borrowed;
184*bd882c8aSJames Wright       break;
185*bd882c8aSJames Wright     case CEED_MEM_DEVICE:
186*bd882c8aSJames Wright       *has_borrowed_array_of_type = !!impl->d_array_borrowed;
187*bd882c8aSJames Wright       break;
188*bd882c8aSJames Wright   }
189*bd882c8aSJames Wright 
190*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
191*bd882c8aSJames Wright }
192*bd882c8aSJames Wright 
193*bd882c8aSJames Wright //------------------------------------------------------------------------------
194*bd882c8aSJames Wright // Set array from host
195*bd882c8aSJames Wright //------------------------------------------------------------------------------
196*bd882c8aSJames Wright static int CeedVectorSetArrayHost_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
197*bd882c8aSJames Wright   CeedVector_Sycl *impl;
198*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
199*bd882c8aSJames Wright 
200*bd882c8aSJames Wright   switch (copy_mode) {
201*bd882c8aSJames Wright     case CEED_COPY_VALUES: {
202*bd882c8aSJames Wright       if (!impl->h_array_owned) {
203*bd882c8aSJames Wright         CeedSize length;
204*bd882c8aSJames Wright         CeedCallBackend(CeedVectorGetLength(vec, &length));
205*bd882c8aSJames Wright         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
206*bd882c8aSJames Wright       }
207*bd882c8aSJames Wright       impl->h_array_borrowed = NULL;
208*bd882c8aSJames Wright       impl->h_array          = impl->h_array_owned;
209*bd882c8aSJames Wright       if (array) {
210*bd882c8aSJames Wright         CeedSize length;
211*bd882c8aSJames Wright         CeedCallBackend(CeedVectorGetLength(vec, &length));
212*bd882c8aSJames Wright         size_t bytes = length * sizeof(CeedScalar);
213*bd882c8aSJames Wright         memcpy(impl->h_array, array, bytes);
214*bd882c8aSJames Wright       }
215*bd882c8aSJames Wright     } break;
216*bd882c8aSJames Wright     case CEED_OWN_POINTER:
217*bd882c8aSJames Wright       CeedCallBackend(CeedFree(&impl->h_array_owned));
218*bd882c8aSJames Wright       impl->h_array_owned    = array;
219*bd882c8aSJames Wright       impl->h_array_borrowed = NULL;
220*bd882c8aSJames Wright       impl->h_array          = array;
221*bd882c8aSJames Wright       break;
222*bd882c8aSJames Wright     case CEED_USE_POINTER:
223*bd882c8aSJames Wright       CeedCallBackend(CeedFree(&impl->h_array_owned));
224*bd882c8aSJames Wright       impl->h_array_borrowed = array;
225*bd882c8aSJames Wright       impl->h_array          = array;
226*bd882c8aSJames Wright       break;
227*bd882c8aSJames Wright   }
228*bd882c8aSJames Wright 
229*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
230*bd882c8aSJames Wright }
231*bd882c8aSJames Wright 
232*bd882c8aSJames Wright //------------------------------------------------------------------------------
233*bd882c8aSJames Wright // Set array from device
234*bd882c8aSJames Wright //------------------------------------------------------------------------------
235*bd882c8aSJames Wright static int CeedVectorSetArrayDevice_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
236*bd882c8aSJames Wright   Ceed ceed;
237*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
238*bd882c8aSJames Wright   CeedVector_Sycl *impl;
239*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
240*bd882c8aSJames Wright   Ceed_Sycl *data;
241*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
242*bd882c8aSJames Wright 
243*bd882c8aSJames Wright   // Order queue
244*bd882c8aSJames Wright   sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
245*bd882c8aSJames Wright 
246*bd882c8aSJames Wright   switch (copy_mode) {
247*bd882c8aSJames Wright     case CEED_COPY_VALUES: {
248*bd882c8aSJames Wright       CeedSize length;
249*bd882c8aSJames Wright       CeedCallBackend(CeedVectorGetLength(vec, &length));
250*bd882c8aSJames Wright       if (!impl->d_array_owned) {
251*bd882c8aSJames Wright         CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context));
252*bd882c8aSJames Wright         impl->d_array = impl->d_array_owned;
253*bd882c8aSJames Wright       }
254*bd882c8aSJames Wright       if (array) {
255*bd882c8aSJames Wright         sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(array, impl->d_array, length, {e});
256*bd882c8aSJames Wright         // Wait for copy to finish and handle exceptions.
257*bd882c8aSJames Wright         CeedCallSycl(ceed, copy_event.wait_and_throw());
258*bd882c8aSJames Wright       }
259*bd882c8aSJames Wright     } break;
260*bd882c8aSJames Wright     case CEED_OWN_POINTER:
261*bd882c8aSJames Wright       if (impl->d_array_owned) {
262*bd882c8aSJames Wright         // Wait for all work to finish before freeing memory
263*bd882c8aSJames Wright         CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
264*bd882c8aSJames Wright         CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context));
265*bd882c8aSJames Wright       }
266*bd882c8aSJames Wright       impl->d_array_owned    = array;
267*bd882c8aSJames Wright       impl->d_array_borrowed = NULL;
268*bd882c8aSJames Wright       impl->d_array          = array;
269*bd882c8aSJames Wright       break;
270*bd882c8aSJames Wright     case CEED_USE_POINTER:
271*bd882c8aSJames Wright       if (impl->d_array_owned) {
272*bd882c8aSJames Wright         // Wait for all work to finish before freeing memory
273*bd882c8aSJames Wright         CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
274*bd882c8aSJames Wright         CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context));
275*bd882c8aSJames Wright       }
276*bd882c8aSJames Wright       impl->d_array_owned    = NULL;
277*bd882c8aSJames Wright       impl->d_array_borrowed = array;
278*bd882c8aSJames Wright       impl->d_array          = array;
279*bd882c8aSJames Wright       break;
280*bd882c8aSJames Wright   }
281*bd882c8aSJames Wright 
282*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
283*bd882c8aSJames Wright }
284*bd882c8aSJames Wright 
285*bd882c8aSJames Wright //------------------------------------------------------------------------------
286*bd882c8aSJames Wright // Set the array used by a vector,
287*bd882c8aSJames Wright //   freeing any previously allocated array if applicable
288*bd882c8aSJames Wright //------------------------------------------------------------------------------
289*bd882c8aSJames Wright static int CeedVectorSetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
290*bd882c8aSJames Wright   Ceed ceed;
291*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
292*bd882c8aSJames Wright   CeedVector_Sycl *impl;
293*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
294*bd882c8aSJames Wright 
295*bd882c8aSJames Wright   CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec));
296*bd882c8aSJames Wright   switch (mem_type) {
297*bd882c8aSJames Wright     case CEED_MEM_HOST:
298*bd882c8aSJames Wright       return CeedVectorSetArrayHost_Sycl(vec, copy_mode, array);
299*bd882c8aSJames Wright     case CEED_MEM_DEVICE:
300*bd882c8aSJames Wright       return CeedVectorSetArrayDevice_Sycl(vec, copy_mode, array);
301*bd882c8aSJames Wright   }
302*bd882c8aSJames Wright 
303*bd882c8aSJames Wright   return CEED_ERROR_UNSUPPORTED;
304*bd882c8aSJames Wright }
305*bd882c8aSJames Wright 
306*bd882c8aSJames Wright //------------------------------------------------------------------------------
307*bd882c8aSJames Wright // Set host array to value
308*bd882c8aSJames Wright //------------------------------------------------------------------------------
309*bd882c8aSJames Wright static int CeedHostSetValue_Sycl(CeedScalar *h_array, CeedInt length, CeedScalar val) {
310*bd882c8aSJames Wright   for (int i = 0; i < length; i++) h_array[i] = val;
311*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
312*bd882c8aSJames Wright }
313*bd882c8aSJames Wright 
314*bd882c8aSJames Wright //------------------------------------------------------------------------------
315*bd882c8aSJames Wright // Set device array to value
316*bd882c8aSJames Wright //------------------------------------------------------------------------------
317*bd882c8aSJames Wright static int CeedDeviceSetValue_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedInt length, CeedScalar val) {
318*bd882c8aSJames Wright   // Order queue
319*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
320*bd882c8aSJames Wright   sycl_queue.fill(d_array, val, length, {e});
321*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
322*bd882c8aSJames Wright }
323*bd882c8aSJames Wright 
324*bd882c8aSJames Wright //------------------------------------------------------------------------------
325*bd882c8aSJames Wright // Set a vector to a value,
326*bd882c8aSJames Wright //------------------------------------------------------------------------------
327*bd882c8aSJames Wright static int CeedVectorSetValue_Sycl(CeedVector vec, CeedScalar val) {
328*bd882c8aSJames Wright   Ceed ceed;
329*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
330*bd882c8aSJames Wright   CeedVector_Sycl *impl;
331*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
332*bd882c8aSJames Wright   CeedSize length;
333*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetLength(vec, &length));
334*bd882c8aSJames Wright   Ceed_Sycl *data;
335*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
336*bd882c8aSJames Wright 
337*bd882c8aSJames Wright   // Set value for synced device/host array
338*bd882c8aSJames Wright   if (!impl->d_array && !impl->h_array) {
339*bd882c8aSJames Wright     if (impl->d_array_borrowed) {
340*bd882c8aSJames Wright       impl->d_array = impl->d_array_borrowed;
341*bd882c8aSJames Wright     } else if (impl->h_array_borrowed) {
342*bd882c8aSJames Wright       impl->h_array = impl->h_array_borrowed;
343*bd882c8aSJames Wright     } else if (impl->d_array_owned) {
344*bd882c8aSJames Wright       impl->d_array = impl->d_array_owned;
345*bd882c8aSJames Wright     } else if (impl->h_array_owned) {
346*bd882c8aSJames Wright       impl->h_array = impl->h_array_owned;
347*bd882c8aSJames Wright     } else {
348*bd882c8aSJames Wright       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
349*bd882c8aSJames Wright     }
350*bd882c8aSJames Wright   }
351*bd882c8aSJames Wright   if (impl->d_array) {
352*bd882c8aSJames Wright     CeedCallBackend(CeedDeviceSetValue_Sycl(data->sycl_queue, impl->d_array, length, val));
353*bd882c8aSJames Wright     impl->h_array = NULL;
354*bd882c8aSJames Wright   }
355*bd882c8aSJames Wright   if (impl->h_array) {
356*bd882c8aSJames Wright     CeedCallBackend(CeedHostSetValue_Sycl(impl->h_array, length, val));
357*bd882c8aSJames Wright     impl->d_array = NULL;
358*bd882c8aSJames Wright   }
359*bd882c8aSJames Wright 
360*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
361*bd882c8aSJames Wright }
362*bd882c8aSJames Wright 
363*bd882c8aSJames Wright //------------------------------------------------------------------------------
364*bd882c8aSJames Wright // Vector Take Array
365*bd882c8aSJames Wright //------------------------------------------------------------------------------
366*bd882c8aSJames Wright static int CeedVectorTakeArray_Sycl(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
367*bd882c8aSJames Wright   Ceed ceed;
368*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
369*bd882c8aSJames Wright   CeedVector_Sycl *impl;
370*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
371*bd882c8aSJames Wright 
372*bd882c8aSJames Wright   Ceed_Sycl *data;
373*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
374*bd882c8aSJames Wright 
375*bd882c8aSJames Wright   // Order queue
376*bd882c8aSJames Wright   data->sycl_queue.ext_oneapi_submit_barrier();
377*bd882c8aSJames Wright 
378*bd882c8aSJames Wright   // Sync array to requested mem_type
379*bd882c8aSJames Wright   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
380*bd882c8aSJames Wright 
381*bd882c8aSJames Wright   // Update pointer
382*bd882c8aSJames Wright   switch (mem_type) {
383*bd882c8aSJames Wright     case CEED_MEM_HOST:
384*bd882c8aSJames Wright       (*array)               = impl->h_array_borrowed;
385*bd882c8aSJames Wright       impl->h_array_borrowed = NULL;
386*bd882c8aSJames Wright       impl->h_array          = NULL;
387*bd882c8aSJames Wright       break;
388*bd882c8aSJames Wright     case CEED_MEM_DEVICE:
389*bd882c8aSJames Wright       (*array)               = impl->d_array_borrowed;
390*bd882c8aSJames Wright       impl->d_array_borrowed = NULL;
391*bd882c8aSJames Wright       impl->d_array          = NULL;
392*bd882c8aSJames Wright       break;
393*bd882c8aSJames Wright   }
394*bd882c8aSJames Wright 
395*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
396*bd882c8aSJames Wright }
397*bd882c8aSJames Wright 
398*bd882c8aSJames Wright //------------------------------------------------------------------------------
399*bd882c8aSJames Wright // Core logic for array syncronization for GetArray.
400*bd882c8aSJames Wright //   If a different memory type is most up to date, this will perform a copy
401*bd882c8aSJames Wright //------------------------------------------------------------------------------
402*bd882c8aSJames Wright static int CeedVectorGetArrayCore_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
403*bd882c8aSJames Wright   Ceed ceed;
404*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
405*bd882c8aSJames Wright   CeedVector_Sycl *impl;
406*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
407*bd882c8aSJames Wright 
408*bd882c8aSJames Wright   // Sync array to requested mem_type
409*bd882c8aSJames Wright   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
410*bd882c8aSJames Wright 
411*bd882c8aSJames Wright   // Update pointer
412*bd882c8aSJames Wright   switch (mem_type) {
413*bd882c8aSJames Wright     case CEED_MEM_HOST:
414*bd882c8aSJames Wright       *array = impl->h_array;
415*bd882c8aSJames Wright       break;
416*bd882c8aSJames Wright     case CEED_MEM_DEVICE:
417*bd882c8aSJames Wright       *array = impl->d_array;
418*bd882c8aSJames Wright       break;
419*bd882c8aSJames Wright   }
420*bd882c8aSJames Wright 
421*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
422*bd882c8aSJames Wright }
423*bd882c8aSJames Wright //------------------------------------------------------------------------------
424*bd882c8aSJames Wright // Get read-only access to a vector via the specified mem_type
425*bd882c8aSJames Wright //------------------------------------------------------------------------------
426*bd882c8aSJames Wright static int CeedVectorGetArrayRead_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
427*bd882c8aSJames Wright   return CeedVectorGetArrayCore_Sycl(vec, mem_type, (CeedScalar **)array);
428*bd882c8aSJames Wright }
429*bd882c8aSJames Wright 
430*bd882c8aSJames Wright //------------------------------------------------------------------------------
431*bd882c8aSJames Wright // Get read/write access to a vector via the specified mem_type
432*bd882c8aSJames Wright //------------------------------------------------------------------------------
433*bd882c8aSJames Wright static int CeedVectorGetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
434*bd882c8aSJames Wright   CeedVector_Sycl *impl;
435*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
436*bd882c8aSJames Wright 
437*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetArrayCore_Sycl(vec, mem_type, array));
438*bd882c8aSJames Wright 
439*bd882c8aSJames Wright   CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec));
440*bd882c8aSJames Wright   switch (mem_type) {
441*bd882c8aSJames Wright     case CEED_MEM_HOST:
442*bd882c8aSJames Wright       impl->h_array = *array;
443*bd882c8aSJames Wright       break;
444*bd882c8aSJames Wright     case CEED_MEM_DEVICE:
445*bd882c8aSJames Wright       impl->d_array = *array;
446*bd882c8aSJames Wright       break;
447*bd882c8aSJames Wright   }
448*bd882c8aSJames Wright 
449*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
450*bd882c8aSJames Wright }
451*bd882c8aSJames Wright 
452*bd882c8aSJames Wright //------------------------------------------------------------------------------
453*bd882c8aSJames Wright // Get write access to a vector via the specified mem_type
454*bd882c8aSJames Wright //------------------------------------------------------------------------------
455*bd882c8aSJames Wright static int CeedVectorGetArrayWrite_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
456*bd882c8aSJames Wright   CeedVector_Sycl *impl;
457*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
458*bd882c8aSJames Wright 
459*bd882c8aSJames Wright   bool has_array_of_type = true;
460*bd882c8aSJames Wright   CeedCallBackend(CeedVectorHasArrayOfType_Sycl(vec, mem_type, &has_array_of_type));
461*bd882c8aSJames Wright   if (!has_array_of_type) {
462*bd882c8aSJames Wright     // Allocate if array is not yet allocated
463*bd882c8aSJames Wright     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
464*bd882c8aSJames Wright   } else {
465*bd882c8aSJames Wright     // Select dirty array
466*bd882c8aSJames Wright     switch (mem_type) {
467*bd882c8aSJames Wright       case CEED_MEM_HOST:
468*bd882c8aSJames Wright         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
469*bd882c8aSJames Wright         else impl->h_array = impl->h_array_owned;
470*bd882c8aSJames Wright         break;
471*bd882c8aSJames Wright       case CEED_MEM_DEVICE:
472*bd882c8aSJames Wright         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
473*bd882c8aSJames Wright         else impl->d_array = impl->d_array_owned;
474*bd882c8aSJames Wright     }
475*bd882c8aSJames Wright   }
476*bd882c8aSJames Wright 
477*bd882c8aSJames Wright   return CeedVectorGetArray_Sycl(vec, mem_type, array);
478*bd882c8aSJames Wright }
479*bd882c8aSJames Wright 
480*bd882c8aSJames Wright //------------------------------------------------------------------------------
481*bd882c8aSJames Wright // Get the norm of a CeedVector
482*bd882c8aSJames Wright //------------------------------------------------------------------------------
483*bd882c8aSJames Wright static int CeedVectorNorm_Sycl(CeedVector vec, CeedNormType type, CeedScalar *norm) {
484*bd882c8aSJames Wright   Ceed ceed;
485*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
486*bd882c8aSJames Wright   CeedVector_Sycl *impl;
487*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
488*bd882c8aSJames Wright   CeedSize length;
489*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetLength(vec, &length));
490*bd882c8aSJames Wright   Ceed_Sycl *data;
491*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
492*bd882c8aSJames Wright 
493*bd882c8aSJames Wright   // Compute norm
494*bd882c8aSJames Wright   const CeedScalar *d_array;
495*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
496*bd882c8aSJames Wright 
497*bd882c8aSJames Wright   switch (type) {
498*bd882c8aSJames Wright     case CEED_NORM_1: {
499*bd882c8aSJames Wright       // Order queue
500*bd882c8aSJames Wright       sycl::event e            = data->sycl_queue.ext_oneapi_submit_barrier();
501*bd882c8aSJames Wright       auto        sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}});
502*bd882c8aSJames Wright       data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += abs(d_array[i]); }).wait_and_throw();
503*bd882c8aSJames Wright     } break;
504*bd882c8aSJames Wright     case CEED_NORM_2: {
505*bd882c8aSJames Wright       // Order queue
506*bd882c8aSJames Wright       sycl::event e            = data->sycl_queue.ext_oneapi_submit_barrier();
507*bd882c8aSJames Wright       auto        sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}});
508*bd882c8aSJames Wright       data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += (d_array[i] * d_array[i]); }).wait_and_throw();
509*bd882c8aSJames Wright     } break;
510*bd882c8aSJames Wright     case CEED_NORM_MAX: {
511*bd882c8aSJames Wright       // Order queue
512*bd882c8aSJames Wright       sycl::event e            = data->sycl_queue.ext_oneapi_submit_barrier();
513*bd882c8aSJames Wright       auto        maxReduction = sycl::reduction(impl->reduction_norm, sycl::maximum<>(), {sycl::property::reduction::initialize_to_identity{}});
514*bd882c8aSJames Wright       data->sycl_queue.parallel_for(length, {e}, maxReduction, [=](sycl::id<1> i, auto &max) { max.combine(abs(d_array[i])); }).wait_and_throw();
515*bd882c8aSJames Wright     } break;
516*bd882c8aSJames Wright   }
517*bd882c8aSJames Wright   // L2 norm - square root over reduced value
518*bd882c8aSJames Wright   if (type == CEED_NORM_2) *norm = sqrt(*impl->reduction_norm);
519*bd882c8aSJames Wright   else *norm = *impl->reduction_norm;
520*bd882c8aSJames Wright 
521*bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
522*bd882c8aSJames Wright 
523*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
524*bd882c8aSJames Wright }
525*bd882c8aSJames Wright 
526*bd882c8aSJames Wright //------------------------------------------------------------------------------
527*bd882c8aSJames Wright // Take reciprocal of a vector on host
528*bd882c8aSJames Wright //------------------------------------------------------------------------------
529*bd882c8aSJames Wright static int CeedHostReciprocal_Sycl(CeedScalar *h_array, CeedInt length) {
530*bd882c8aSJames Wright   for (int i = 0; i < length; i++) {
531*bd882c8aSJames Wright     if (std::fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
532*bd882c8aSJames Wright   }
533*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
534*bd882c8aSJames Wright }
535*bd882c8aSJames Wright 
536*bd882c8aSJames Wright //------------------------------------------------------------------------------
537*bd882c8aSJames Wright // Take reciprocal of a vector on device
538*bd882c8aSJames Wright //------------------------------------------------------------------------------
539*bd882c8aSJames Wright static int CeedDeviceReciprocal_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedInt length) {
540*bd882c8aSJames Wright   // Order queue
541*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
542*bd882c8aSJames Wright   sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) {
543*bd882c8aSJames Wright     if (std::fabs(d_array[i]) > CEED_EPSILON) d_array[i] = 1. / d_array[i];
544*bd882c8aSJames Wright   });
545*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
546*bd882c8aSJames Wright }
547*bd882c8aSJames Wright 
548*bd882c8aSJames Wright //------------------------------------------------------------------------------
549*bd882c8aSJames Wright // Take reciprocal of a vector
550*bd882c8aSJames Wright //------------------------------------------------------------------------------
551*bd882c8aSJames Wright static int CeedVectorReciprocal_Sycl(CeedVector vec) {
552*bd882c8aSJames Wright   Ceed ceed;
553*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
554*bd882c8aSJames Wright   CeedVector_Sycl *impl;
555*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
556*bd882c8aSJames Wright   CeedSize length;
557*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetLength(vec, &length));
558*bd882c8aSJames Wright   Ceed_Sycl *data;
559*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
560*bd882c8aSJames Wright 
561*bd882c8aSJames Wright   // Set value for synced device/host array
562*bd882c8aSJames Wright   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Sycl(data->sycl_queue, impl->d_array, length));
563*bd882c8aSJames Wright   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Sycl(impl->h_array, length));
564*bd882c8aSJames Wright 
565*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
566*bd882c8aSJames Wright }
567*bd882c8aSJames Wright 
568*bd882c8aSJames Wright //------------------------------------------------------------------------------
569*bd882c8aSJames Wright // Compute x = alpha x on the host
570*bd882c8aSJames Wright //------------------------------------------------------------------------------
571*bd882c8aSJames Wright static int CeedHostScale_Sycl(CeedScalar *x_array, CeedScalar alpha, CeedInt length) {
572*bd882c8aSJames Wright   for (int i = 0; i < length; i++) x_array[i] *= alpha;
573*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
574*bd882c8aSJames Wright }
575*bd882c8aSJames Wright 
576*bd882c8aSJames Wright //------------------------------------------------------------------------------
577*bd882c8aSJames Wright // Compute x = alpha x on device
578*bd882c8aSJames Wright //------------------------------------------------------------------------------
579*bd882c8aSJames Wright static int CeedDeviceScale_Sycl(sycl::queue &sycl_queue, CeedScalar *x_array, CeedScalar alpha, CeedInt length) {
580*bd882c8aSJames Wright   // Order queue
581*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
582*bd882c8aSJames Wright   sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { x_array[i] *= alpha; });
583*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
584*bd882c8aSJames Wright }
585*bd882c8aSJames Wright 
586*bd882c8aSJames Wright //------------------------------------------------------------------------------
587*bd882c8aSJames Wright // Compute x = alpha x
588*bd882c8aSJames Wright //------------------------------------------------------------------------------
589*bd882c8aSJames Wright static int CeedVectorScale_Sycl(CeedVector x, CeedScalar alpha) {
590*bd882c8aSJames Wright   Ceed ceed;
591*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
592*bd882c8aSJames Wright   CeedVector_Sycl *x_impl;
593*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(x, &x_impl));
594*bd882c8aSJames Wright   CeedSize length;
595*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetLength(x, &length));
596*bd882c8aSJames Wright   Ceed_Sycl *data;
597*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
598*bd882c8aSJames Wright 
599*bd882c8aSJames Wright   // Set value for synced device/host array
600*bd882c8aSJames Wright   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Sycl(data->sycl_queue, x_impl->d_array, alpha, length));
601*bd882c8aSJames Wright   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Sycl(x_impl->h_array, alpha, length));
602*bd882c8aSJames Wright 
603*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
604*bd882c8aSJames Wright }
605*bd882c8aSJames Wright 
606*bd882c8aSJames Wright //------------------------------------------------------------------------------
607*bd882c8aSJames Wright // Compute y = alpha x + y on the host
608*bd882c8aSJames Wright //------------------------------------------------------------------------------
609*bd882c8aSJames Wright static int CeedHostAXPY_Sycl(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) {
610*bd882c8aSJames Wright   for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
611*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
612*bd882c8aSJames Wright }
613*bd882c8aSJames Wright 
614*bd882c8aSJames Wright //------------------------------------------------------------------------------
615*bd882c8aSJames Wright // Compute y = alpha x + y on device
616*bd882c8aSJames Wright //------------------------------------------------------------------------------
617*bd882c8aSJames Wright static int CeedDeviceAXPY_Sycl(sycl::queue &sycl_queue, CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) {
618*bd882c8aSJames Wright   // Order queue
619*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
620*bd882c8aSJames Wright   sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { y_array[i] += alpha * x_array[i]; });
621*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
622*bd882c8aSJames Wright }
623*bd882c8aSJames Wright 
624*bd882c8aSJames Wright //------------------------------------------------------------------------------
625*bd882c8aSJames Wright // Compute y = alpha x + y
626*bd882c8aSJames Wright //------------------------------------------------------------------------------
627*bd882c8aSJames Wright static int CeedVectorAXPY_Sycl(CeedVector y, CeedScalar alpha, CeedVector x) {
628*bd882c8aSJames Wright   Ceed ceed;
629*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
630*bd882c8aSJames Wright   CeedVector_Sycl *y_impl, *x_impl;
631*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(y, &y_impl));
632*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(x, &x_impl));
633*bd882c8aSJames Wright   CeedSize length;
634*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetLength(y, &length));
635*bd882c8aSJames Wright   Ceed_Sycl *data;
636*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
637*bd882c8aSJames Wright 
638*bd882c8aSJames Wright   // Set value for synced device/host array
639*bd882c8aSJames Wright   if (y_impl->d_array) {
640*bd882c8aSJames Wright     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
641*bd882c8aSJames Wright     CeedCallBackend(CeedDeviceAXPY_Sycl(data->sycl_queue, y_impl->d_array, alpha, x_impl->d_array, length));
642*bd882c8aSJames Wright   }
643*bd882c8aSJames Wright   if (y_impl->h_array) {
644*bd882c8aSJames Wright     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
645*bd882c8aSJames Wright     CeedCallBackend(CeedHostAXPY_Sycl(y_impl->h_array, alpha, x_impl->h_array, length));
646*bd882c8aSJames Wright   }
647*bd882c8aSJames Wright 
648*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
649*bd882c8aSJames Wright }
650*bd882c8aSJames Wright 
651*bd882c8aSJames Wright //------------------------------------------------------------------------------
652*bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y on the host
653*bd882c8aSJames Wright //------------------------------------------------------------------------------
654*bd882c8aSJames Wright static int CeedHostPointwiseMult_Sycl(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) {
655*bd882c8aSJames Wright   for (int i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
656*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
657*bd882c8aSJames Wright }
658*bd882c8aSJames Wright 
659*bd882c8aSJames Wright //------------------------------------------------------------------------------
660*bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
661*bd882c8aSJames Wright //------------------------------------------------------------------------------
662*bd882c8aSJames Wright static int CeedDevicePointwiseMult_Sycl(sycl::queue &sycl_queue, CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) {
663*bd882c8aSJames Wright   // Order queue
664*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
665*bd882c8aSJames Wright   sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { w_array[i] = x_array[i] * y_array[i]; });
666*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
667*bd882c8aSJames Wright }
668*bd882c8aSJames Wright 
669*bd882c8aSJames Wright //------------------------------------------------------------------------------
670*bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y
671*bd882c8aSJames Wright //------------------------------------------------------------------------------
672*bd882c8aSJames Wright static int CeedVectorPointwiseMult_Sycl(CeedVector w, CeedVector x, CeedVector y) {
673*bd882c8aSJames Wright   Ceed ceed;
674*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
675*bd882c8aSJames Wright   CeedVector_Sycl *w_impl, *x_impl, *y_impl;
676*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(w, &w_impl));
677*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(x, &x_impl));
678*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(y, &y_impl));
679*bd882c8aSJames Wright   CeedSize length;
680*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetLength(w, &length));
681*bd882c8aSJames Wright   Ceed_Sycl *data;
682*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
683*bd882c8aSJames Wright 
684*bd882c8aSJames Wright   // Set value for synced device/host array
685*bd882c8aSJames Wright   if (!w_impl->d_array && !w_impl->h_array) {
686*bd882c8aSJames Wright     CeedCallBackend(CeedVectorSetValue(w, 0.0));
687*bd882c8aSJames Wright   }
688*bd882c8aSJames Wright   if (w_impl->d_array) {
689*bd882c8aSJames Wright     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
690*bd882c8aSJames Wright     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
691*bd882c8aSJames Wright     CeedCallBackend(CeedDevicePointwiseMult_Sycl(data->sycl_queue, w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
692*bd882c8aSJames Wright   }
693*bd882c8aSJames Wright   if (w_impl->h_array) {
694*bd882c8aSJames Wright     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
695*bd882c8aSJames Wright     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
696*bd882c8aSJames Wright     CeedCallBackend(CeedHostPointwiseMult_Sycl(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
697*bd882c8aSJames Wright   }
698*bd882c8aSJames Wright 
699*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
700*bd882c8aSJames Wright }
701*bd882c8aSJames Wright 
702*bd882c8aSJames Wright //------------------------------------------------------------------------------
703*bd882c8aSJames Wright // Destroy the vector
704*bd882c8aSJames Wright //------------------------------------------------------------------------------
705*bd882c8aSJames Wright static int CeedVectorDestroy_Sycl(const CeedVector vec) {
706*bd882c8aSJames Wright   Ceed ceed;
707*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
708*bd882c8aSJames Wright   CeedVector_Sycl *impl;
709*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetData(vec, &impl));
710*bd882c8aSJames Wright   Ceed_Sycl *data;
711*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
712*bd882c8aSJames Wright 
713*bd882c8aSJames Wright   // Wait for all work to finish before freeing memory
714*bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
715*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context));
716*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->reduction_norm, data->sycl_context));
717*bd882c8aSJames Wright 
718*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl->h_array_owned));
719*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
720*bd882c8aSJames Wright 
721*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
722*bd882c8aSJames Wright }
723*bd882c8aSJames Wright 
724*bd882c8aSJames Wright //------------------------------------------------------------------------------
725*bd882c8aSJames Wright // Create a vector of the specified length (does not allocate memory)
726*bd882c8aSJames Wright //------------------------------------------------------------------------------
727*bd882c8aSJames Wright int CeedVectorCreate_Sycl(CeedSize n, CeedVector vec) {
728*bd882c8aSJames Wright   CeedVector_Sycl *impl;
729*bd882c8aSJames Wright   Ceed             ceed;
730*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
731*bd882c8aSJames Wright   Ceed_Sycl *data;
732*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
733*bd882c8aSJames Wright 
734*bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(1, &impl));
735*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->reduction_norm = sycl::malloc_host<CeedScalar>(1, data->sycl_context));
736*bd882c8aSJames Wright 
737*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Sycl));
738*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Sycl));
739*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Sycl));
740*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Sycl));
741*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Sycl));
742*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Sycl));
743*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Sycl));
744*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Sycl));
745*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Sycl));
746*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Norm", CeedVectorNorm_Sycl));
747*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Sycl));
748*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Sycl));
749*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Scale", CeedVectorScale_Sycl));
750*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Sycl));
751*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Sycl));
752*bd882c8aSJames Wright 
753*bd882c8aSJames Wright   CeedCallBackend(CeedVectorSetData(vec, impl));
754*bd882c8aSJames Wright 
755*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
756*bd882c8aSJames Wright }
757