xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-ref/ceed-sycl-restriction.sycl.cpp (revision 9bc663991d6482bcb1d60b1f116148f11db83fa1)
15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other
2bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3bd882c8aSJames Wright // files for details.
4bd882c8aSJames Wright //
5bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
6bd882c8aSJames Wright //
7bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
8bd882c8aSJames Wright 
9bd882c8aSJames Wright #include <ceed/backend.h>
10bd882c8aSJames Wright #include <ceed/ceed.h>
11bd882c8aSJames Wright #include <ceed/jit-tools.h>
12bd882c8aSJames Wright 
13bd882c8aSJames Wright #include <string>
14bd882c8aSJames Wright #include <sycl/sycl.hpp>
15bd882c8aSJames Wright 
16bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
17bd882c8aSJames Wright #include "ceed-sycl-ref.hpp"
18bd882c8aSJames Wright 
19bd882c8aSJames Wright class CeedElemRestrSyclStridedNT;
20bd882c8aSJames Wright class CeedElemRestrSyclOffsetNT;
21bd882c8aSJames Wright class CeedElemRestrSyclStridedT;
22bd882c8aSJames Wright class CeedElemRestrSyclOffsetT;
23bd882c8aSJames Wright 
24bd882c8aSJames Wright //------------------------------------------------------------------------------
25bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, strided
26bd882c8aSJames Wright //------------------------------------------------------------------------------
27bd882c8aSJames Wright static int CeedElemRestrictionStridedNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
28bd882c8aSJames Wright                                                       CeedScalar *v) {
29bd882c8aSJames Wright   const CeedInt  elem_size    = impl->elem_size;
30bd882c8aSJames Wright   const CeedInt  num_elem     = impl->num_elem;
31bd882c8aSJames Wright   const CeedInt  num_comp     = impl->num_comp;
32bd882c8aSJames Wright   const CeedInt  stride_nodes = impl->strides[0];
33bd882c8aSJames Wright   const CeedInt  stride_comp  = impl->strides[1];
34bd882c8aSJames Wright   const CeedInt  stride_elem  = impl->strides[2];
35bd882c8aSJames Wright   sycl::range<1> kernel_range(num_elem * elem_size);
36bd882c8aSJames Wright 
371f4b1b45SUmesh Unnikrishnan   std::vector<sycl::event> e;
381f4b1b45SUmesh Unnikrishnan 
391f4b1b45SUmesh Unnikrishnan   if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()};
401f4b1b45SUmesh Unnikrishnan   sycl_queue.parallel_for<CeedElemRestrSyclStridedNT>(kernel_range, e, [=](sycl::id<1> node) {
41bd882c8aSJames Wright     const CeedInt loc_node = node % elem_size;
42bd882c8aSJames Wright     const CeedInt elem     = node / elem_size;
43bd882c8aSJames Wright 
44bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
45bd882c8aSJames Wright       v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem];
46bd882c8aSJames Wright     }
47bd882c8aSJames Wright   });
48bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
49bd882c8aSJames Wright }
50bd882c8aSJames Wright 
51bd882c8aSJames Wright //------------------------------------------------------------------------------
52bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, offsets provided
53bd882c8aSJames Wright //------------------------------------------------------------------------------
54bd882c8aSJames Wright static int CeedElemRestrictionOffsetNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
55bd882c8aSJames Wright                                                      CeedScalar *v) {
56bd882c8aSJames Wright   const CeedInt  elem_size   = impl->elem_size;
57bd882c8aSJames Wright   const CeedInt  num_elem    = impl->num_elem;
58bd882c8aSJames Wright   const CeedInt  num_comp    = impl->num_comp;
59bd882c8aSJames Wright   const CeedInt  comp_stride = impl->comp_stride;
60f59ebe5eSJeremy L Thompson   const CeedInt *indices     = impl->d_offsets;
61bd882c8aSJames Wright 
62bd882c8aSJames Wright   sycl::range<1> kernel_range(num_elem * elem_size);
63bd882c8aSJames Wright 
641f4b1b45SUmesh Unnikrishnan   std::vector<sycl::event> e;
651f4b1b45SUmesh Unnikrishnan 
661f4b1b45SUmesh Unnikrishnan   if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()};
671f4b1b45SUmesh Unnikrishnan   sycl_queue.parallel_for<CeedElemRestrSyclOffsetNT>(kernel_range, e, [=](sycl::id<1> node) {
68bd882c8aSJames Wright     const CeedInt ind      = indices[node];
69bd882c8aSJames Wright     const CeedInt loc_node = node % elem_size;
70bd882c8aSJames Wright     const CeedInt elem     = node / elem_size;
71bd882c8aSJames Wright 
72bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
73bd882c8aSJames Wright       v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[ind + comp * comp_stride];
74bd882c8aSJames Wright     }
75bd882c8aSJames Wright   });
76bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
77bd882c8aSJames Wright }
78bd882c8aSJames Wright 
79bd882c8aSJames Wright //------------------------------------------------------------------------------
80bd882c8aSJames Wright // Kernel: E-vector -> L-vector, strided
81bd882c8aSJames Wright //------------------------------------------------------------------------------
82bd882c8aSJames Wright static int CeedElemRestrictionStridedTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
83bd882c8aSJames Wright                                                     CeedScalar *v) {
84bd882c8aSJames Wright   const CeedInt elem_size    = impl->elem_size;
85bd882c8aSJames Wright   const CeedInt num_elem     = impl->num_elem;
86bd882c8aSJames Wright   const CeedInt num_comp     = impl->num_comp;
87bd882c8aSJames Wright   const CeedInt stride_nodes = impl->strides[0];
88bd882c8aSJames Wright   const CeedInt stride_comp  = impl->strides[1];
89bd882c8aSJames Wright   const CeedInt stride_elem  = impl->strides[2];
90bd882c8aSJames Wright 
91bd882c8aSJames Wright   sycl::range<1> kernel_range(num_elem * elem_size);
92bd882c8aSJames Wright 
931f4b1b45SUmesh Unnikrishnan   std::vector<sycl::event> e;
941f4b1b45SUmesh Unnikrishnan 
951f4b1b45SUmesh Unnikrishnan   if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()};
961f4b1b45SUmesh Unnikrishnan   sycl_queue.parallel_for<CeedElemRestrSyclStridedT>(kernel_range, e, [=](sycl::id<1> node) {
97bd882c8aSJames Wright     const CeedInt loc_node = node % elem_size;
98bd882c8aSJames Wright     const CeedInt elem     = node / elem_size;
99bd882c8aSJames Wright 
100bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
101bd882c8aSJames Wright       v[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem] += u[loc_node + comp * elem_size * num_elem + elem * elem_size];
102bd882c8aSJames Wright     }
103bd882c8aSJames Wright   });
104bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
105bd882c8aSJames Wright }
106bd882c8aSJames Wright 
107bd882c8aSJames Wright //------------------------------------------------------------------------------
108bd882c8aSJames Wright // Kernel: E-vector -> L-vector, offsets provided
109bd882c8aSJames Wright //------------------------------------------------------------------------------
110bd882c8aSJames Wright static int CeedElemRestrictionOffsetTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
111bd882c8aSJames Wright                                                    CeedScalar *v) {
112bd882c8aSJames Wright   const CeedInt  num_nodes     = impl->num_nodes;
113bd882c8aSJames Wright   const CeedInt  elem_size     = impl->elem_size;
114bd882c8aSJames Wright   const CeedInt  num_elem      = impl->num_elem;
115bd882c8aSJames Wright   const CeedInt  num_comp      = impl->num_comp;
116bd882c8aSJames Wright   const CeedInt  comp_stride   = impl->comp_stride;
117bd882c8aSJames Wright   const CeedInt *l_vec_indices = impl->d_l_vec_indices;
118bd882c8aSJames Wright   const CeedInt *t_offsets     = impl->d_t_offsets;
119bd882c8aSJames Wright   const CeedInt *t_indices     = impl->d_t_indices;
120bd882c8aSJames Wright 
121bd882c8aSJames Wright   sycl::range<1> kernel_range(num_nodes * num_comp);
122bd882c8aSJames Wright 
1231f4b1b45SUmesh Unnikrishnan   std::vector<sycl::event> e;
1241f4b1b45SUmesh Unnikrishnan 
1251f4b1b45SUmesh Unnikrishnan   if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()};
1261f4b1b45SUmesh Unnikrishnan   sycl_queue.parallel_for<CeedElemRestrSyclOffsetT>(kernel_range, e, [=](sycl::id<1> id) {
127bd882c8aSJames Wright     const CeedInt node    = id % num_nodes;
128bd882c8aSJames Wright     const CeedInt comp    = id / num_nodes;
129bd882c8aSJames Wright     const CeedInt ind     = l_vec_indices[node];
130bd882c8aSJames Wright     const CeedInt range_1 = t_offsets[node];
131bd882c8aSJames Wright     const CeedInt range_N = t_offsets[node + 1];
132bd882c8aSJames Wright     CeedScalar    value   = 0.0;
133bd882c8aSJames Wright 
134bd882c8aSJames Wright     for (CeedInt j = range_1; j < range_N; j++) {
135bd882c8aSJames Wright       const CeedInt t_ind    = t_indices[j];
136bd882c8aSJames Wright       CeedInt       loc_node = t_ind % elem_size;
137bd882c8aSJames Wright       CeedInt       elem     = t_ind / elem_size;
138bd882c8aSJames Wright 
139bd882c8aSJames Wright       value += u[loc_node + comp * elem_size * num_elem + elem * elem_size];
140bd882c8aSJames Wright     }
141bd882c8aSJames Wright     v[ind + comp * comp_stride] += value;
142bd882c8aSJames Wright   });
143bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
144bd882c8aSJames Wright }
145bd882c8aSJames Wright 
146bd882c8aSJames Wright //------------------------------------------------------------------------------
147bd882c8aSJames Wright // Apply restriction
148bd882c8aSJames Wright //------------------------------------------------------------------------------
149dce49693SSebastian Grimberg static int CeedElemRestrictionApply_Sycl(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
150bd882c8aSJames Wright   Ceed                      ceed;
151bd882c8aSJames Wright   Ceed_Sycl                *data;
152dd64fc84SJeremy L Thompson   const CeedScalar         *d_u;
153dd64fc84SJeremy L Thompson   CeedScalar               *d_v;
154dd64fc84SJeremy L Thompson   CeedElemRestriction_Sycl *impl;
155dd64fc84SJeremy L Thompson 
156dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
157dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
158bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
159bd882c8aSJames Wright 
160bd882c8aSJames Wright   // Get vectors
161bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
162bd882c8aSJames Wright   if (t_mode == CEED_TRANSPOSE) {
163bd882c8aSJames Wright     // Sum into for transpose mode, e-vec to l-vec
164bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
165bd882c8aSJames Wright   } else {
166bd882c8aSJames Wright     // Overwrite for notranspose mode, l-vec to e-vec
167bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
168bd882c8aSJames Wright   }
169bd882c8aSJames Wright 
170bd882c8aSJames Wright   // Restrict
171bd882c8aSJames Wright   if (t_mode == CEED_NOTRANSPOSE) {
172bd882c8aSJames Wright     // L-vector -> E-vector
173f59ebe5eSJeremy L Thompson     if (impl->d_offsets) {
174bd882c8aSJames Wright       // -- Offsets provided
175bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionOffsetNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
176bd882c8aSJames Wright     } else {
177bd882c8aSJames Wright       // -- Strided restriction
178bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionStridedNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
179bd882c8aSJames Wright     }
180bd882c8aSJames Wright   } else {
181bd882c8aSJames Wright     // E-vector -> L-vector
182f59ebe5eSJeremy L Thompson     if (impl->d_offsets) {
183bd882c8aSJames Wright       // -- Offsets provided
184bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionOffsetTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
185bd882c8aSJames Wright     } else {
186bd882c8aSJames Wright       // -- Strided restriction
187bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionStridedTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
188bd882c8aSJames Wright     }
189bd882c8aSJames Wright   }
190bd882c8aSJames Wright   // Wait for queues to be completed. NOTE: This may not be necessary
191bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
192bd882c8aSJames Wright 
193bd882c8aSJames Wright   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
194bd882c8aSJames Wright 
195bd882c8aSJames Wright   // Restore arrays
196bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
197bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
198*9bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
199bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
200bd882c8aSJames Wright }
201bd882c8aSJames Wright 
202bd882c8aSJames Wright //------------------------------------------------------------------------------
203bd882c8aSJames Wright // Get offsets
204bd882c8aSJames Wright //------------------------------------------------------------------------------
205dce49693SSebastian Grimberg static int CeedElemRestrictionGetOffsets_Sycl(CeedElemRestriction rstr, CeedMemType m_type, const CeedInt **offsets) {
206bd882c8aSJames Wright   CeedElemRestriction_Sycl *impl;
207dd64fc84SJeremy L Thompson 
208dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
209bd882c8aSJames Wright 
210bd882c8aSJames Wright   switch (m_type) {
211bd882c8aSJames Wright     case CEED_MEM_HOST:
212f59ebe5eSJeremy L Thompson       *offsets = impl->h_offsets;
213bd882c8aSJames Wright       break;
214bd882c8aSJames Wright     case CEED_MEM_DEVICE:
215f59ebe5eSJeremy L Thompson       *offsets = impl->d_offsets;
216bd882c8aSJames Wright       break;
217bd882c8aSJames Wright   }
218bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
219bd882c8aSJames Wright }
220bd882c8aSJames Wright 
221bd882c8aSJames Wright //------------------------------------------------------------------------------
222bd882c8aSJames Wright // Destroy restriction
223bd882c8aSJames Wright //------------------------------------------------------------------------------
224dce49693SSebastian Grimberg static int CeedElemRestrictionDestroy_Sycl(CeedElemRestriction rstr) {
225bd882c8aSJames Wright   Ceed                      ceed;
226bd882c8aSJames Wright   Ceed_Sycl                *data;
227dd64fc84SJeremy L Thompson   CeedElemRestriction_Sycl *impl;
228dd64fc84SJeremy L Thompson 
229dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
230dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
231bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
232bd882c8aSJames Wright 
233bd882c8aSJames Wright   // Wait for all work to finish before freeing memory
234bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
235bd882c8aSJames Wright 
236f59ebe5eSJeremy L Thompson   CeedCallBackend(CeedFree(&impl->h_offsets_owned));
237f59ebe5eSJeremy L Thompson   CeedCallSycl(ceed, sycl::free(impl->d_offsets_owned, data->sycl_context));
238bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_t_offsets, data->sycl_context));
239bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_t_indices, data->sycl_context));
240bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_l_vec_indices, data->sycl_context));
241bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
242*9bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
243bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
244bd882c8aSJames Wright }
245bd882c8aSJames Wright 
246bd882c8aSJames Wright //------------------------------------------------------------------------------
247bd882c8aSJames Wright // Create transpose offsets and indices
248bd882c8aSJames Wright //------------------------------------------------------------------------------
249dce49693SSebastian Grimberg static int CeedElemRestrictionOffset_Sycl(const CeedElemRestriction rstr, const CeedInt *indices) {
250bd882c8aSJames Wright   Ceed                      ceed;
251dd64fc84SJeremy L Thompson   Ceed_Sycl                *data;
252dd64fc84SJeremy L Thompson   bool                     *is_node;
253bd882c8aSJames Wright   CeedSize                  l_size;
254dd64fc84SJeremy L Thompson   CeedInt                   num_elem, elem_size, num_comp, num_nodes = 0, *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices;
255dd64fc84SJeremy L Thompson   CeedElemRestriction_Sycl *impl;
256dd64fc84SJeremy L Thompson 
257dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
258dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
259dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
260dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
261dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size));
262dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
263bd882c8aSJames Wright 
264bd882c8aSJames Wright   // Count num_nodes
265bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(l_size, &is_node));
266bd882c8aSJames Wright   const CeedInt size_indices = num_elem * elem_size;
267dd64fc84SJeremy L Thompson 
268bd882c8aSJames Wright   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
269bd882c8aSJames Wright   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
270bd882c8aSJames Wright   impl->num_nodes = num_nodes;
271bd882c8aSJames Wright 
272bd882c8aSJames Wright   // L-vector offsets array
273bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
274bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
275dd64fc84SJeremy L Thompson   for (CeedInt i = 0, j = 0; i < l_size; i++) {
276bd882c8aSJames Wright     if (is_node[i]) {
277bd882c8aSJames Wright       l_vec_indices[j] = i;
278bd882c8aSJames Wright       ind_to_offset[i] = j++;
279bd882c8aSJames Wright     }
280bd882c8aSJames Wright   }
281bd882c8aSJames Wright   CeedCallBackend(CeedFree(&is_node));
282bd882c8aSJames Wright 
283bd882c8aSJames Wright   // Compute transpose offsets and indices
284bd882c8aSJames Wright   const CeedInt size_offsets = num_nodes + 1;
285dd64fc84SJeremy L Thompson 
286bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
287bd882c8aSJames Wright   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
288bd882c8aSJames Wright   // Count node multiplicity
289bd882c8aSJames Wright   for (CeedInt e = 0; e < num_elem; ++e) {
290bd882c8aSJames Wright     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
291bd882c8aSJames Wright   }
292bd882c8aSJames Wright   // Convert to running sum
293bd882c8aSJames Wright   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
294bd882c8aSJames Wright   // List all E-vec indices associated with L-vec node
295bd882c8aSJames Wright   for (CeedInt e = 0; e < num_elem; ++e) {
296bd882c8aSJames Wright     for (CeedInt i = 0; i < elem_size; ++i) {
297bd882c8aSJames Wright       const CeedInt lid                          = elem_size * e + i;
298bd882c8aSJames Wright       const CeedInt gid                          = indices[lid];
299bd882c8aSJames Wright       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
300bd882c8aSJames Wright     }
301bd882c8aSJames Wright   }
302bd882c8aSJames Wright   // Reset running sum
303bd882c8aSJames Wright   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
304bd882c8aSJames Wright   t_offsets[0] = 0;
305bd882c8aSJames Wright 
306bd882c8aSJames Wright   // Copy data to device
307bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
308bd882c8aSJames Wright 
3091f4b1b45SUmesh Unnikrishnan   std::vector<sycl::event> e;
3101f4b1b45SUmesh Unnikrishnan 
3111f4b1b45SUmesh Unnikrishnan   if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()};
312bd882c8aSJames Wright 
313bd882c8aSJames Wright   // -- L-vector indices
314bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_l_vec_indices = sycl::malloc_device<CeedInt>(num_nodes, data->sycl_device, data->sycl_context));
3151f4b1b45SUmesh Unnikrishnan   sycl::event copy_lvec = data->sycl_queue.copy<CeedInt>(l_vec_indices, impl->d_l_vec_indices, num_nodes, e);
316bd882c8aSJames Wright   // -- Transpose offsets
317bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_t_offsets = sycl::malloc_device<CeedInt>(size_offsets, data->sycl_device, data->sycl_context));
3181f4b1b45SUmesh Unnikrishnan   sycl::event copy_offsets = data->sycl_queue.copy<CeedInt>(t_offsets, impl->d_t_offsets, size_offsets, e);
319bd882c8aSJames Wright   // -- Transpose indices
320bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_t_indices = sycl::malloc_device<CeedInt>(size_indices, data->sycl_device, data->sycl_context));
3211f4b1b45SUmesh Unnikrishnan   sycl::event copy_indices = data->sycl_queue.copy<CeedInt>(t_indices, impl->d_t_indices, size_indices, e);
322bd882c8aSJames Wright 
323bd882c8aSJames Wright   // Wait for all copies to complete and handle exceptions
324bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_lvec, copy_offsets, copy_indices}));
325bd882c8aSJames Wright 
326bd882c8aSJames Wright   // Cleanup
327bd882c8aSJames Wright   CeedCallBackend(CeedFree(&ind_to_offset));
328bd882c8aSJames Wright   CeedCallBackend(CeedFree(&l_vec_indices));
329bd882c8aSJames Wright   CeedCallBackend(CeedFree(&t_offsets));
330bd882c8aSJames Wright   CeedCallBackend(CeedFree(&t_indices));
331*9bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
332bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
333bd882c8aSJames Wright }
334bd882c8aSJames Wright 
335bd882c8aSJames Wright //------------------------------------------------------------------------------
336bd882c8aSJames Wright // Create restriction
337bd882c8aSJames Wright //------------------------------------------------------------------------------
338f59ebe5eSJeremy L Thompson int CeedElemRestrictionCreate_Sycl(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *offsets, const bool *orients,
339dce49693SSebastian Grimberg                                    const CeedInt8 *curl_orients, CeedElemRestriction rstr) {
340bd882c8aSJames Wright   Ceed                      ceed;
341bd882c8aSJames Wright   Ceed_Sycl                *data;
342dd64fc84SJeremy L Thompson   bool                      is_strided;
343dd64fc84SJeremy L Thompson   CeedInt                   num_elem, num_comp, elem_size, comp_stride = 1;
344dd64fc84SJeremy L Thompson   CeedRestrictionType       rstr_type;
345bd882c8aSJames Wright   CeedElemRestriction_Sycl *impl;
346dd64fc84SJeremy L Thompson 
347dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
348dd64fc84SJeremy L Thompson   CeedCallBackend(CeedGetData(ceed, &data));
349dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
350dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
351dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
352dce49693SSebastian Grimberg   const CeedInt size       = num_elem * elem_size;
353bd882c8aSJames Wright   CeedInt       strides[3] = {1, size, elem_size};
354bd882c8aSJames Wright 
355dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
35600125730SSebastian Grimberg   CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND,
35700125730SSebastian Grimberg             "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented");
35800125730SSebastian Grimberg 
359bd882c8aSJames Wright   // Stride data
360dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionIsStrided(rstr, &is_strided));
361bd882c8aSJames Wright   if (is_strided) {
362bd882c8aSJames Wright     bool has_backend_strides;
363dd64fc84SJeremy L Thompson 
364dce49693SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides));
365bd882c8aSJames Wright     if (!has_backend_strides) {
36656c48462SJeremy L Thompson       CeedCallBackend(CeedElemRestrictionGetStrides(rstr, strides));
367bd882c8aSJames Wright     }
368bd882c8aSJames Wright   } else {
369dce49693SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride));
370bd882c8aSJames Wright   }
371bd882c8aSJames Wright 
372dce49693SSebastian Grimberg   CeedCallBackend(CeedCalloc(1, &impl));
373bd882c8aSJames Wright   impl->num_nodes   = size;
374bd882c8aSJames Wright   impl->num_elem    = num_elem;
375bd882c8aSJames Wright   impl->num_comp    = num_comp;
376bd882c8aSJames Wright   impl->elem_size   = elem_size;
377bd882c8aSJames Wright   impl->comp_stride = comp_stride;
378bd882c8aSJames Wright   impl->strides[0]  = strides[0];
379bd882c8aSJames Wright   impl->strides[1]  = strides[1];
380bd882c8aSJames Wright   impl->strides[2]  = strides[2];
381dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionSetData(rstr, impl));
38222eb1385SJeremy L Thompson 
38322eb1385SJeremy L Thompson   // Set layouts
38422eb1385SJeremy L Thompson   {
38522eb1385SJeremy L Thompson     bool    has_backend_strides;
38622eb1385SJeremy L Thompson     CeedInt layout[3] = {1, size, elem_size};
38722eb1385SJeremy L Thompson 
388dce49693SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout));
38922eb1385SJeremy L Thompson     if (rstr_type == CEED_RESTRICTION_STRIDED) {
39022eb1385SJeremy L Thompson       CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides));
39122eb1385SJeremy L Thompson       if (has_backend_strides) {
39222eb1385SJeremy L Thompson         CeedCallBackend(CeedElemRestrictionSetLLayout(rstr, layout));
39322eb1385SJeremy L Thompson       }
39422eb1385SJeremy L Thompson     }
39522eb1385SJeremy L Thompson   }
396bd882c8aSJames Wright 
397bd882c8aSJames Wright   // Set up device indices/offset arrays
3989d1bceceSJames Wright   switch (mem_type) {
39942b3fd1bSJeremy L Thompson     case CEED_MEM_HOST: {
400bd882c8aSJames Wright       switch (copy_mode) {
401f59ebe5eSJeremy L Thompson         case CEED_COPY_VALUES:
402f59ebe5eSJeremy L Thompson           if (offsets != NULL) {
403f59ebe5eSJeremy L Thompson             CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_offsets_owned));
404f59ebe5eSJeremy L Thompson             memcpy(impl->h_offsets_owned, offsets, elem_size * num_elem * sizeof(CeedInt));
405f59ebe5eSJeremy L Thompson             impl->h_offsets_borrowed = NULL;
406f59ebe5eSJeremy L Thompson             impl->h_offsets          = impl->h_offsets_owned;
407f59ebe5eSJeremy L Thompson           }
408f59ebe5eSJeremy L Thompson           break;
409bd882c8aSJames Wright         case CEED_OWN_POINTER:
410f59ebe5eSJeremy L Thompson           impl->h_offsets_owned    = (CeedInt *)offsets;
411f59ebe5eSJeremy L Thompson           impl->h_offsets_borrowed = NULL;
412f59ebe5eSJeremy L Thompson           impl->h_offsets          = impl->h_offsets_owned;
413bd882c8aSJames Wright           break;
414bd882c8aSJames Wright         case CEED_USE_POINTER:
415f59ebe5eSJeremy L Thompson           impl->h_offsets_owned    = NULL;
416f59ebe5eSJeremy L Thompson           impl->h_offsets_borrowed = (CeedInt *)offsets;
417f59ebe5eSJeremy L Thompson           impl->h_offsets          = impl->h_offsets_borrowed;
418bd882c8aSJames Wright           break;
419bd882c8aSJames Wright       }
420f59ebe5eSJeremy L Thompson       if (offsets != NULL) {
421f59ebe5eSJeremy L Thompson         CeedCallSycl(ceed, impl->d_offsets_owned = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context));
422bd882c8aSJames Wright         // Copy from host to device
423f59ebe5eSJeremy L Thompson         // -- Order queue
424f59ebe5eSJeremy L Thompson         sycl::event e          = data->sycl_queue.ext_oneapi_submit_barrier();
425f59ebe5eSJeremy L Thompson         sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->h_offsets, impl->d_offsets_owned, size, {e});
426f59ebe5eSJeremy L Thompson         // -- Wait for copy to finish and handle exceptions
427bd882c8aSJames Wright         CeedCallSycl(ceed, copy_event.wait_and_throw());
428f59ebe5eSJeremy L Thompson         impl->d_offsets = impl->d_offsets_owned;
429f59ebe5eSJeremy L Thompson         CeedCallBackend(CeedElemRestrictionOffset_Sycl(rstr, offsets));
430bd882c8aSJames Wright       }
43142b3fd1bSJeremy L Thompson     } break;
4329d1bceceSJames Wright     case CEED_MEM_DEVICE: {
433bd882c8aSJames Wright       switch (copy_mode) {
434bd882c8aSJames Wright         case CEED_COPY_VALUES:
435f59ebe5eSJeremy L Thompson           if (offsets != NULL) {
436f59ebe5eSJeremy L Thompson             CeedCallSycl(ceed, impl->d_offsets_owned = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context));
437bd882c8aSJames Wright             // Copy from device to device
438f59ebe5eSJeremy L Thompson             // -- Order queue
439bd882c8aSJames Wright             sycl::event e          = data->sycl_queue.ext_oneapi_submit_barrier();
440f59ebe5eSJeremy L Thompson             sycl::event copy_event = data->sycl_queue.copy<CeedInt>(offsets, impl->d_offsets_owned, size, {e});
441f59ebe5eSJeremy L Thompson             // -- Wait for copy to finish and handle exceptions
442bd882c8aSJames Wright             CeedCallSycl(ceed, copy_event.wait_and_throw());
443f59ebe5eSJeremy L Thompson             impl->d_offsets = impl->d_offsets_owned;
444bd882c8aSJames Wright           }
445bd882c8aSJames Wright           break;
446bd882c8aSJames Wright         case CEED_OWN_POINTER:
447f59ebe5eSJeremy L Thompson           impl->d_offsets_owned    = (CeedInt *)offsets;
448f59ebe5eSJeremy L Thompson           impl->d_offsets_borrowed = NULL;
449f59ebe5eSJeremy L Thompson           impl->d_offsets          = impl->d_offsets_owned;
450bd882c8aSJames Wright           break;
451bd882c8aSJames Wright         case CEED_USE_POINTER:
452f59ebe5eSJeremy L Thompson           impl->d_offsets_owned    = NULL;
453f59ebe5eSJeremy L Thompson           impl->d_offsets_borrowed = (CeedInt *)offsets;
454f59ebe5eSJeremy L Thompson           impl->d_offsets          = impl->d_offsets_borrowed;
455bd882c8aSJames Wright       }
456f59ebe5eSJeremy L Thompson       if (offsets != NULL) {
457f59ebe5eSJeremy L Thompson         CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_offsets_owned));
458bd882c8aSJames Wright         // Copy from device to host
459f59ebe5eSJeremy L Thompson         // -- Order queue
460f59ebe5eSJeremy L Thompson         sycl::event e          = data->sycl_queue.ext_oneapi_submit_barrier();
461f59ebe5eSJeremy L Thompson         sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->d_offsets, impl->h_offsets_owned, elem_size * num_elem, {e});
462f59ebe5eSJeremy L Thompson         // -- Wait for copy to finish and handle exceptions
463bd882c8aSJames Wright         CeedCallSycl(ceed, copy_event.wait_and_throw());
464f59ebe5eSJeremy L Thompson         impl->h_offsets = impl->h_offsets_owned;
465f59ebe5eSJeremy L Thompson         CeedCallBackend(CeedElemRestrictionOffset_Sycl(rstr, offsets));
466bd882c8aSJames Wright       }
467bd882c8aSJames Wright     }
4689d1bceceSJames Wright   }
469bd882c8aSJames Wright 
470bd882c8aSJames Wright   // Register backend functions
471dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Sycl));
472dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApply_Sycl));
473dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApply_Sycl));
474dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Sycl));
475dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Sycl));
476*9bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
477bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
478bd882c8aSJames Wright }
479