xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-ref/ceed-sycl-ref-basis.sycl.cpp (revision bd882c8a454763a096666645dc9a6229d5263694)
1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*bd882c8aSJames Wright //
4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5*bd882c8aSJames Wright //
6*bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7*bd882c8aSJames Wright 
8*bd882c8aSJames Wright #include <ceed/backend.h>
9*bd882c8aSJames Wright #include <ceed/ceed.h>
10*bd882c8aSJames Wright #include <ceed/jit-tools.h>
11*bd882c8aSJames Wright 
12*bd882c8aSJames Wright #include <sycl/sycl.hpp>
13*bd882c8aSJames Wright #include <vector>
14*bd882c8aSJames Wright 
15*bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
16*bd882c8aSJames Wright #include "ceed-sycl-ref.hpp"
17*bd882c8aSJames Wright 
18*bd882c8aSJames Wright template <int>
19*bd882c8aSJames Wright class CeedBasisSyclInterp;
20*bd882c8aSJames Wright template <int>
21*bd882c8aSJames Wright class CeedBasisSyclGrad;
22*bd882c8aSJames Wright class CeedBasisSyclWeight;
23*bd882c8aSJames Wright 
24*bd882c8aSJames Wright class CeedBasisSyclInterpNT;
25*bd882c8aSJames Wright class CeedBasisSyclGradNT;
26*bd882c8aSJames Wright class CeedBasisSyclWeightNT;
27*bd882c8aSJames Wright 
28*bd882c8aSJames Wright using SpecID = sycl::specialization_id<CeedInt>;
29*bd882c8aSJames Wright 
30*bd882c8aSJames Wright static constexpr SpecID BASIS_DIM_ID;
31*bd882c8aSJames Wright static constexpr SpecID BASIS_NUM_COMP_ID;
32*bd882c8aSJames Wright static constexpr SpecID BASIS_P_1D_ID;
33*bd882c8aSJames Wright static constexpr SpecID BASIS_Q_1D_ID;
34*bd882c8aSJames Wright 
35*bd882c8aSJames Wright //------------------------------------------------------------------------------
36*bd882c8aSJames Wright // Interpolation kernel - tensor
37*bd882c8aSJames Wright //------------------------------------------------------------------------------
38*bd882c8aSJames Wright template <int transpose>
39*bd882c8aSJames Wright static int CeedBasisApplyInterp_Sycl(sycl::queue &sycl_queue, const SyclModule_t &sycl_module, CeedInt num_elem, const CeedBasis_Sycl *impl,
40*bd882c8aSJames Wright                                      const CeedScalar *u, CeedScalar *v) {
41*bd882c8aSJames Wright   const CeedInt     buf_len   = impl->buf_len;
42*bd882c8aSJames Wright   const CeedInt     op_len    = impl->op_len;
43*bd882c8aSJames Wright   const CeedScalar *interp_1d = impl->d_interp_1d;
44*bd882c8aSJames Wright 
45*bd882c8aSJames Wright   const sycl::device &sycl_device         = sycl_queue.get_device();
46*bd882c8aSJames Wright   const CeedInt       max_work_group_size = 32;
47*bd882c8aSJames Wright   const CeedInt       work_group_size     = CeedIntMin(impl->num_qpts, max_work_group_size);
48*bd882c8aSJames Wright   sycl::range<1>      local_range(work_group_size);
49*bd882c8aSJames Wright   sycl::range<1>      global_range(num_elem * work_group_size);
50*bd882c8aSJames Wright   sycl::nd_range<1>   kernel_range(global_range, local_range);
51*bd882c8aSJames Wright 
52*bd882c8aSJames Wright   // Order queue
53*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
54*bd882c8aSJames Wright   sycl_queue.submit([&](sycl::handler &cgh) {
55*bd882c8aSJames Wright     cgh.depends_on({e});
56*bd882c8aSJames Wright     cgh.use_kernel_bundle(sycl_module);
57*bd882c8aSJames Wright 
58*bd882c8aSJames Wright     sycl::local_accessor<CeedScalar> s_mem(op_len + 2 * buf_len, cgh);
59*bd882c8aSJames Wright 
60*bd882c8aSJames Wright     cgh.parallel_for<CeedBasisSyclInterp<transpose>>(kernel_range, [=](sycl::nd_item<1> work_item, sycl::kernel_handler kh) {
61*bd882c8aSJames Wright       //-------------------------------------------------------------->
62*bd882c8aSJames Wright       // Retrieve spec constant values
63*bd882c8aSJames Wright       const CeedInt dim      = kh.get_specialization_constant<BASIS_DIM_ID>();
64*bd882c8aSJames Wright       const CeedInt num_comp = kh.get_specialization_constant<BASIS_NUM_COMP_ID>();
65*bd882c8aSJames Wright       const CeedInt P_1d     = kh.get_specialization_constant<BASIS_P_1D_ID>();
66*bd882c8aSJames Wright       const CeedInt Q_1d     = kh.get_specialization_constant<BASIS_Q_1D_ID>();
67*bd882c8aSJames Wright       //-------------------------------------------------------------->
68*bd882c8aSJames Wright       const CeedInt num_nodes     = CeedIntPow(P_1d, dim);
69*bd882c8aSJames Wright       const CeedInt num_qpts      = CeedIntPow(Q_1d, dim);
70*bd882c8aSJames Wright       const CeedInt P             = transpose ? Q_1d : P_1d;
71*bd882c8aSJames Wright       const CeedInt Q             = transpose ? P_1d : Q_1d;
72*bd882c8aSJames Wright       const CeedInt stride_0      = transpose ? 1 : P_1d;
73*bd882c8aSJames Wright       const CeedInt stride_1      = transpose ? P_1d : 1;
74*bd882c8aSJames Wright       const CeedInt u_stride      = transpose ? num_qpts : num_nodes;
75*bd882c8aSJames Wright       const CeedInt v_stride      = transpose ? num_nodes : num_qpts;
76*bd882c8aSJames Wright       const CeedInt u_comp_stride = num_elem * u_stride;
77*bd882c8aSJames Wright       const CeedInt v_comp_stride = num_elem * v_stride;
78*bd882c8aSJames Wright       const CeedInt u_size        = u_stride;
79*bd882c8aSJames Wright 
80*bd882c8aSJames Wright       sycl::group   work_group = work_item.get_group();
81*bd882c8aSJames Wright       const CeedInt i          = work_item.get_local_linear_id();
82*bd882c8aSJames Wright       const CeedInt group_size = work_group.get_local_linear_range();
83*bd882c8aSJames Wright       const CeedInt elem       = work_group.get_group_linear_id();
84*bd882c8aSJames Wright 
85*bd882c8aSJames Wright       CeedScalar *s_interp_1d = s_mem.get_pointer();
86*bd882c8aSJames Wright       CeedScalar *s_buffer_1  = s_interp_1d + Q * P;
87*bd882c8aSJames Wright       CeedScalar *s_buffer_2  = s_buffer_1 + buf_len;
88*bd882c8aSJames Wright 
89*bd882c8aSJames Wright       for (CeedInt k = i; k < P * Q; k += group_size) {
90*bd882c8aSJames Wright         s_interp_1d[k] = interp_1d[k];
91*bd882c8aSJames Wright       }
92*bd882c8aSJames Wright 
93*bd882c8aSJames Wright       // Apply basis element by element
94*bd882c8aSJames Wright       for (CeedInt comp = 0; comp < num_comp; comp++) {
95*bd882c8aSJames Wright         const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
96*bd882c8aSJames Wright         CeedScalar       *cur_v = v + elem * v_stride + comp * v_comp_stride;
97*bd882c8aSJames Wright 
98*bd882c8aSJames Wright         for (CeedInt k = i; k < u_size; k += group_size) {
99*bd882c8aSJames Wright           s_buffer_1[k] = cur_u[k];
100*bd882c8aSJames Wright         }
101*bd882c8aSJames Wright 
102*bd882c8aSJames Wright         CeedInt pre  = u_size;
103*bd882c8aSJames Wright         CeedInt post = 1;
104*bd882c8aSJames Wright 
105*bd882c8aSJames Wright         for (CeedInt d = 0; d < dim; d++) {
106*bd882c8aSJames Wright           // Use older version of sycl workgroup barrier for performance reasons
107*bd882c8aSJames Wright           // Can be updated in future to align with SYCL2020 spec if performance bottleneck is removed
108*bd882c8aSJames Wright           // sycl::group_barrier(work_group);
109*bd882c8aSJames Wright           work_item.barrier(sycl::access::fence_space::local_space);
110*bd882c8aSJames Wright 
111*bd882c8aSJames Wright           pre /= P;
112*bd882c8aSJames Wright           const CeedScalar *in  = d % 2 ? s_buffer_2 : s_buffer_1;
113*bd882c8aSJames Wright           CeedScalar       *out = d == dim - 1 ? cur_v : (d % 2 ? s_buffer_1 : s_buffer_2);
114*bd882c8aSJames Wright 
115*bd882c8aSJames Wright           // Contract along middle index
116*bd882c8aSJames Wright           const CeedInt writeLen = pre * post * Q;
117*bd882c8aSJames Wright           for (CeedInt k = i; k < writeLen; k += group_size) {
118*bd882c8aSJames Wright             const CeedInt c = k % post;
119*bd882c8aSJames Wright             const CeedInt j = (k / post) % Q;
120*bd882c8aSJames Wright             const CeedInt a = k / (post * Q);
121*bd882c8aSJames Wright 
122*bd882c8aSJames Wright             CeedScalar vk = 0;
123*bd882c8aSJames Wright             for (CeedInt b = 0; b < P; b++) {
124*bd882c8aSJames Wright               vk += s_interp_1d[j * stride_0 + b * stride_1] * in[(a * P + b) * post + c];
125*bd882c8aSJames Wright             }
126*bd882c8aSJames Wright             out[k] = vk;
127*bd882c8aSJames Wright           }
128*bd882c8aSJames Wright           post *= Q;
129*bd882c8aSJames Wright         }
130*bd882c8aSJames Wright       }
131*bd882c8aSJames Wright     });
132*bd882c8aSJames Wright   });
133*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
134*bd882c8aSJames Wright }
135*bd882c8aSJames Wright 
136*bd882c8aSJames Wright //------------------------------------------------------------------------------
137*bd882c8aSJames Wright // Gradient kernel - tensor
138*bd882c8aSJames Wright //------------------------------------------------------------------------------
139*bd882c8aSJames Wright template <int transpose>
140*bd882c8aSJames Wright static int CeedBasisApplyGrad_Sycl(sycl::queue &sycl_queue, const SyclModule_t &sycl_module, CeedInt num_elem, const CeedBasis_Sycl *impl,
141*bd882c8aSJames Wright                                    const CeedScalar *u, CeedScalar *v) {
142*bd882c8aSJames Wright   const CeedInt     buf_len   = impl->buf_len;
143*bd882c8aSJames Wright   const CeedInt     op_len    = impl->op_len;
144*bd882c8aSJames Wright   const CeedScalar *interp_1d = impl->d_interp_1d;
145*bd882c8aSJames Wright   const CeedScalar *grad_1d   = impl->d_grad_1d;
146*bd882c8aSJames Wright 
147*bd882c8aSJames Wright   const sycl::device &sycl_device     = sycl_queue.get_device();
148*bd882c8aSJames Wright   const CeedInt       work_group_size = 32;
149*bd882c8aSJames Wright   sycl::range<1>      local_range(work_group_size);
150*bd882c8aSJames Wright   sycl::range<1>      global_range(num_elem * work_group_size);
151*bd882c8aSJames Wright   sycl::nd_range<1>   kernel_range(global_range, local_range);
152*bd882c8aSJames Wright 
153*bd882c8aSJames Wright   // Order queue
154*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
155*bd882c8aSJames Wright   sycl_queue.submit([&](sycl::handler &cgh) {
156*bd882c8aSJames Wright     cgh.depends_on({e});
157*bd882c8aSJames Wright     cgh.use_kernel_bundle(sycl_module);
158*bd882c8aSJames Wright 
159*bd882c8aSJames Wright     sycl::local_accessor<CeedScalar> s_mem(2 * (op_len + buf_len), cgh);
160*bd882c8aSJames Wright 
161*bd882c8aSJames Wright     cgh.parallel_for<CeedBasisSyclGrad<transpose>>(kernel_range, [=](sycl::nd_item<1> work_item, sycl::kernel_handler kh) {
162*bd882c8aSJames Wright       //-------------------------------------------------------------->
163*bd882c8aSJames Wright       // Retrieve spec constant values
164*bd882c8aSJames Wright       const CeedInt dim      = kh.get_specialization_constant<BASIS_DIM_ID>();
165*bd882c8aSJames Wright       const CeedInt num_comp = kh.get_specialization_constant<BASIS_NUM_COMP_ID>();
166*bd882c8aSJames Wright       const CeedInt P_1d     = kh.get_specialization_constant<BASIS_P_1D_ID>();
167*bd882c8aSJames Wright       const CeedInt Q_1d     = kh.get_specialization_constant<BASIS_Q_1D_ID>();
168*bd882c8aSJames Wright       //-------------------------------------------------------------->
169*bd882c8aSJames Wright       const CeedInt num_nodes     = CeedIntPow(P_1d, dim);
170*bd882c8aSJames Wright       const CeedInt num_qpts      = CeedIntPow(Q_1d, dim);
171*bd882c8aSJames Wright       const CeedInt P             = transpose ? Q_1d : P_1d;
172*bd882c8aSJames Wright       const CeedInt Q             = transpose ? P_1d : Q_1d;
173*bd882c8aSJames Wright       const CeedInt stride_0      = transpose ? 1 : P_1d;
174*bd882c8aSJames Wright       const CeedInt stride_1      = transpose ? P_1d : 1;
175*bd882c8aSJames Wright       const CeedInt u_stride      = transpose ? num_qpts : num_nodes;
176*bd882c8aSJames Wright       const CeedInt v_stride      = transpose ? num_nodes : num_qpts;
177*bd882c8aSJames Wright       const CeedInt u_comp_stride = num_elem * u_stride;
178*bd882c8aSJames Wright       const CeedInt v_comp_stride = num_elem * v_stride;
179*bd882c8aSJames Wright       const CeedInt u_dim_stride  = transpose ? num_elem * num_qpts * num_comp : 0;
180*bd882c8aSJames Wright       const CeedInt v_dim_stride  = transpose ? 0 : num_elem * num_qpts * num_comp;
181*bd882c8aSJames Wright 
182*bd882c8aSJames Wright       sycl::group   work_group = work_item.get_group();
183*bd882c8aSJames Wright       const CeedInt i          = work_item.get_local_linear_id();
184*bd882c8aSJames Wright       const CeedInt group_size = work_group.get_local_linear_range();
185*bd882c8aSJames Wright       const CeedInt elem       = work_group.get_group_linear_id();
186*bd882c8aSJames Wright 
187*bd882c8aSJames Wright       CeedScalar *s_interp_1d = s_mem.get_pointer();
188*bd882c8aSJames Wright       CeedScalar *s_grad_1d   = s_interp_1d + P * Q;
189*bd882c8aSJames Wright       CeedScalar *s_buffer_1  = s_grad_1d + P * Q;
190*bd882c8aSJames Wright       CeedScalar *s_buffer_2  = s_buffer_1 + buf_len;
191*bd882c8aSJames Wright 
192*bd882c8aSJames Wright       for (CeedInt k = i; k < P * Q; k += group_size) {
193*bd882c8aSJames Wright         s_interp_1d[k] = interp_1d[k];
194*bd882c8aSJames Wright         s_grad_1d[k]   = grad_1d[k];
195*bd882c8aSJames Wright       }
196*bd882c8aSJames Wright 
197*bd882c8aSJames Wright       // Apply basis element by element
198*bd882c8aSJames Wright       for (CeedInt comp = 0; comp < num_comp; comp++) {
199*bd882c8aSJames Wright         for (CeedInt dim_1 = 0; dim_1 < dim; dim_1++) {
200*bd882c8aSJames Wright           CeedInt           pre   = transpose ? num_qpts : num_nodes;
201*bd882c8aSJames Wright           CeedInt           post  = 1;
202*bd882c8aSJames Wright           const CeedScalar *cur_u = u + elem * u_stride + dim_1 * u_dim_stride + comp * u_comp_stride;
203*bd882c8aSJames Wright           CeedScalar       *cur_v = v + elem * v_stride + dim_1 * v_dim_stride + comp * v_comp_stride;
204*bd882c8aSJames Wright 
205*bd882c8aSJames Wright           for (CeedInt dim_2 = 0; dim_2 < dim; dim_2++) {
206*bd882c8aSJames Wright             // Use older version of sycl workgroup barrier for performance reasons
207*bd882c8aSJames Wright             // Can be updated in future to align with SYCL2020 spec if performance bottleneck is removed
208*bd882c8aSJames Wright             // sycl::group_barrier(work_group);
209*bd882c8aSJames Wright             work_item.barrier(sycl::access::fence_space::local_space);
210*bd882c8aSJames Wright 
211*bd882c8aSJames Wright             pre /= P;
212*bd882c8aSJames Wright             const CeedScalar *op  = dim_1 == dim_2 ? s_grad_1d : s_interp_1d;
213*bd882c8aSJames Wright             const CeedScalar *in  = (dim_2 == 0 ? cur_u : (dim_2 % 2 ? s_buffer_2 : s_buffer_1));
214*bd882c8aSJames Wright             CeedScalar       *out = dim_2 == dim - 1 ? cur_v : (dim_2 % 2 ? s_buffer_1 : s_buffer_2);
215*bd882c8aSJames Wright 
216*bd882c8aSJames Wright             // Contract along middle index
217*bd882c8aSJames Wright             const CeedInt writeLen = pre * post * Q;
218*bd882c8aSJames Wright             for (CeedInt k = i; k < writeLen; k += group_size) {
219*bd882c8aSJames Wright               const CeedInt c = k % post;
220*bd882c8aSJames Wright               const CeedInt j = (k / post) % Q;
221*bd882c8aSJames Wright               const CeedInt a = k / (post * Q);
222*bd882c8aSJames Wright 
223*bd882c8aSJames Wright               CeedScalar v_k = 0;
224*bd882c8aSJames Wright               for (CeedInt b = 0; b < P; b++) v_k += op[j * stride_0 + b * stride_1] * in[(a * P + b) * post + c];
225*bd882c8aSJames Wright 
226*bd882c8aSJames Wright               if (transpose && dim_2 == dim - 1) out[k] += v_k;
227*bd882c8aSJames Wright               else out[k] = v_k;
228*bd882c8aSJames Wright             }
229*bd882c8aSJames Wright 
230*bd882c8aSJames Wright             post *= Q;
231*bd882c8aSJames Wright           }
232*bd882c8aSJames Wright         }
233*bd882c8aSJames Wright       }
234*bd882c8aSJames Wright     });
235*bd882c8aSJames Wright   });
236*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
237*bd882c8aSJames Wright }
238*bd882c8aSJames Wright 
239*bd882c8aSJames Wright //------------------------------------------------------------------------------
240*bd882c8aSJames Wright // Weight kernel - tensor
241*bd882c8aSJames Wright //------------------------------------------------------------------------------
242*bd882c8aSJames Wright static int CeedBasisApplyWeight_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, const CeedBasis_Sycl *impl, CeedScalar *w) {
243*bd882c8aSJames Wright   const CeedInt     dim         = impl->dim;
244*bd882c8aSJames Wright   const CeedInt     Q_1d        = impl->Q_1d;
245*bd882c8aSJames Wright   const CeedScalar *q_weight_1d = impl->d_q_weight_1d;
246*bd882c8aSJames Wright 
247*bd882c8aSJames Wright   const CeedInt  num_quad_x = Q_1d;
248*bd882c8aSJames Wright   const CeedInt  num_quad_y = (dim > 1) ? Q_1d : 1;
249*bd882c8aSJames Wright   const CeedInt  num_quad_z = (dim > 2) ? Q_1d : 1;
250*bd882c8aSJames Wright   sycl::range<3> kernel_range(num_elem * num_quad_z, num_quad_y, num_quad_x);
251*bd882c8aSJames Wright 
252*bd882c8aSJames Wright   // Order queue
253*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
254*bd882c8aSJames Wright   sycl_queue.parallel_for<CeedBasisSyclWeight>(kernel_range, {e}, [=](sycl::item<3> work_item) {
255*bd882c8aSJames Wright     if (dim == 1) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]];
256*bd882c8aSJames Wright     if (dim == 2) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]] * q_weight_1d[work_item[1]];
257*bd882c8aSJames Wright     if (dim == 3) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]] * q_weight_1d[work_item[1]] * q_weight_1d[work_item[0] % Q_1d];
258*bd882c8aSJames Wright   });
259*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
260*bd882c8aSJames Wright }
261*bd882c8aSJames Wright 
262*bd882c8aSJames Wright //------------------------------------------------------------------------------
263*bd882c8aSJames Wright // Basis apply - tensor
264*bd882c8aSJames Wright //------------------------------------------------------------------------------
265*bd882c8aSJames Wright static int CeedBasisApply_Sycl(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
266*bd882c8aSJames Wright                                CeedVector v) {
267*bd882c8aSJames Wright   Ceed ceed;
268*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
269*bd882c8aSJames Wright   Ceed_Sycl *data;
270*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
271*bd882c8aSJames Wright   CeedBasis_Sycl *impl;
272*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetData(basis, &impl));
273*bd882c8aSJames Wright 
274*bd882c8aSJames Wright   const CeedInt transpose = t_mode == CEED_TRANSPOSE;
275*bd882c8aSJames Wright 
276*bd882c8aSJames Wright   // Read vectors
277*bd882c8aSJames Wright   const CeedScalar *d_u;
278*bd882c8aSJames Wright   CeedScalar       *d_v;
279*bd882c8aSJames Wright   if (eval_mode != CEED_EVAL_WEIGHT) {
280*bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
281*bd882c8aSJames Wright   }
282*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
283*bd882c8aSJames Wright 
284*bd882c8aSJames Wright   // Clear v for transpose operation
285*bd882c8aSJames Wright   if (t_mode == CEED_TRANSPOSE) {
286*bd882c8aSJames Wright     CeedSize length;
287*bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetLength(v, &length));
288*bd882c8aSJames Wright     // Order queue
289*bd882c8aSJames Wright     sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
290*bd882c8aSJames Wright     data->sycl_queue.fill<CeedScalar>(d_v, 0, length, {e});
291*bd882c8aSJames Wright   }
292*bd882c8aSJames Wright 
293*bd882c8aSJames Wright   // Basis action
294*bd882c8aSJames Wright   switch (eval_mode) {
295*bd882c8aSJames Wright     case CEED_EVAL_INTERP: {
296*bd882c8aSJames Wright       if (transpose) {
297*bd882c8aSJames Wright         CeedCallBackend(CeedBasisApplyInterp_Sycl<CEED_TRANSPOSE>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
298*bd882c8aSJames Wright       } else {
299*bd882c8aSJames Wright         CeedCallBackend(CeedBasisApplyInterp_Sycl<CEED_NOTRANSPOSE>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
300*bd882c8aSJames Wright       }
301*bd882c8aSJames Wright     } break;
302*bd882c8aSJames Wright     case CEED_EVAL_GRAD: {
303*bd882c8aSJames Wright       if (transpose) {
304*bd882c8aSJames Wright         CeedCallBackend(CeedBasisApplyGrad_Sycl<1>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
305*bd882c8aSJames Wright       } else {
306*bd882c8aSJames Wright         CeedCallBackend(CeedBasisApplyGrad_Sycl<0>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
307*bd882c8aSJames Wright       }
308*bd882c8aSJames Wright     } break;
309*bd882c8aSJames Wright     case CEED_EVAL_WEIGHT: {
310*bd882c8aSJames Wright       CeedCallBackend(CeedBasisApplyWeight_Sycl(data->sycl_queue, num_elem, impl, d_v));
311*bd882c8aSJames Wright     } break;
312*bd882c8aSJames Wright     // LCOV_EXCL_START
313*bd882c8aSJames Wright     // Evaluate the divergence to/from the quadrature points
314*bd882c8aSJames Wright     case CEED_EVAL_DIV:
315*bd882c8aSJames Wright       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
316*bd882c8aSJames Wright     // Evaluate the curl to/from the quadrature points
317*bd882c8aSJames Wright     case CEED_EVAL_CURL:
318*bd882c8aSJames Wright       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
319*bd882c8aSJames Wright     // Take no action, BasisApply should not have been called
320*bd882c8aSJames Wright     case CEED_EVAL_NONE:
321*bd882c8aSJames Wright       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
322*bd882c8aSJames Wright       // LCOV_EXCL_STOP
323*bd882c8aSJames Wright   }
324*bd882c8aSJames Wright 
325*bd882c8aSJames Wright   // Restore vectors
326*bd882c8aSJames Wright   if (eval_mode != CEED_EVAL_WEIGHT) {
327*bd882c8aSJames Wright     CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
328*bd882c8aSJames Wright   }
329*bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
330*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
331*bd882c8aSJames Wright }
332*bd882c8aSJames Wright 
333*bd882c8aSJames Wright //------------------------------------------------------------------------------
334*bd882c8aSJames Wright // Interpolation kernel - non-tensor
335*bd882c8aSJames Wright //------------------------------------------------------------------------------
336*bd882c8aSJames Wright static int CeedBasisApplyNonTensorInterp_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, CeedInt transpose, const CeedBasisNonTensor_Sycl *impl,
337*bd882c8aSJames Wright                                               const CeedScalar *d_U, CeedScalar *d_V) {
338*bd882c8aSJames Wright   const CeedInt     num_comp      = impl->num_comp;
339*bd882c8aSJames Wright   const CeedInt     P             = transpose ? impl->num_qpts : impl->num_nodes;
340*bd882c8aSJames Wright   const CeedInt     Q             = transpose ? impl->num_nodes : impl->num_qpts;
341*bd882c8aSJames Wright   const CeedInt     stride_0      = transpose ? 1 : impl->num_nodes;
342*bd882c8aSJames Wright   const CeedInt     stride_1      = transpose ? impl->num_nodes : 1;
343*bd882c8aSJames Wright   const CeedInt     u_stride      = P;
344*bd882c8aSJames Wright   const CeedInt     v_stride      = Q;
345*bd882c8aSJames Wright   const CeedInt     u_comp_stride = u_stride * num_elem;
346*bd882c8aSJames Wright   const CeedInt     v_comp_stride = v_stride * num_elem;
347*bd882c8aSJames Wright   const CeedInt     u_size        = P;
348*bd882c8aSJames Wright   const CeedInt     v_size        = Q;
349*bd882c8aSJames Wright   const CeedScalar *d_B           = impl->d_interp;
350*bd882c8aSJames Wright 
351*bd882c8aSJames Wright   sycl::range<2> kernel_range(num_elem, v_size);
352*bd882c8aSJames Wright 
353*bd882c8aSJames Wright   // Order queue
354*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
355*bd882c8aSJames Wright   sycl_queue.parallel_for<CeedBasisSyclInterpNT>(kernel_range, {e}, [=](sycl::id<2> indx) {
356*bd882c8aSJames Wright     const CeedInt i    = indx[1];
357*bd882c8aSJames Wright     const CeedInt elem = indx[0];
358*bd882c8aSJames Wright 
359*bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
360*bd882c8aSJames Wright       const CeedScalar *U = d_U + elem * u_stride + comp * u_comp_stride;
361*bd882c8aSJames Wright       CeedScalar        V = 0.0;
362*bd882c8aSJames Wright 
363*bd882c8aSJames Wright       for (CeedInt j = 0; j < u_size; ++j) {
364*bd882c8aSJames Wright         V += d_B[i * stride_0 + j * stride_1] * U[j];
365*bd882c8aSJames Wright       }
366*bd882c8aSJames Wright       d_V[i + elem * v_stride + comp * v_comp_stride] = V;
367*bd882c8aSJames Wright     }
368*bd882c8aSJames Wright   });
369*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
370*bd882c8aSJames Wright }
371*bd882c8aSJames Wright 
372*bd882c8aSJames Wright //------------------------------------------------------------------------------
373*bd882c8aSJames Wright // Gradient kernel - non-tensor
374*bd882c8aSJames Wright //------------------------------------------------------------------------------
375*bd882c8aSJames Wright static int CeedBasisApplyNonTensorGrad_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, CeedInt transpose, const CeedBasisNonTensor_Sycl *impl,
376*bd882c8aSJames Wright                                             const CeedScalar *d_U, CeedScalar *d_V) {
377*bd882c8aSJames Wright   const CeedInt     num_comp      = impl->num_comp;
378*bd882c8aSJames Wright   const CeedInt     P             = transpose ? impl->num_qpts : impl->num_nodes;
379*bd882c8aSJames Wright   const CeedInt     Q             = transpose ? impl->num_nodes : impl->num_qpts;
380*bd882c8aSJames Wright   const CeedInt     stride_0      = transpose ? 1 : impl->num_nodes;
381*bd882c8aSJames Wright   const CeedInt     stride_1      = transpose ? impl->num_nodes : 1;
382*bd882c8aSJames Wright   const CeedInt     g_dim_stride  = P * Q;
383*bd882c8aSJames Wright   const CeedInt     u_stride      = P;
384*bd882c8aSJames Wright   const CeedInt     v_stride      = Q;
385*bd882c8aSJames Wright   const CeedInt     u_comp_stride = u_stride * num_elem;
386*bd882c8aSJames Wright   const CeedInt     v_comp_stride = v_stride * num_elem;
387*bd882c8aSJames Wright   const CeedInt     u_dim_stride  = u_comp_stride * num_comp;
388*bd882c8aSJames Wright   const CeedInt     v_dim_stride  = v_comp_stride * num_comp;
389*bd882c8aSJames Wright   const CeedInt     u_size        = P;
390*bd882c8aSJames Wright   const CeedInt     v_size        = Q;
391*bd882c8aSJames Wright   const CeedInt     in_dim        = transpose ? impl->dim : 1;
392*bd882c8aSJames Wright   const CeedInt     out_dim       = transpose ? 1 : impl->dim;
393*bd882c8aSJames Wright   const CeedScalar *d_G           = impl->d_grad;
394*bd882c8aSJames Wright 
395*bd882c8aSJames Wright   sycl::range<2> kernel_range(num_elem, v_size);
396*bd882c8aSJames Wright 
397*bd882c8aSJames Wright   // Order queue
398*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
399*bd882c8aSJames Wright   sycl_queue.parallel_for<CeedBasisSyclGradNT>(kernel_range, {e}, [=](sycl::id<2> indx) {
400*bd882c8aSJames Wright     const CeedInt i    = indx[1];
401*bd882c8aSJames Wright     const CeedInt elem = indx[0];
402*bd882c8aSJames Wright 
403*bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
404*bd882c8aSJames Wright       CeedScalar V[3] = {0.0, 0.0, 0.0};
405*bd882c8aSJames Wright 
406*bd882c8aSJames Wright       for (CeedInt d1 = 0; d1 < in_dim; ++d1) {
407*bd882c8aSJames Wright         const CeedScalar *U = d_U + elem * u_stride + comp * u_comp_stride + d1 * u_dim_stride;
408*bd882c8aSJames Wright         const CeedScalar *G = d_G + i * stride_0 + d1 * g_dim_stride;
409*bd882c8aSJames Wright 
410*bd882c8aSJames Wright         for (CeedInt j = 0; j < u_size; ++j) {
411*bd882c8aSJames Wright           const CeedScalar Uj = U[j];
412*bd882c8aSJames Wright 
413*bd882c8aSJames Wright           for (CeedInt d0 = 0; d0 < out_dim; ++d0) {
414*bd882c8aSJames Wright             V[d0] += G[j * stride_1 + d0 * g_dim_stride] * Uj;
415*bd882c8aSJames Wright           }
416*bd882c8aSJames Wright         }
417*bd882c8aSJames Wright       }
418*bd882c8aSJames Wright       for (CeedInt d0 = 0; d0 < out_dim; ++d0) {
419*bd882c8aSJames Wright         d_V[i + elem * v_stride + comp * v_comp_stride + d0 * v_dim_stride] = V[d0];
420*bd882c8aSJames Wright       }
421*bd882c8aSJames Wright     }
422*bd882c8aSJames Wright   });
423*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
424*bd882c8aSJames Wright }
425*bd882c8aSJames Wright 
426*bd882c8aSJames Wright //------------------------------------------------------------------------------
427*bd882c8aSJames Wright // Weight kernel - non-tensor
428*bd882c8aSJames Wright //------------------------------------------------------------------------------
429*bd882c8aSJames Wright static int CeedBasisApplyNonTensorWeight_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, const CeedBasisNonTensor_Sycl *impl, CeedScalar *d_V) {
430*bd882c8aSJames Wright   const CeedInt     num_qpts = impl->num_qpts;
431*bd882c8aSJames Wright   const CeedScalar *q_weight = impl->d_q_weight;
432*bd882c8aSJames Wright 
433*bd882c8aSJames Wright   sycl::range<2> kernel_range(num_elem, num_qpts);
434*bd882c8aSJames Wright 
435*bd882c8aSJames Wright   // Order queue
436*bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
437*bd882c8aSJames Wright   sycl_queue.parallel_for<CeedBasisSyclWeightNT>(kernel_range, {e}, [=](sycl::id<2> indx) {
438*bd882c8aSJames Wright     const CeedInt i          = indx[1];
439*bd882c8aSJames Wright     const CeedInt elem       = indx[0];
440*bd882c8aSJames Wright     d_V[i + elem * num_qpts] = q_weight[i];
441*bd882c8aSJames Wright   });
442*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
443*bd882c8aSJames Wright }
444*bd882c8aSJames Wright 
445*bd882c8aSJames Wright //------------------------------------------------------------------------------
446*bd882c8aSJames Wright // Basis apply - non-tensor
447*bd882c8aSJames Wright //------------------------------------------------------------------------------
448*bd882c8aSJames Wright static int CeedBasisApplyNonTensor_Sycl(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
449*bd882c8aSJames Wright                                         CeedVector v) {
450*bd882c8aSJames Wright   Ceed ceed;
451*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
452*bd882c8aSJames Wright   CeedBasisNonTensor_Sycl *impl;
453*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetData(basis, &impl));
454*bd882c8aSJames Wright   Ceed_Sycl *data;
455*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
456*bd882c8aSJames Wright 
457*bd882c8aSJames Wright   const CeedInt transpose = t_mode == CEED_TRANSPOSE;
458*bd882c8aSJames Wright 
459*bd882c8aSJames Wright   // Read vectors
460*bd882c8aSJames Wright   const CeedScalar *d_u;
461*bd882c8aSJames Wright   CeedScalar       *d_v;
462*bd882c8aSJames Wright   if (eval_mode != CEED_EVAL_WEIGHT) {
463*bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
464*bd882c8aSJames Wright   }
465*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
466*bd882c8aSJames Wright 
467*bd882c8aSJames Wright   // Clear v for transpose operation
468*bd882c8aSJames Wright   if (transpose) {
469*bd882c8aSJames Wright     CeedSize length;
470*bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetLength(v, &length));
471*bd882c8aSJames Wright     // Order queue
472*bd882c8aSJames Wright     sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
473*bd882c8aSJames Wright     data->sycl_queue.fill<CeedScalar>(d_v, 0, length, {e});
474*bd882c8aSJames Wright   }
475*bd882c8aSJames Wright 
476*bd882c8aSJames Wright   // Apply basis operation
477*bd882c8aSJames Wright   switch (eval_mode) {
478*bd882c8aSJames Wright     case CEED_EVAL_INTERP: {
479*bd882c8aSJames Wright       CeedCallBackend(CeedBasisApplyNonTensorInterp_Sycl(data->sycl_queue, num_elem, transpose, impl, d_u, d_v));
480*bd882c8aSJames Wright     } break;
481*bd882c8aSJames Wright     case CEED_EVAL_GRAD: {
482*bd882c8aSJames Wright       CeedCallBackend(CeedBasisApplyNonTensorGrad_Sycl(data->sycl_queue, num_elem, transpose, impl, d_u, d_v));
483*bd882c8aSJames Wright     } break;
484*bd882c8aSJames Wright     case CEED_EVAL_WEIGHT: {
485*bd882c8aSJames Wright       CeedCallBackend(CeedBasisApplyNonTensorWeight_Sycl(data->sycl_queue, num_elem, impl, d_v));
486*bd882c8aSJames Wright     } break;
487*bd882c8aSJames Wright     // LCOV_EXCL_START
488*bd882c8aSJames Wright     // Evaluate the divergence to/from the quadrature points
489*bd882c8aSJames Wright     case CEED_EVAL_DIV:
490*bd882c8aSJames Wright       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
491*bd882c8aSJames Wright     // Evaluate the curl to/from the quadrature points
492*bd882c8aSJames Wright     case CEED_EVAL_CURL:
493*bd882c8aSJames Wright       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
494*bd882c8aSJames Wright     // Take no action, BasisApply should not have been called
495*bd882c8aSJames Wright     case CEED_EVAL_NONE:
496*bd882c8aSJames Wright       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
497*bd882c8aSJames Wright       // LCOV_EXCL_STOP
498*bd882c8aSJames Wright   }
499*bd882c8aSJames Wright 
500*bd882c8aSJames Wright   // Restore vectors
501*bd882c8aSJames Wright   if (eval_mode != CEED_EVAL_WEIGHT) {
502*bd882c8aSJames Wright     CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
503*bd882c8aSJames Wright   }
504*bd882c8aSJames Wright 
505*bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
506*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
507*bd882c8aSJames Wright }
508*bd882c8aSJames Wright 
509*bd882c8aSJames Wright //------------------------------------------------------------------------------
510*bd882c8aSJames Wright // Destroy tensor basis
511*bd882c8aSJames Wright //------------------------------------------------------------------------------
512*bd882c8aSJames Wright static int CeedBasisDestroy_Sycl(CeedBasis basis) {
513*bd882c8aSJames Wright   Ceed ceed;
514*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
515*bd882c8aSJames Wright   CeedBasis_Sycl *impl;
516*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetData(basis, &impl));
517*bd882c8aSJames Wright   Ceed_Sycl *data;
518*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
519*bd882c8aSJames Wright 
520*bd882c8aSJames Wright   // Wait for all work to finish before freeing memory
521*bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
522*bd882c8aSJames Wright 
523*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_q_weight_1d, data->sycl_context));
524*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_interp_1d, data->sycl_context));
525*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_grad_1d, data->sycl_context));
526*bd882c8aSJames Wright 
527*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
528*bd882c8aSJames Wright 
529*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
530*bd882c8aSJames Wright }
531*bd882c8aSJames Wright 
532*bd882c8aSJames Wright //------------------------------------------------------------------------------
533*bd882c8aSJames Wright // Destroy non-tensor basis
534*bd882c8aSJames Wright //------------------------------------------------------------------------------
535*bd882c8aSJames Wright static int CeedBasisDestroyNonTensor_Sycl(CeedBasis basis) {
536*bd882c8aSJames Wright   Ceed ceed;
537*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
538*bd882c8aSJames Wright   CeedBasisNonTensor_Sycl *impl;
539*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetData(basis, &impl));
540*bd882c8aSJames Wright   Ceed_Sycl *data;
541*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
542*bd882c8aSJames Wright 
543*bd882c8aSJames Wright   // Wait for all work to finish before freeing memory
544*bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
545*bd882c8aSJames Wright 
546*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_q_weight, data->sycl_context));
547*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_interp, data->sycl_context));
548*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_grad, data->sycl_context));
549*bd882c8aSJames Wright 
550*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
551*bd882c8aSJames Wright 
552*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
553*bd882c8aSJames Wright }
554*bd882c8aSJames Wright 
555*bd882c8aSJames Wright //------------------------------------------------------------------------------
556*bd882c8aSJames Wright // Create tensor
557*bd882c8aSJames Wright //------------------------------------------------------------------------------
558*bd882c8aSJames Wright int CeedBasisCreateTensorH1_Sycl(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
559*bd882c8aSJames Wright                                  const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
560*bd882c8aSJames Wright   Ceed ceed;
561*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
562*bd882c8aSJames Wright   CeedBasis_Sycl *impl;
563*bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(1, &impl));
564*bd882c8aSJames Wright   Ceed_Sycl *data;
565*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
566*bd882c8aSJames Wright 
567*bd882c8aSJames Wright   CeedInt num_comp;
568*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
569*bd882c8aSJames Wright 
570*bd882c8aSJames Wright   const CeedInt num_nodes = CeedIntPow(P_1d, dim);
571*bd882c8aSJames Wright   const CeedInt num_qpts  = CeedIntPow(Q_1d, dim);
572*bd882c8aSJames Wright 
573*bd882c8aSJames Wright   impl->dim       = dim;
574*bd882c8aSJames Wright   impl->P_1d      = P_1d;
575*bd882c8aSJames Wright   impl->Q_1d      = Q_1d;
576*bd882c8aSJames Wright   impl->num_comp  = num_comp;
577*bd882c8aSJames Wright   impl->num_nodes = num_nodes;
578*bd882c8aSJames Wright   impl->num_qpts  = num_qpts;
579*bd882c8aSJames Wright   impl->buf_len   = num_comp * CeedIntMax(num_nodes, num_qpts);
580*bd882c8aSJames Wright   impl->op_len    = Q_1d * P_1d;
581*bd882c8aSJames Wright 
582*bd882c8aSJames Wright   // Order queue
583*bd882c8aSJames Wright   sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
584*bd882c8aSJames Wright 
585*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_q_weight_1d = sycl::malloc_device<CeedScalar>(Q_1d, data->sycl_device, data->sycl_context));
586*bd882c8aSJames Wright   sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight_1d, impl->d_q_weight_1d, Q_1d, {e});
587*bd882c8aSJames Wright 
588*bd882c8aSJames Wright   const CeedInt interp_length = Q_1d * P_1d;
589*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_interp_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
590*bd882c8aSJames Wright   sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp_1d, impl->d_interp_1d, interp_length, {e});
591*bd882c8aSJames Wright 
592*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_grad_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
593*bd882c8aSJames Wright   sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad_1d, impl->d_grad_1d, interp_length, {e});
594*bd882c8aSJames Wright 
595*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_weight, copy_interp, copy_grad}));
596*bd882c8aSJames Wright 
597*bd882c8aSJames Wright   std::vector<sycl::kernel_id> kernel_ids = {sycl::get_kernel_id<CeedBasisSyclInterp<1>>(), sycl::get_kernel_id<CeedBasisSyclInterp<0>>(),
598*bd882c8aSJames Wright                                              sycl::get_kernel_id<CeedBasisSyclGrad<1>>(), sycl::get_kernel_id<CeedBasisSyclGrad<0>>()};
599*bd882c8aSJames Wright 
600*bd882c8aSJames Wright   sycl::kernel_bundle<sycl::bundle_state::input> input_bundle = sycl::get_kernel_bundle<sycl::bundle_state::input>(data->sycl_context, kernel_ids);
601*bd882c8aSJames Wright   input_bundle.set_specialization_constant<BASIS_DIM_ID>(dim);
602*bd882c8aSJames Wright   input_bundle.set_specialization_constant<BASIS_NUM_COMP_ID>(num_comp);
603*bd882c8aSJames Wright   input_bundle.set_specialization_constant<BASIS_Q_1D_ID>(Q_1d);
604*bd882c8aSJames Wright   input_bundle.set_specialization_constant<BASIS_P_1D_ID>(P_1d);
605*bd882c8aSJames Wright 
606*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->sycl_module = new SyclModule_t(sycl::build(input_bundle)));
607*bd882c8aSJames Wright 
608*bd882c8aSJames Wright   CeedCallBackend(CeedBasisSetData(basis, impl));
609*bd882c8aSJames Wright 
610*bd882c8aSJames Wright   // Register backend functions
611*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApply_Sycl));
612*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Sycl));
613*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
614*bd882c8aSJames Wright }
615*bd882c8aSJames Wright 
616*bd882c8aSJames Wright //------------------------------------------------------------------------------
617*bd882c8aSJames Wright // Create non-tensor
618*bd882c8aSJames Wright //------------------------------------------------------------------------------
619*bd882c8aSJames Wright int CeedBasisCreateH1_Sycl(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad,
620*bd882c8aSJames Wright                            const CeedScalar *qref, const CeedScalar *q_weight, CeedBasis basis) {
621*bd882c8aSJames Wright   Ceed ceed;
622*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
623*bd882c8aSJames Wright   CeedBasisNonTensor_Sycl *impl;
624*bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(1, &impl));
625*bd882c8aSJames Wright   Ceed_Sycl *data;
626*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
627*bd882c8aSJames Wright 
628*bd882c8aSJames Wright   CeedInt num_comp;
629*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
630*bd882c8aSJames Wright 
631*bd882c8aSJames Wright   impl->dim       = dim;
632*bd882c8aSJames Wright   impl->num_comp  = num_comp;
633*bd882c8aSJames Wright   impl->num_nodes = num_nodes;
634*bd882c8aSJames Wright   impl->num_qpts  = num_qpts;
635*bd882c8aSJames Wright 
636*bd882c8aSJames Wright   // Order queue
637*bd882c8aSJames Wright   sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
638*bd882c8aSJames Wright 
639*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_q_weight = sycl::malloc_device<CeedScalar>(num_qpts, data->sycl_device, data->sycl_context));
640*bd882c8aSJames Wright   sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight, impl->d_q_weight, num_qpts, {e});
641*bd882c8aSJames Wright 
642*bd882c8aSJames Wright   const CeedInt interp_length = num_qpts * num_nodes;
643*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_interp = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
644*bd882c8aSJames Wright   sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp, impl->d_interp, interp_length, {e});
645*bd882c8aSJames Wright 
646*bd882c8aSJames Wright   const CeedInt grad_length = num_qpts * num_nodes * dim;
647*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_grad = sycl::malloc_device<CeedScalar>(grad_length, data->sycl_device, data->sycl_context));
648*bd882c8aSJames Wright   sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad, impl->d_grad, grad_length, {e});
649*bd882c8aSJames Wright 
650*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_weight, copy_interp, copy_grad}));
651*bd882c8aSJames Wright 
652*bd882c8aSJames Wright   CeedCallBackend(CeedBasisSetData(basis, impl));
653*bd882c8aSJames Wright 
654*bd882c8aSJames Wright   // Register backend functions
655*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Sycl));
656*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Sycl));
657*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
658*bd882c8aSJames Wright }
659*bd882c8aSJames Wright //------------------------------------------------------------------------------
660