xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-ref/ceed-sycl-restriction.sycl.cpp (revision bd882c8a454763a096666645dc9a6229d5263694)
1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other
2*bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3*bd882c8aSJames Wright // files for details.
4*bd882c8aSJames Wright //
5*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
6*bd882c8aSJames Wright //
7*bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
8*bd882c8aSJames Wright 
9*bd882c8aSJames Wright #include <ceed/backend.h>
10*bd882c8aSJames Wright #include <ceed/ceed.h>
11*bd882c8aSJames Wright #include <ceed/jit-tools.h>
12*bd882c8aSJames Wright 
13*bd882c8aSJames Wright #include <string>
14*bd882c8aSJames Wright #include <sycl/sycl.hpp>
15*bd882c8aSJames Wright 
16*bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
17*bd882c8aSJames Wright #include "ceed-sycl-ref.hpp"
18*bd882c8aSJames Wright 
19*bd882c8aSJames Wright class CeedElemRestrSyclStridedNT;
20*bd882c8aSJames Wright class CeedElemRestrSyclOffsetNT;
21*bd882c8aSJames Wright class CeedElemRestrSyclStridedT;
22*bd882c8aSJames Wright class CeedElemRestrSyclOffsetT;
23*bd882c8aSJames Wright 
24*bd882c8aSJames Wright //------------------------------------------------------------------------------
25*bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, strided
26*bd882c8aSJames Wright //------------------------------------------------------------------------------
27*bd882c8aSJames Wright static int CeedElemRestrictionStridedNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
28*bd882c8aSJames Wright                                                       CeedScalar *v) {
29*bd882c8aSJames Wright   const CeedInt  elem_size    = impl->elem_size;
30*bd882c8aSJames Wright   const CeedInt  num_elem     = impl->num_elem;
31*bd882c8aSJames Wright   const CeedInt  num_comp     = impl->num_comp;
32*bd882c8aSJames Wright   const CeedInt  stride_nodes = impl->strides[0];
33*bd882c8aSJames Wright   const CeedInt  stride_comp  = impl->strides[1];
34*bd882c8aSJames Wright   const CeedInt  stride_elem  = impl->strides[2];
35*bd882c8aSJames Wright   sycl::range<1> kernel_range(num_elem * elem_size);
36*bd882c8aSJames Wright 
37*bd882c8aSJames Wright   // Order queue
38*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
39*bd882c8aSJames Wright   sycl_queue.parallel_for<CeedElemRestrSyclStridedNT>(kernel_range, {e}, [=](sycl::id<1> node) {
40*bd882c8aSJames Wright     const CeedInt loc_node = node % elem_size;
41*bd882c8aSJames Wright     const CeedInt elem     = node / elem_size;
42*bd882c8aSJames Wright 
43*bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
44*bd882c8aSJames Wright       v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem];
45*bd882c8aSJames Wright     }
46*bd882c8aSJames Wright   });
47*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
48*bd882c8aSJames Wright }
49*bd882c8aSJames Wright 
50*bd882c8aSJames Wright //------------------------------------------------------------------------------
51*bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, offsets provided
52*bd882c8aSJames Wright //------------------------------------------------------------------------------
53*bd882c8aSJames Wright static int CeedElemRestrictionOffsetNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
54*bd882c8aSJames Wright                                                      CeedScalar *v) {
55*bd882c8aSJames Wright   const CeedInt elem_size   = impl->elem_size;
56*bd882c8aSJames Wright   const CeedInt num_elem    = impl->num_elem;
57*bd882c8aSJames Wright   const CeedInt num_comp    = impl->num_comp;
58*bd882c8aSJames Wright   const CeedInt comp_stride = impl->comp_stride;
59*bd882c8aSJames Wright 
60*bd882c8aSJames Wright   const CeedInt *indices = impl->d_ind;
61*bd882c8aSJames Wright 
62*bd882c8aSJames Wright   sycl::range<1> kernel_range(num_elem * elem_size);
63*bd882c8aSJames Wright 
64*bd882c8aSJames Wright   // Order queue
65*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
66*bd882c8aSJames Wright   sycl_queue.parallel_for<CeedElemRestrSyclOffsetNT>(kernel_range, {e}, [=](sycl::id<1> node) {
67*bd882c8aSJames Wright     const CeedInt ind      = indices[node];
68*bd882c8aSJames Wright     const CeedInt loc_node = node % elem_size;
69*bd882c8aSJames Wright     const CeedInt elem     = node / elem_size;
70*bd882c8aSJames Wright 
71*bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
72*bd882c8aSJames Wright       v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[ind + comp * comp_stride];
73*bd882c8aSJames Wright     }
74*bd882c8aSJames Wright   });
75*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
76*bd882c8aSJames Wright }
77*bd882c8aSJames Wright 
78*bd882c8aSJames Wright //------------------------------------------------------------------------------
79*bd882c8aSJames Wright // Kernel: E-vector -> L-vector, strided
80*bd882c8aSJames Wright //------------------------------------------------------------------------------
81*bd882c8aSJames Wright static int CeedElemRestrictionStridedTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
82*bd882c8aSJames Wright                                                     CeedScalar *v) {
83*bd882c8aSJames Wright   const CeedInt elem_size    = impl->elem_size;
84*bd882c8aSJames Wright   const CeedInt num_elem     = impl->num_elem;
85*bd882c8aSJames Wright   const CeedInt num_comp     = impl->num_comp;
86*bd882c8aSJames Wright   const CeedInt stride_nodes = impl->strides[0];
87*bd882c8aSJames Wright   const CeedInt stride_comp  = impl->strides[1];
88*bd882c8aSJames Wright   const CeedInt stride_elem  = impl->strides[2];
89*bd882c8aSJames Wright 
90*bd882c8aSJames Wright   sycl::range<1> kernel_range(num_elem * elem_size);
91*bd882c8aSJames Wright 
92*bd882c8aSJames Wright   // Order queue
93*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
94*bd882c8aSJames Wright   sycl_queue.parallel_for<CeedElemRestrSyclStridedT>(kernel_range, {e}, [=](sycl::id<1> node) {
95*bd882c8aSJames Wright     const CeedInt loc_node = node % elem_size;
96*bd882c8aSJames Wright     const CeedInt elem     = node / elem_size;
97*bd882c8aSJames Wright 
98*bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
99*bd882c8aSJames Wright       v[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem] += u[loc_node + comp * elem_size * num_elem + elem * elem_size];
100*bd882c8aSJames Wright     }
101*bd882c8aSJames Wright   });
102*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
103*bd882c8aSJames Wright }
104*bd882c8aSJames Wright 
105*bd882c8aSJames Wright //------------------------------------------------------------------------------
106*bd882c8aSJames Wright // Kernel: E-vector -> L-vector, offsets provided
107*bd882c8aSJames Wright //------------------------------------------------------------------------------
108*bd882c8aSJames Wright static int CeedElemRestrictionOffsetTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
109*bd882c8aSJames Wright                                                    CeedScalar *v) {
110*bd882c8aSJames Wright   const CeedInt num_nodes   = impl->num_nodes;
111*bd882c8aSJames Wright   const CeedInt elem_size   = impl->elem_size;
112*bd882c8aSJames Wright   const CeedInt num_elem    = impl->num_elem;
113*bd882c8aSJames Wright   const CeedInt num_comp    = impl->num_comp;
114*bd882c8aSJames Wright   const CeedInt comp_stride = impl->comp_stride;
115*bd882c8aSJames Wright 
116*bd882c8aSJames Wright   const CeedInt *l_vec_indices = impl->d_l_vec_indices;
117*bd882c8aSJames Wright   const CeedInt *t_offsets     = impl->d_t_offsets;
118*bd882c8aSJames Wright   const CeedInt *t_indices     = impl->d_t_indices;
119*bd882c8aSJames Wright 
120*bd882c8aSJames Wright   sycl::range<1> kernel_range(num_nodes * num_comp);
121*bd882c8aSJames Wright 
122*bd882c8aSJames Wright   // Order queue
123*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
124*bd882c8aSJames Wright   sycl_queue.parallel_for<CeedElemRestrSyclOffsetT>(kernel_range, {e}, [=](sycl::id<1> id) {
125*bd882c8aSJames Wright     const CeedInt node    = id % num_nodes;
126*bd882c8aSJames Wright     const CeedInt comp    = id / num_nodes;
127*bd882c8aSJames Wright     const CeedInt ind     = l_vec_indices[node];
128*bd882c8aSJames Wright     const CeedInt range_1 = t_offsets[node];
129*bd882c8aSJames Wright     const CeedInt range_N = t_offsets[node + 1];
130*bd882c8aSJames Wright 
131*bd882c8aSJames Wright     CeedScalar value = 0.0;
132*bd882c8aSJames Wright 
133*bd882c8aSJames Wright     for (CeedInt j = range_1; j < range_N; j++) {
134*bd882c8aSJames Wright       const CeedInt t_ind    = t_indices[j];
135*bd882c8aSJames Wright       CeedInt       loc_node = t_ind % elem_size;
136*bd882c8aSJames Wright       CeedInt       elem     = t_ind / elem_size;
137*bd882c8aSJames Wright 
138*bd882c8aSJames Wright       value += u[loc_node + comp * elem_size * num_elem + elem * elem_size];
139*bd882c8aSJames Wright     }
140*bd882c8aSJames Wright     v[ind + comp * comp_stride] += value;
141*bd882c8aSJames Wright   });
142*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
143*bd882c8aSJames Wright }
144*bd882c8aSJames Wright 
145*bd882c8aSJames Wright //------------------------------------------------------------------------------
146*bd882c8aSJames Wright // Apply restriction
147*bd882c8aSJames Wright //------------------------------------------------------------------------------
148*bd882c8aSJames Wright static int CeedElemRestrictionApply_Sycl(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
149*bd882c8aSJames Wright   Ceed ceed;
150*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
151*bd882c8aSJames Wright   CeedElemRestriction_Sycl *impl;
152*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
153*bd882c8aSJames Wright   Ceed_Sycl *data;
154*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
155*bd882c8aSJames Wright 
156*bd882c8aSJames Wright   // Get vectors
157*bd882c8aSJames Wright   const CeedScalar *d_u;
158*bd882c8aSJames Wright   CeedScalar       *d_v;
159*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
160*bd882c8aSJames Wright   if (t_mode == CEED_TRANSPOSE) {
161*bd882c8aSJames Wright     // Sum into for transpose mode, e-vec to l-vec
162*bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
163*bd882c8aSJames Wright   } else {
164*bd882c8aSJames Wright     // Overwrite for notranspose mode, l-vec to e-vec
165*bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
166*bd882c8aSJames Wright   }
167*bd882c8aSJames Wright 
168*bd882c8aSJames Wright   // Restrict
169*bd882c8aSJames Wright   if (t_mode == CEED_NOTRANSPOSE) {
170*bd882c8aSJames Wright     // L-vector -> E-vector
171*bd882c8aSJames Wright     if (impl->d_ind) {
172*bd882c8aSJames Wright       // -- Offsets provided
173*bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionOffsetNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
174*bd882c8aSJames Wright     } else {
175*bd882c8aSJames Wright       // -- Strided restriction
176*bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionStridedNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
177*bd882c8aSJames Wright     }
178*bd882c8aSJames Wright   } else {
179*bd882c8aSJames Wright     // E-vector -> L-vector
180*bd882c8aSJames Wright     if (impl->d_ind) {
181*bd882c8aSJames Wright       // -- Offsets provided
182*bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionOffsetTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
183*bd882c8aSJames Wright     } else {
184*bd882c8aSJames Wright       // -- Strided restriction
185*bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionStridedTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
186*bd882c8aSJames Wright     }
187*bd882c8aSJames Wright   }
188*bd882c8aSJames Wright   // Wait for queues to be completed. NOTE: This may not be necessary
189*bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
190*bd882c8aSJames Wright 
191*bd882c8aSJames Wright   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
192*bd882c8aSJames Wright 
193*bd882c8aSJames Wright   // Restore arrays
194*bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
195*bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
196*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
197*bd882c8aSJames Wright }
198*bd882c8aSJames Wright 
199*bd882c8aSJames Wright //------------------------------------------------------------------------------
200*bd882c8aSJames Wright // Get offsets
201*bd882c8aSJames Wright //------------------------------------------------------------------------------
202*bd882c8aSJames Wright static int CeedElemRestrictionGetOffsets_Sycl(CeedElemRestriction r, CeedMemType m_type, const CeedInt **offsets) {
203*bd882c8aSJames Wright   Ceed ceed;
204*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
205*bd882c8aSJames Wright   CeedElemRestriction_Sycl *impl;
206*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
207*bd882c8aSJames Wright 
208*bd882c8aSJames Wright   switch (m_type) {
209*bd882c8aSJames Wright     case CEED_MEM_HOST:
210*bd882c8aSJames Wright       *offsets = impl->h_ind;
211*bd882c8aSJames Wright       break;
212*bd882c8aSJames Wright     case CEED_MEM_DEVICE:
213*bd882c8aSJames Wright       *offsets = impl->d_ind;
214*bd882c8aSJames Wright       break;
215*bd882c8aSJames Wright   }
216*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
217*bd882c8aSJames Wright }
218*bd882c8aSJames Wright 
219*bd882c8aSJames Wright //------------------------------------------------------------------------------
220*bd882c8aSJames Wright // Destroy restriction
221*bd882c8aSJames Wright //------------------------------------------------------------------------------
222*bd882c8aSJames Wright static int CeedElemRestrictionDestroy_Sycl(CeedElemRestriction r) {
223*bd882c8aSJames Wright   Ceed ceed;
224*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
225*bd882c8aSJames Wright   CeedElemRestriction_Sycl *impl;
226*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
227*bd882c8aSJames Wright   Ceed_Sycl *data;
228*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
229*bd882c8aSJames Wright 
230*bd882c8aSJames Wright   // Wait for all work to finish before freeing memory
231*bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
232*bd882c8aSJames Wright 
233*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
234*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_ind_allocated, data->sycl_context));
235*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_t_offsets, data->sycl_context));
236*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_t_indices, data->sycl_context));
237*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_l_vec_indices, data->sycl_context));
238*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
239*bd882c8aSJames Wright 
240*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
241*bd882c8aSJames Wright }
242*bd882c8aSJames Wright 
243*bd882c8aSJames Wright //------------------------------------------------------------------------------
244*bd882c8aSJames Wright // Create transpose offsets and indices
245*bd882c8aSJames Wright //------------------------------------------------------------------------------
246*bd882c8aSJames Wright static int CeedElemRestrictionOffset_Sycl(const CeedElemRestriction r, const CeedInt *indices) {
247*bd882c8aSJames Wright   Ceed ceed;
248*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
249*bd882c8aSJames Wright   CeedElemRestriction_Sycl *impl;
250*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
251*bd882c8aSJames Wright   CeedSize l_size;
252*bd882c8aSJames Wright   CeedInt  num_elem, elem_size, num_comp;
253*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
254*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
255*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size));
256*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
257*bd882c8aSJames Wright 
258*bd882c8aSJames Wright   // Count num_nodes
259*bd882c8aSJames Wright   bool *is_node;
260*bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(l_size, &is_node));
261*bd882c8aSJames Wright   const CeedInt size_indices = num_elem * elem_size;
262*bd882c8aSJames Wright   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
263*bd882c8aSJames Wright   CeedInt num_nodes = 0;
264*bd882c8aSJames Wright   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
265*bd882c8aSJames Wright   impl->num_nodes = num_nodes;
266*bd882c8aSJames Wright 
267*bd882c8aSJames Wright   // L-vector offsets array
268*bd882c8aSJames Wright   CeedInt *ind_to_offset, *l_vec_indices;
269*bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
270*bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
271*bd882c8aSJames Wright   CeedInt j = 0;
272*bd882c8aSJames Wright   for (CeedInt i = 0; i < l_size; i++) {
273*bd882c8aSJames Wright     if (is_node[i]) {
274*bd882c8aSJames Wright       l_vec_indices[j] = i;
275*bd882c8aSJames Wright       ind_to_offset[i] = j++;
276*bd882c8aSJames Wright     }
277*bd882c8aSJames Wright   }
278*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&is_node));
279*bd882c8aSJames Wright 
280*bd882c8aSJames Wright   // Compute transpose offsets and indices
281*bd882c8aSJames Wright   const CeedInt size_offsets = num_nodes + 1;
282*bd882c8aSJames Wright   CeedInt      *t_offsets;
283*bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
284*bd882c8aSJames Wright   CeedInt *t_indices;
285*bd882c8aSJames Wright   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
286*bd882c8aSJames Wright   // Count node multiplicity
287*bd882c8aSJames Wright   for (CeedInt e = 0; e < num_elem; ++e) {
288*bd882c8aSJames Wright     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
289*bd882c8aSJames Wright   }
290*bd882c8aSJames Wright   // Convert to running sum
291*bd882c8aSJames Wright   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
292*bd882c8aSJames Wright   // List all E-vec indices associated with L-vec node
293*bd882c8aSJames Wright   for (CeedInt e = 0; e < num_elem; ++e) {
294*bd882c8aSJames Wright     for (CeedInt i = 0; i < elem_size; ++i) {
295*bd882c8aSJames Wright       const CeedInt lid                          = elem_size * e + i;
296*bd882c8aSJames Wright       const CeedInt gid                          = indices[lid];
297*bd882c8aSJames Wright       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
298*bd882c8aSJames Wright     }
299*bd882c8aSJames Wright   }
300*bd882c8aSJames Wright   // Reset running sum
301*bd882c8aSJames Wright   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
302*bd882c8aSJames Wright   t_offsets[0] = 0;
303*bd882c8aSJames Wright 
304*bd882c8aSJames Wright   // Copy data to device
305*bd882c8aSJames Wright   Ceed_Sycl *data;
306*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
307*bd882c8aSJames Wright 
308*bd882c8aSJames Wright   // Order queue
309*bd882c8aSJames Wright   sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
310*bd882c8aSJames Wright 
311*bd882c8aSJames Wright   // -- L-vector indices
312*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_l_vec_indices = sycl::malloc_device<CeedInt>(num_nodes, data->sycl_device, data->sycl_context));
313*bd882c8aSJames Wright   sycl::event copy_lvec = data->sycl_queue.copy<CeedInt>(l_vec_indices, impl->d_l_vec_indices, num_nodes, {e});
314*bd882c8aSJames Wright   // -- Transpose offsets
315*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_t_offsets = sycl::malloc_device<CeedInt>(size_offsets, data->sycl_device, data->sycl_context));
316*bd882c8aSJames Wright   sycl::event copy_offsets = data->sycl_queue.copy<CeedInt>(t_offsets, impl->d_t_offsets, size_offsets, {e});
317*bd882c8aSJames Wright   // -- Transpose indices
318*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_t_indices = sycl::malloc_device<CeedInt>(size_indices, data->sycl_device, data->sycl_context));
319*bd882c8aSJames Wright   sycl::event copy_indices = data->sycl_queue.copy<CeedInt>(t_indices, impl->d_t_indices, size_indices, {e});
320*bd882c8aSJames Wright 
321*bd882c8aSJames Wright   // Wait for all copies to complete and handle exceptions
322*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_lvec, copy_offsets, copy_indices}));
323*bd882c8aSJames Wright 
324*bd882c8aSJames Wright   // Cleanup
325*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&ind_to_offset));
326*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&l_vec_indices));
327*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&t_offsets));
328*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&t_indices));
329*bd882c8aSJames Wright 
330*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
331*bd882c8aSJames Wright }
332*bd882c8aSJames Wright 
333*bd882c8aSJames Wright //------------------------------------------------------------------------------
334*bd882c8aSJames Wright // Create restriction
335*bd882c8aSJames Wright //------------------------------------------------------------------------------
336*bd882c8aSJames Wright int CeedElemRestrictionCreate_Sycl(CeedMemType m_type, CeedCopyMode copy_mode, const CeedInt *indices, CeedElemRestriction r) {
337*bd882c8aSJames Wright   Ceed ceed;
338*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
339*bd882c8aSJames Wright   Ceed_Sycl *data;
340*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
341*bd882c8aSJames Wright   CeedElemRestriction_Sycl *impl;
342*bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(1, &impl));
343*bd882c8aSJames Wright   CeedInt num_elem, num_comp, elem_size;
344*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
345*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
346*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
347*bd882c8aSJames Wright   CeedInt size        = num_elem * elem_size;
348*bd882c8aSJames Wright   CeedInt strides[3]  = {1, size, elem_size};
349*bd882c8aSJames Wright   CeedInt comp_stride = 1;
350*bd882c8aSJames Wright 
351*bd882c8aSJames Wright   // Stride data
352*bd882c8aSJames Wright   bool is_strided;
353*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided));
354*bd882c8aSJames Wright   if (is_strided) {
355*bd882c8aSJames Wright     bool has_backend_strides;
356*bd882c8aSJames Wright     CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides));
357*bd882c8aSJames Wright     if (!has_backend_strides) {
358*bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides));
359*bd882c8aSJames Wright     }
360*bd882c8aSJames Wright   } else {
361*bd882c8aSJames Wright     CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride));
362*bd882c8aSJames Wright   }
363*bd882c8aSJames Wright 
364*bd882c8aSJames Wright   impl->h_ind           = NULL;
365*bd882c8aSJames Wright   impl->h_ind_allocated = NULL;
366*bd882c8aSJames Wright   impl->d_ind           = NULL;
367*bd882c8aSJames Wright   impl->d_ind_allocated = NULL;
368*bd882c8aSJames Wright   impl->d_t_indices     = NULL;
369*bd882c8aSJames Wright   impl->d_t_offsets     = NULL;
370*bd882c8aSJames Wright   impl->num_nodes       = size;
371*bd882c8aSJames Wright   impl->num_elem        = num_elem;
372*bd882c8aSJames Wright   impl->num_comp        = num_comp;
373*bd882c8aSJames Wright   impl->elem_size       = elem_size;
374*bd882c8aSJames Wright   impl->comp_stride     = comp_stride;
375*bd882c8aSJames Wright   impl->strides[0]      = strides[0];
376*bd882c8aSJames Wright   impl->strides[1]      = strides[1];
377*bd882c8aSJames Wright   impl->strides[2]      = strides[2];
378*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionSetData(r, impl));
379*bd882c8aSJames Wright   CeedInt layout[3] = {1, elem_size * num_elem, elem_size};
380*bd882c8aSJames Wright   CeedCallBackend(CeedElemRestrictionSetELayout(r, layout));
381*bd882c8aSJames Wright 
382*bd882c8aSJames Wright   // Set up device indices/offset arrays
383*bd882c8aSJames Wright   if (m_type == CEED_MEM_HOST) {
384*bd882c8aSJames Wright     switch (copy_mode) {
385*bd882c8aSJames Wright       case CEED_OWN_POINTER:
386*bd882c8aSJames Wright         impl->h_ind_allocated = (CeedInt *)indices;
387*bd882c8aSJames Wright         impl->h_ind           = (CeedInt *)indices;
388*bd882c8aSJames Wright         break;
389*bd882c8aSJames Wright       case CEED_USE_POINTER:
390*bd882c8aSJames Wright         impl->h_ind = (CeedInt *)indices;
391*bd882c8aSJames Wright         break;
392*bd882c8aSJames Wright       case CEED_COPY_VALUES:
393*bd882c8aSJames Wright         if (indices != NULL) {
394*bd882c8aSJames Wright           CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
395*bd882c8aSJames Wright           memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
396*bd882c8aSJames Wright           impl->h_ind = impl->h_ind_allocated;
397*bd882c8aSJames Wright         }
398*bd882c8aSJames Wright         break;
399*bd882c8aSJames Wright     }
400*bd882c8aSJames Wright     if (indices != NULL) {
401*bd882c8aSJames Wright       CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context));
402*bd882c8aSJames Wright       impl->d_ind_allocated = impl->d_ind;  // We own the device memory
403*bd882c8aSJames Wright       // Order queue
404*bd882c8aSJames Wright       sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
405*bd882c8aSJames Wright       // Copy from host to device
406*bd882c8aSJames Wright       sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e});
407*bd882c8aSJames Wright       // Wait for copy to finish and handle exceptions
408*bd882c8aSJames Wright       CeedCallSycl(ceed, copy_event.wait_and_throw());
409*bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionOffset_Sycl(r, indices));
410*bd882c8aSJames Wright     }
411*bd882c8aSJames Wright   } else if (m_type == CEED_MEM_DEVICE) {
412*bd882c8aSJames Wright     switch (copy_mode) {
413*bd882c8aSJames Wright       case CEED_COPY_VALUES:
414*bd882c8aSJames Wright         if (indices != NULL) {
415*bd882c8aSJames Wright           CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context));
416*bd882c8aSJames Wright           impl->d_ind_allocated = impl->d_ind;  // We own the device memory
417*bd882c8aSJames Wright                                                 // Copy from device to device
418*bd882c8aSJames Wright           // Order queue
419*bd882c8aSJames Wright           sycl::event e          = data->sycl_queue.ext_oneapi_submit_barrier();
420*bd882c8aSJames Wright           sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e});
421*bd882c8aSJames Wright           // Wait for copy to finish and handle exceptions
422*bd882c8aSJames Wright           CeedCallSycl(ceed, copy_event.wait_and_throw());
423*bd882c8aSJames Wright         }
424*bd882c8aSJames Wright         break;
425*bd882c8aSJames Wright       case CEED_OWN_POINTER:
426*bd882c8aSJames Wright         impl->d_ind           = (CeedInt *)indices;
427*bd882c8aSJames Wright         impl->d_ind_allocated = impl->d_ind;
428*bd882c8aSJames Wright         break;
429*bd882c8aSJames Wright       case CEED_USE_POINTER:
430*bd882c8aSJames Wright         impl->d_ind = (CeedInt *)indices;
431*bd882c8aSJames Wright     }
432*bd882c8aSJames Wright     if (indices != NULL) {
433*bd882c8aSJames Wright       CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
434*bd882c8aSJames Wright       // Order queue
435*bd882c8aSJames Wright       sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
436*bd882c8aSJames Wright       // Copy from device to host
437*bd882c8aSJames Wright       sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->d_ind, impl->h_ind_allocated, elem_size * num_elem, {e});
438*bd882c8aSJames Wright       CeedCallSycl(ceed, copy_event.wait_and_throw());
439*bd882c8aSJames Wright       impl->h_ind = impl->h_ind_allocated;
440*bd882c8aSJames Wright       CeedCallBackend(CeedElemRestrictionOffset_Sycl(r, indices));
441*bd882c8aSJames Wright     }
442*bd882c8aSJames Wright   } else {
443*bd882c8aSJames Wright     // LCOV_EXCL_START
444*bd882c8aSJames Wright     return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported");
445*bd882c8aSJames Wright     // LCOV_EXCL_STOP
446*bd882c8aSJames Wright   }
447*bd882c8aSJames Wright 
448*bd882c8aSJames Wright   // Register backend functions
449*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Sycl));
450*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Sycl));
451*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Sycl));
452*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Sycl));
453*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
454*bd882c8aSJames Wright }
455