xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-ref/ceed-sycl-restriction.sycl.cpp (revision 22eb13854768cd7db9fa223351f183dc3d3dc7a1)
1bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other
2bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3bd882c8aSJames Wright // files for details.
4bd882c8aSJames Wright //
5bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
6bd882c8aSJames Wright //
7bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
8bd882c8aSJames Wright 
9bd882c8aSJames Wright #include <ceed/backend.h>
10bd882c8aSJames Wright #include <ceed/ceed.h>
11bd882c8aSJames Wright #include <ceed/jit-tools.h>
12bd882c8aSJames Wright 
13bd882c8aSJames Wright #include <string>
14bd882c8aSJames Wright #include <sycl/sycl.hpp>
15bd882c8aSJames Wright 
16bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
17bd882c8aSJames Wright #include "ceed-sycl-ref.hpp"
18bd882c8aSJames Wright 
19bd882c8aSJames Wright class CeedElemRestrSyclStridedNT;
20bd882c8aSJames Wright class CeedElemRestrSyclOffsetNT;
21bd882c8aSJames Wright class CeedElemRestrSyclStridedT;
22bd882c8aSJames Wright class CeedElemRestrSyclOffsetT;
23bd882c8aSJames Wright 
24bd882c8aSJames Wright //------------------------------------------------------------------------------
25bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, strided
26bd882c8aSJames Wright //------------------------------------------------------------------------------
27bd882c8aSJames Wright static int CeedElemRestrictionStridedNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
28bd882c8aSJames Wright                                                       CeedScalar *v) {
29bd882c8aSJames Wright   const CeedInt  elem_size    = impl->elem_size;
30bd882c8aSJames Wright   const CeedInt  num_elem     = impl->num_elem;
31bd882c8aSJames Wright   const CeedInt  num_comp     = impl->num_comp;
32bd882c8aSJames Wright   const CeedInt  stride_nodes = impl->strides[0];
33bd882c8aSJames Wright   const CeedInt  stride_comp  = impl->strides[1];
34bd882c8aSJames Wright   const CeedInt  stride_elem  = impl->strides[2];
35bd882c8aSJames Wright   sycl::range<1> kernel_range(num_elem * elem_size);
36bd882c8aSJames Wright 
37bd882c8aSJames Wright   // Order queue
38bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
39bd882c8aSJames Wright   sycl_queue.parallel_for<CeedElemRestrSyclStridedNT>(kernel_range, {e}, [=](sycl::id<1> node) {
40bd882c8aSJames Wright     const CeedInt loc_node = node % elem_size;
41bd882c8aSJames Wright     const CeedInt elem     = node / elem_size;
42bd882c8aSJames Wright 
43bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
44bd882c8aSJames Wright       v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem];
45bd882c8aSJames Wright     }
46bd882c8aSJames Wright   });
47bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
48bd882c8aSJames Wright }
49bd882c8aSJames Wright 
50bd882c8aSJames Wright //------------------------------------------------------------------------------
51bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, offsets provided
52bd882c8aSJames Wright //------------------------------------------------------------------------------
53bd882c8aSJames Wright static int CeedElemRestrictionOffsetNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
54bd882c8aSJames Wright                                                      CeedScalar *v) {
55bd882c8aSJames Wright   const CeedInt  elem_size   = impl->elem_size;
56bd882c8aSJames Wright   const CeedInt  num_elem    = impl->num_elem;
57bd882c8aSJames Wright   const CeedInt  num_comp    = impl->num_comp;
58bd882c8aSJames Wright   const CeedInt  comp_stride = impl->comp_stride;
59bd882c8aSJames Wright   const CeedInt *indices     = impl->d_ind;
60bd882c8aSJames Wright 
61bd882c8aSJames Wright   sycl::range<1> kernel_range(num_elem * elem_size);
62bd882c8aSJames Wright 
63bd882c8aSJames Wright   // Order queue
64bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
65bd882c8aSJames Wright   sycl_queue.parallel_for<CeedElemRestrSyclOffsetNT>(kernel_range, {e}, [=](sycl::id<1> node) {
66bd882c8aSJames Wright     const CeedInt ind      = indices[node];
67bd882c8aSJames Wright     const CeedInt loc_node = node % elem_size;
68bd882c8aSJames Wright     const CeedInt elem     = node / elem_size;
69bd882c8aSJames Wright 
70bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
71bd882c8aSJames Wright       v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[ind + comp * comp_stride];
72bd882c8aSJames Wright     }
73bd882c8aSJames Wright   });
74bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
75bd882c8aSJames Wright }
76bd882c8aSJames Wright 
77bd882c8aSJames Wright //------------------------------------------------------------------------------
78bd882c8aSJames Wright // Kernel: E-vector -> L-vector, strided
79bd882c8aSJames Wright //------------------------------------------------------------------------------
80bd882c8aSJames Wright static int CeedElemRestrictionStridedTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
81bd882c8aSJames Wright                                                     CeedScalar *v) {
82bd882c8aSJames Wright   const CeedInt elem_size    = impl->elem_size;
83bd882c8aSJames Wright   const CeedInt num_elem     = impl->num_elem;
84bd882c8aSJames Wright   const CeedInt num_comp     = impl->num_comp;
85bd882c8aSJames Wright   const CeedInt stride_nodes = impl->strides[0];
86bd882c8aSJames Wright   const CeedInt stride_comp  = impl->strides[1];
87bd882c8aSJames Wright   const CeedInt stride_elem  = impl->strides[2];
88bd882c8aSJames Wright 
89bd882c8aSJames Wright   sycl::range<1> kernel_range(num_elem * elem_size);
90bd882c8aSJames Wright 
91bd882c8aSJames Wright   // Order queue
92bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
93bd882c8aSJames Wright   sycl_queue.parallel_for<CeedElemRestrSyclStridedT>(kernel_range, {e}, [=](sycl::id<1> node) {
94bd882c8aSJames Wright     const CeedInt loc_node = node % elem_size;
95bd882c8aSJames Wright     const CeedInt elem     = node / elem_size;
96bd882c8aSJames Wright 
97bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
98bd882c8aSJames Wright       v[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem] += u[loc_node + comp * elem_size * num_elem + elem * elem_size];
99bd882c8aSJames Wright     }
100bd882c8aSJames Wright   });
101bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
102bd882c8aSJames Wright }
103bd882c8aSJames Wright 
104bd882c8aSJames Wright //------------------------------------------------------------------------------
105bd882c8aSJames Wright // Kernel: E-vector -> L-vector, offsets provided
106bd882c8aSJames Wright //------------------------------------------------------------------------------
107bd882c8aSJames Wright static int CeedElemRestrictionOffsetTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
108bd882c8aSJames Wright                                                    CeedScalar *v) {
109bd882c8aSJames Wright   const CeedInt  num_nodes     = impl->num_nodes;
110bd882c8aSJames Wright   const CeedInt  elem_size     = impl->elem_size;
111bd882c8aSJames Wright   const CeedInt  num_elem      = impl->num_elem;
112bd882c8aSJames Wright   const CeedInt  num_comp      = impl->num_comp;
113bd882c8aSJames Wright   const CeedInt  comp_stride   = impl->comp_stride;
114bd882c8aSJames Wright   const CeedInt *l_vec_indices = impl->d_l_vec_indices;
115bd882c8aSJames Wright   const CeedInt *t_offsets     = impl->d_t_offsets;
116bd882c8aSJames Wright   const CeedInt *t_indices     = impl->d_t_indices;
117bd882c8aSJames Wright 
118bd882c8aSJames Wright   sycl::range<1> kernel_range(num_nodes * num_comp);
119bd882c8aSJames Wright 
120bd882c8aSJames Wright   // Order queue
121bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
122bd882c8aSJames Wright   sycl_queue.parallel_for<CeedElemRestrSyclOffsetT>(kernel_range, {e}, [=](sycl::id<1> id) {
123bd882c8aSJames Wright     const CeedInt node    = id % num_nodes;
124bd882c8aSJames Wright     const CeedInt comp    = id / num_nodes;
125bd882c8aSJames Wright     const CeedInt ind     = l_vec_indices[node];
126bd882c8aSJames Wright     const CeedInt range_1 = t_offsets[node];
127bd882c8aSJames Wright     const CeedInt range_N = t_offsets[node + 1];
128bd882c8aSJames Wright     CeedScalar    value   = 0.0;
129bd882c8aSJames Wright 
130bd882c8aSJames Wright     for (CeedInt j = range_1; j < range_N; j++) {
131bd882c8aSJames Wright       const CeedInt t_ind    = t_indices[j];
132bd882c8aSJames Wright       CeedInt       loc_node = t_ind % elem_size;
133bd882c8aSJames Wright       CeedInt       elem     = t_ind / elem_size;
134bd882c8aSJames Wright 
135bd882c8aSJames Wright       value += u[loc_node + comp * elem_size * num_elem + elem * elem_size];
136bd882c8aSJames Wright     }
137bd882c8aSJames Wright     v[ind + comp * comp_stride] += value;
138bd882c8aSJames Wright   });
139bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
140bd882c8aSJames Wright }
141bd882c8aSJames Wright 
142bd882c8aSJames Wright //------------------------------------------------------------------------------
143bd882c8aSJames Wright // Apply restriction
144bd882c8aSJames Wright //------------------------------------------------------------------------------
145dce49693SSebastian Grimberg static int CeedElemRestrictionApply_Sycl(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
146bd882c8aSJames Wright   Ceed                      ceed;
147bd882c8aSJames Wright   Ceed_Sycl                *data;
148dd64fc84SJeremy L Thompson   const CeedScalar         *d_u;
149dd64fc84SJeremy L Thompson   CeedScalar               *d_v;
150dd64fc84SJeremy L Thompson   CeedElemRestriction_Sycl *impl;
151dd64fc84SJeremy L Thompson 
152dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
153dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
154bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
155bd882c8aSJames Wright 
156bd882c8aSJames Wright   // Get vectors
157bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
158bd882c8aSJames Wright   if (t_mode == CEED_TRANSPOSE) {
159bd882c8aSJames Wright     // Sum into for transpose mode, e-vec to l-vec
160bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
161bd882c8aSJames Wright   } else {
162bd882c8aSJames Wright     // Overwrite for notranspose mode, l-vec to e-vec
163bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
164bd882c8aSJames Wright   }
165bd882c8aSJames Wright 
166bd882c8aSJames Wright   // Restrict
167bd882c8aSJames Wright   if (t_mode == CEED_NOTRANSPOSE) {
168bd882c8aSJames Wright     // L-vector -> E-vector
169bd882c8aSJames Wright     if (impl->d_ind) {
170bd882c8aSJames Wright       // -- Offsets provided
171bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionOffsetNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
172bd882c8aSJames Wright     } else {
173bd882c8aSJames Wright       // -- Strided restriction
174bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionStridedNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
175bd882c8aSJames Wright     }
176bd882c8aSJames Wright   } else {
177bd882c8aSJames Wright     // E-vector -> L-vector
178bd882c8aSJames Wright     if (impl->d_ind) {
179bd882c8aSJames Wright       // -- Offsets provided
180bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionOffsetTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
181bd882c8aSJames Wright     } else {
182bd882c8aSJames Wright       // -- Strided restriction
183bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionStridedTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
184bd882c8aSJames Wright     }
185bd882c8aSJames Wright   }
186bd882c8aSJames Wright   // Wait for queues to be completed. NOTE: This may not be necessary
187bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
188bd882c8aSJames Wright 
189bd882c8aSJames Wright   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
190bd882c8aSJames Wright 
191bd882c8aSJames Wright   // Restore arrays
192bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
193bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
194bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
195bd882c8aSJames Wright }
196bd882c8aSJames Wright 
197bd882c8aSJames Wright //------------------------------------------------------------------------------
198bd882c8aSJames Wright // Get offsets
199bd882c8aSJames Wright //------------------------------------------------------------------------------
200dce49693SSebastian Grimberg static int CeedElemRestrictionGetOffsets_Sycl(CeedElemRestriction rstr, CeedMemType m_type, const CeedInt **offsets) {
201bd882c8aSJames Wright   Ceed                      ceed;
202bd882c8aSJames Wright   CeedElemRestriction_Sycl *impl;
203dd64fc84SJeremy L Thompson 
204dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
205dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
206bd882c8aSJames Wright 
207bd882c8aSJames Wright   switch (m_type) {
208bd882c8aSJames Wright     case CEED_MEM_HOST:
209bd882c8aSJames Wright       *offsets = impl->h_ind;
210bd882c8aSJames Wright       break;
211bd882c8aSJames Wright     case CEED_MEM_DEVICE:
212bd882c8aSJames Wright       *offsets = impl->d_ind;
213bd882c8aSJames Wright       break;
214bd882c8aSJames Wright   }
215bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
216bd882c8aSJames Wright }
217bd882c8aSJames Wright 
218bd882c8aSJames Wright //------------------------------------------------------------------------------
219bd882c8aSJames Wright // Destroy restriction
220bd882c8aSJames Wright //------------------------------------------------------------------------------
221dce49693SSebastian Grimberg static int CeedElemRestrictionDestroy_Sycl(CeedElemRestriction rstr) {
222bd882c8aSJames Wright   Ceed                      ceed;
223bd882c8aSJames Wright   Ceed_Sycl                *data;
224dd64fc84SJeremy L Thompson   CeedElemRestriction_Sycl *impl;
225dd64fc84SJeremy L Thompson 
226dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
227dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
228bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
229bd882c8aSJames Wright 
230bd882c8aSJames Wright   // Wait for all work to finish before freeing memory
231bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
232bd882c8aSJames Wright 
233bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
234bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_ind_allocated, data->sycl_context));
235bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_t_offsets, data->sycl_context));
236bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_t_indices, data->sycl_context));
237bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_l_vec_indices, data->sycl_context));
238bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
239bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
240bd882c8aSJames Wright }
241bd882c8aSJames Wright 
242bd882c8aSJames Wright //------------------------------------------------------------------------------
243bd882c8aSJames Wright // Create transpose offsets and indices
244bd882c8aSJames Wright //------------------------------------------------------------------------------
245dce49693SSebastian Grimberg static int CeedElemRestrictionOffset_Sycl(const CeedElemRestriction rstr, const CeedInt *indices) {
246bd882c8aSJames Wright   Ceed                      ceed;
247dd64fc84SJeremy L Thompson   Ceed_Sycl                *data;
248dd64fc84SJeremy L Thompson   bool                     *is_node;
249bd882c8aSJames Wright   CeedSize                  l_size;
250dd64fc84SJeremy L Thompson   CeedInt                   num_elem, elem_size, num_comp, num_nodes = 0, *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices;
251dd64fc84SJeremy L Thompson   CeedElemRestriction_Sycl *impl;
252dd64fc84SJeremy L Thompson 
253dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
254dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
255dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
256dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
257dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size));
258dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
259bd882c8aSJames Wright 
260bd882c8aSJames Wright   // Count num_nodes
261bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(l_size, &is_node));
262bd882c8aSJames Wright   const CeedInt size_indices = num_elem * elem_size;
263dd64fc84SJeremy L Thompson 
264bd882c8aSJames Wright   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
265bd882c8aSJames Wright   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
266bd882c8aSJames Wright   impl->num_nodes = num_nodes;
267bd882c8aSJames Wright 
268bd882c8aSJames Wright   // L-vector offsets array
269bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
270bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
271dd64fc84SJeremy L Thompson   for (CeedInt i = 0, j = 0; i < l_size; i++) {
272bd882c8aSJames Wright     if (is_node[i]) {
273bd882c8aSJames Wright       l_vec_indices[j] = i;
274bd882c8aSJames Wright       ind_to_offset[i] = j++;
275bd882c8aSJames Wright     }
276bd882c8aSJames Wright   }
277bd882c8aSJames Wright   CeedCallBackend(CeedFree(&is_node));
278bd882c8aSJames Wright 
279bd882c8aSJames Wright   // Compute transpose offsets and indices
280bd882c8aSJames Wright   const CeedInt size_offsets = num_nodes + 1;
281dd64fc84SJeremy L Thompson 
282bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
283bd882c8aSJames Wright   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
284bd882c8aSJames Wright   // Count node multiplicity
285bd882c8aSJames Wright   for (CeedInt e = 0; e < num_elem; ++e) {
286bd882c8aSJames Wright     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
287bd882c8aSJames Wright   }
288bd882c8aSJames Wright   // Convert to running sum
289bd882c8aSJames Wright   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
290bd882c8aSJames Wright   // List all E-vec indices associated with L-vec node
291bd882c8aSJames Wright   for (CeedInt e = 0; e < num_elem; ++e) {
292bd882c8aSJames Wright     for (CeedInt i = 0; i < elem_size; ++i) {
293bd882c8aSJames Wright       const CeedInt lid                          = elem_size * e + i;
294bd882c8aSJames Wright       const CeedInt gid                          = indices[lid];
295bd882c8aSJames Wright       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
296bd882c8aSJames Wright     }
297bd882c8aSJames Wright   }
298bd882c8aSJames Wright   // Reset running sum
299bd882c8aSJames Wright   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
300bd882c8aSJames Wright   t_offsets[0] = 0;
301bd882c8aSJames Wright 
302bd882c8aSJames Wright   // Copy data to device
303bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
304bd882c8aSJames Wright 
305bd882c8aSJames Wright   // Order queue
306bd882c8aSJames Wright   sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
307bd882c8aSJames Wright 
308bd882c8aSJames Wright   // -- L-vector indices
309bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_l_vec_indices = sycl::malloc_device<CeedInt>(num_nodes, data->sycl_device, data->sycl_context));
310bd882c8aSJames Wright   sycl::event copy_lvec = data->sycl_queue.copy<CeedInt>(l_vec_indices, impl->d_l_vec_indices, num_nodes, {e});
311bd882c8aSJames Wright   // -- Transpose offsets
312bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_t_offsets = sycl::malloc_device<CeedInt>(size_offsets, data->sycl_device, data->sycl_context));
313bd882c8aSJames Wright   sycl::event copy_offsets = data->sycl_queue.copy<CeedInt>(t_offsets, impl->d_t_offsets, size_offsets, {e});
314bd882c8aSJames Wright   // -- Transpose indices
315bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_t_indices = sycl::malloc_device<CeedInt>(size_indices, data->sycl_device, data->sycl_context));
316bd882c8aSJames Wright   sycl::event copy_indices = data->sycl_queue.copy<CeedInt>(t_indices, impl->d_t_indices, size_indices, {e});
317bd882c8aSJames Wright 
318bd882c8aSJames Wright   // Wait for all copies to complete and handle exceptions
319bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_lvec, copy_offsets, copy_indices}));
320bd882c8aSJames Wright 
321bd882c8aSJames Wright   // Cleanup
322bd882c8aSJames Wright   CeedCallBackend(CeedFree(&ind_to_offset));
323bd882c8aSJames Wright   CeedCallBackend(CeedFree(&l_vec_indices));
324bd882c8aSJames Wright   CeedCallBackend(CeedFree(&t_offsets));
325bd882c8aSJames Wright   CeedCallBackend(CeedFree(&t_indices));
326bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
327bd882c8aSJames Wright }
328bd882c8aSJames Wright 
329bd882c8aSJames Wright //------------------------------------------------------------------------------
330bd882c8aSJames Wright // Create restriction
331bd882c8aSJames Wright //------------------------------------------------------------------------------
33200125730SSebastian Grimberg int CeedElemRestrictionCreate_Sycl(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients,
333dce49693SSebastian Grimberg                                    const CeedInt8 *curl_orients, CeedElemRestriction rstr) {
334bd882c8aSJames Wright   Ceed                      ceed;
335bd882c8aSJames Wright   Ceed_Sycl                *data;
336dd64fc84SJeremy L Thompson   bool                      is_strided;
337dd64fc84SJeremy L Thompson   CeedInt                   num_elem, num_comp, elem_size, comp_stride = 1;
338dd64fc84SJeremy L Thompson   CeedRestrictionType       rstr_type;
339bd882c8aSJames Wright   CeedElemRestriction_Sycl *impl;
340dd64fc84SJeremy L Thompson 
341dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
342dd64fc84SJeremy L Thompson   CeedCallBackend(CeedGetData(ceed, &data));
343dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
344dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
345dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
346dce49693SSebastian Grimberg   const CeedInt size       = num_elem * elem_size;
347bd882c8aSJames Wright   CeedInt       strides[3] = {1, size, elem_size};
348bd882c8aSJames Wright 
349dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
35000125730SSebastian Grimberg   CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND,
35100125730SSebastian Grimberg             "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented");
35200125730SSebastian Grimberg 
353bd882c8aSJames Wright   // Stride data
354dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionIsStrided(rstr, &is_strided));
355bd882c8aSJames Wright   if (is_strided) {
356bd882c8aSJames Wright     bool has_backend_strides;
357dd64fc84SJeremy L Thompson 
358dce49693SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides));
359bd882c8aSJames Wright     if (!has_backend_strides) {
360dce49693SSebastian Grimberg       CeedCallBackend(CeedElemRestrictionGetStrides(rstr, &strides));
361bd882c8aSJames Wright     }
362bd882c8aSJames Wright   } else {
363dce49693SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride));
364bd882c8aSJames Wright   }
365bd882c8aSJames Wright 
366dce49693SSebastian Grimberg   CeedCallBackend(CeedCalloc(1, &impl));
367bd882c8aSJames Wright   impl->h_ind           = NULL;
368bd882c8aSJames Wright   impl->h_ind_allocated = NULL;
369bd882c8aSJames Wright   impl->d_ind           = NULL;
370bd882c8aSJames Wright   impl->d_ind_allocated = NULL;
371bd882c8aSJames Wright   impl->d_t_indices     = NULL;
372bd882c8aSJames Wright   impl->d_t_offsets     = NULL;
373bd882c8aSJames Wright   impl->num_nodes       = size;
374bd882c8aSJames Wright   impl->num_elem        = num_elem;
375bd882c8aSJames Wright   impl->num_comp        = num_comp;
376bd882c8aSJames Wright   impl->elem_size       = elem_size;
377bd882c8aSJames Wright   impl->comp_stride     = comp_stride;
378bd882c8aSJames Wright   impl->strides[0]      = strides[0];
379bd882c8aSJames Wright   impl->strides[1]      = strides[1];
380bd882c8aSJames Wright   impl->strides[2]      = strides[2];
381dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionSetData(rstr, impl));
382*22eb1385SJeremy L Thompson 
383*22eb1385SJeremy L Thompson   // Set layouts
384*22eb1385SJeremy L Thompson   {
385*22eb1385SJeremy L Thompson     bool    has_backend_strides;
386*22eb1385SJeremy L Thompson     CeedInt layout[3] = {1, size, elem_size};
387*22eb1385SJeremy L Thompson 
388dce49693SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout));
389*22eb1385SJeremy L Thompson     if (rstr_type == CEED_RESTRICTION_STRIDED) {
390*22eb1385SJeremy L Thompson       CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides));
391*22eb1385SJeremy L Thompson       if (has_backend_strides) {
392*22eb1385SJeremy L Thompson         CeedCallBackend(CeedElemRestrictionSetLLayout(rstr, layout));
393*22eb1385SJeremy L Thompson       }
394*22eb1385SJeremy L Thompson     }
395*22eb1385SJeremy L Thompson   }
396bd882c8aSJames Wright 
397bd882c8aSJames Wright   // Set up device indices/offset arrays
398dd64fc84SJeremy L Thompson   if (mem_type == CEED_MEM_HOST) {
399bd882c8aSJames Wright     switch (copy_mode) {
400bd882c8aSJames Wright       case CEED_OWN_POINTER:
401bd882c8aSJames Wright         impl->h_ind_allocated = (CeedInt *)indices;
402bd882c8aSJames Wright         impl->h_ind           = (CeedInt *)indices;
403bd882c8aSJames Wright         break;
404bd882c8aSJames Wright       case CEED_USE_POINTER:
405bd882c8aSJames Wright         impl->h_ind = (CeedInt *)indices;
406bd882c8aSJames Wright         break;
407bd882c8aSJames Wright       case CEED_COPY_VALUES:
408bd882c8aSJames Wright         if (indices != NULL) {
409bd882c8aSJames Wright           CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
410bd882c8aSJames Wright           memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
411bd882c8aSJames Wright           impl->h_ind = impl->h_ind_allocated;
412bd882c8aSJames Wright         }
413bd882c8aSJames Wright         break;
414bd882c8aSJames Wright     }
415bd882c8aSJames Wright     if (indices != NULL) {
416bd882c8aSJames Wright       CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context));
417bd882c8aSJames Wright       impl->d_ind_allocated = impl->d_ind;  // We own the device memory
418bd882c8aSJames Wright       // Order queue
419bd882c8aSJames Wright       sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
420bd882c8aSJames Wright       // Copy from host to device
421bd882c8aSJames Wright       sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e});
422bd882c8aSJames Wright       // Wait for copy to finish and handle exceptions
423bd882c8aSJames Wright       CeedCallSycl(ceed, copy_event.wait_and_throw());
424dce49693SSebastian Grimberg       CeedCallBackend(CeedElemRestrictionOffset_Sycl(rstr, indices));
425bd882c8aSJames Wright     }
426dd64fc84SJeremy L Thompson   } else if (mem_type == CEED_MEM_DEVICE) {
427bd882c8aSJames Wright     switch (copy_mode) {
428bd882c8aSJames Wright       case CEED_COPY_VALUES:
429bd882c8aSJames Wright         if (indices != NULL) {
430bd882c8aSJames Wright           CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context));
431bd882c8aSJames Wright           impl->d_ind_allocated = impl->d_ind;  // We own the device memory
432bd882c8aSJames Wright                                                 // Copy from device to device
433bd882c8aSJames Wright           // Order queue
434bd882c8aSJames Wright           sycl::event e          = data->sycl_queue.ext_oneapi_submit_barrier();
435bd882c8aSJames Wright           sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e});
436bd882c8aSJames Wright           // Wait for copy to finish and handle exceptions
437bd882c8aSJames Wright           CeedCallSycl(ceed, copy_event.wait_and_throw());
438bd882c8aSJames Wright         }
439bd882c8aSJames Wright         break;
440bd882c8aSJames Wright       case CEED_OWN_POINTER:
441bd882c8aSJames Wright         impl->d_ind           = (CeedInt *)indices;
442bd882c8aSJames Wright         impl->d_ind_allocated = impl->d_ind;
443bd882c8aSJames Wright         break;
444bd882c8aSJames Wright       case CEED_USE_POINTER:
445bd882c8aSJames Wright         impl->d_ind = (CeedInt *)indices;
446bd882c8aSJames Wright     }
447bd882c8aSJames Wright     if (indices != NULL) {
448bd882c8aSJames Wright       CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
449bd882c8aSJames Wright       // Order queue
450bd882c8aSJames Wright       sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
451bd882c8aSJames Wright       // Copy from device to host
452bd882c8aSJames Wright       sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->d_ind, impl->h_ind_allocated, elem_size * num_elem, {e});
453bd882c8aSJames Wright       CeedCallSycl(ceed, copy_event.wait_and_throw());
454bd882c8aSJames Wright       impl->h_ind = impl->h_ind_allocated;
455dce49693SSebastian Grimberg       CeedCallBackend(CeedElemRestrictionOffset_Sycl(rstr, indices));
456bd882c8aSJames Wright     }
457bd882c8aSJames Wright   } else {
458bd882c8aSJames Wright     // LCOV_EXCL_START
459bd882c8aSJames Wright     return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported");
460bd882c8aSJames Wright     // LCOV_EXCL_STOP
461bd882c8aSJames Wright   }
462bd882c8aSJames Wright 
463bd882c8aSJames Wright   // Register backend functions
464dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Sycl));
465dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApply_Sycl));
466dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApply_Sycl));
467dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Sycl));
468dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Sycl));
469bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
470bd882c8aSJames Wright }
471