xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-basis.sycl.cpp (revision d07cdbe5b933b2ee94e4361e6ff18d258e72d53d)
15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3bd882c8aSJames Wright //
4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5bd882c8aSJames Wright //
6bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7bd882c8aSJames Wright 
8bd882c8aSJames Wright #include <ceed/backend.h>
9bd882c8aSJames Wright #include <ceed/ceed.h>
10bd882c8aSJames Wright #include <ceed/jit-tools.h>
11bd882c8aSJames Wright 
12bd882c8aSJames Wright #include <sycl/sycl.hpp>
13bd882c8aSJames Wright #include <vector>
14bd882c8aSJames Wright 
15bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
16bd882c8aSJames Wright #include "ceed-sycl-ref.hpp"
17bd882c8aSJames Wright 
18bd882c8aSJames Wright template <int>
19bd882c8aSJames Wright class CeedBasisSyclInterp;
20bd882c8aSJames Wright template <int>
21bd882c8aSJames Wright class CeedBasisSyclGrad;
22bd882c8aSJames Wright class CeedBasisSyclWeight;
23bd882c8aSJames Wright 
24bd882c8aSJames Wright class CeedBasisSyclInterpNT;
25bd882c8aSJames Wright class CeedBasisSyclGradNT;
26bd882c8aSJames Wright class CeedBasisSyclWeightNT;
27bd882c8aSJames Wright 
28bd882c8aSJames Wright using SpecID = sycl::specialization_id<CeedInt>;
29bd882c8aSJames Wright 
30bd882c8aSJames Wright static constexpr SpecID BASIS_DIM_ID;
31bd882c8aSJames Wright static constexpr SpecID BASIS_NUM_COMP_ID;
32bd882c8aSJames Wright static constexpr SpecID BASIS_P_1D_ID;
33bd882c8aSJames Wright static constexpr SpecID BASIS_Q_1D_ID;
34bd882c8aSJames Wright 
35bd882c8aSJames Wright //------------------------------------------------------------------------------
36bd882c8aSJames Wright // Interpolation kernel - tensor
37bd882c8aSJames Wright //------------------------------------------------------------------------------
380ae60fd3SJeremy L Thompson template <int is_transpose>
39bd882c8aSJames Wright static int CeedBasisApplyInterp_Sycl(sycl::queue &sycl_queue, const SyclModule_t &sycl_module, CeedInt num_elem, const CeedBasis_Sycl *impl,
40bd882c8aSJames Wright                                      const CeedScalar *u, CeedScalar *v) {
41bd882c8aSJames Wright   const CeedInt     buf_len   = impl->buf_len;
42bd882c8aSJames Wright   const CeedInt     op_len    = impl->op_len;
43bd882c8aSJames Wright   const CeedScalar *interp_1d = impl->d_interp_1d;
44bd882c8aSJames Wright 
45bd882c8aSJames Wright   const sycl::device &sycl_device         = sycl_queue.get_device();
46bd882c8aSJames Wright   const CeedInt       max_work_group_size = 32;
47bd882c8aSJames Wright   const CeedInt       work_group_size     = CeedIntMin(impl->num_qpts, max_work_group_size);
48bd882c8aSJames Wright   sycl::range<1>      local_range(work_group_size);
49bd882c8aSJames Wright   sycl::range<1>      global_range(num_elem * work_group_size);
50bd882c8aSJames Wright   sycl::nd_range<1>   kernel_range(global_range, local_range);
51bd882c8aSJames Wright 
52bd882c8aSJames Wright   // Order queue
53bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
54bd882c8aSJames Wright   sycl_queue.submit([&](sycl::handler &cgh) {
55bd882c8aSJames Wright     cgh.depends_on({e});
56bd882c8aSJames Wright     cgh.use_kernel_bundle(sycl_module);
57bd882c8aSJames Wright 
58bd882c8aSJames Wright     sycl::local_accessor<CeedScalar> s_mem(op_len + 2 * buf_len, cgh);
59bd882c8aSJames Wright 
600ae60fd3SJeremy L Thompson     cgh.parallel_for<CeedBasisSyclInterp<is_transpose>>(kernel_range, [=](sycl::nd_item<1> work_item, sycl::kernel_handler kh) {
61bd882c8aSJames Wright       //-------------------------------------------------------------->
62bd882c8aSJames Wright       // Retrieve spec constant values
63bd882c8aSJames Wright       const CeedInt dim      = kh.get_specialization_constant<BASIS_DIM_ID>();
64bd882c8aSJames Wright       const CeedInt num_comp = kh.get_specialization_constant<BASIS_NUM_COMP_ID>();
65bd882c8aSJames Wright       const CeedInt P_1d     = kh.get_specialization_constant<BASIS_P_1D_ID>();
66bd882c8aSJames Wright       const CeedInt Q_1d     = kh.get_specialization_constant<BASIS_Q_1D_ID>();
67bd882c8aSJames Wright       //-------------------------------------------------------------->
68bd882c8aSJames Wright       const CeedInt num_nodes     = CeedIntPow(P_1d, dim);
69bd882c8aSJames Wright       const CeedInt num_qpts      = CeedIntPow(Q_1d, dim);
700ae60fd3SJeremy L Thompson       const CeedInt P             = is_transpose ? Q_1d : P_1d;
710ae60fd3SJeremy L Thompson       const CeedInt Q             = is_transpose ? P_1d : Q_1d;
720ae60fd3SJeremy L Thompson       const CeedInt stride_0      = is_transpose ? 1 : P_1d;
730ae60fd3SJeremy L Thompson       const CeedInt stride_1      = is_transpose ? P_1d : 1;
740ae60fd3SJeremy L Thompson       const CeedInt u_stride      = is_transpose ? num_qpts : num_nodes;
750ae60fd3SJeremy L Thompson       const CeedInt v_stride      = is_transpose ? num_nodes : num_qpts;
76bd882c8aSJames Wright       const CeedInt u_comp_stride = num_elem * u_stride;
77bd882c8aSJames Wright       const CeedInt v_comp_stride = num_elem * v_stride;
78bd882c8aSJames Wright       const CeedInt u_size        = u_stride;
79bd882c8aSJames Wright 
80bd882c8aSJames Wright       sycl::group   work_group = work_item.get_group();
81bd882c8aSJames Wright       const CeedInt i          = work_item.get_local_linear_id();
82bd882c8aSJames Wright       const CeedInt group_size = work_group.get_local_linear_range();
83bd882c8aSJames Wright       const CeedInt elem       = work_group.get_group_linear_id();
84bd882c8aSJames Wright 
8533bb61d4SKris Rowe       CeedScalar *s_interp_1d = s_mem.get_multi_ptr<sycl::access::decorated::yes>().get();
86bd882c8aSJames Wright       CeedScalar *s_buffer_1  = s_interp_1d + Q * P;
87bd882c8aSJames Wright       CeedScalar *s_buffer_2  = s_buffer_1 + buf_len;
88bd882c8aSJames Wright 
89bd882c8aSJames Wright       for (CeedInt k = i; k < P * Q; k += group_size) {
90bd882c8aSJames Wright         s_interp_1d[k] = interp_1d[k];
91bd882c8aSJames Wright       }
92bd882c8aSJames Wright 
93bd882c8aSJames Wright       // Apply basis element by element
94bd882c8aSJames Wright       for (CeedInt comp = 0; comp < num_comp; comp++) {
95bd882c8aSJames Wright         const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
96bd882c8aSJames Wright         CeedScalar       *cur_v = v + elem * v_stride + comp * v_comp_stride;
97bd882c8aSJames Wright 
98bd882c8aSJames Wright         for (CeedInt k = i; k < u_size; k += group_size) {
99bd882c8aSJames Wright           s_buffer_1[k] = cur_u[k];
100bd882c8aSJames Wright         }
101bd882c8aSJames Wright 
102bd882c8aSJames Wright         CeedInt pre  = u_size;
103bd882c8aSJames Wright         CeedInt post = 1;
104bd882c8aSJames Wright 
105bd882c8aSJames Wright         for (CeedInt d = 0; d < dim; d++) {
106bd882c8aSJames Wright           // Use older version of sycl workgroup barrier for performance reasons
107bd882c8aSJames Wright           // Can be updated in future to align with SYCL2020 spec if performance bottleneck is removed
108bd882c8aSJames Wright           // sycl::group_barrier(work_group);
109bd882c8aSJames Wright           work_item.barrier(sycl::access::fence_space::local_space);
110bd882c8aSJames Wright 
111bd882c8aSJames Wright           pre /= P;
112bd882c8aSJames Wright           const CeedScalar *in  = d % 2 ? s_buffer_2 : s_buffer_1;
113bd882c8aSJames Wright           CeedScalar       *out = d == dim - 1 ? cur_v : (d % 2 ? s_buffer_1 : s_buffer_2);
114bd882c8aSJames Wright 
115bd882c8aSJames Wright           // Contract along middle index
116bd882c8aSJames Wright           const CeedInt writeLen = pre * post * Q;
117bd882c8aSJames Wright           for (CeedInt k = i; k < writeLen; k += group_size) {
118bd882c8aSJames Wright             const CeedInt c = k % post;
119bd882c8aSJames Wright             const CeedInt j = (k / post) % Q;
120bd882c8aSJames Wright             const CeedInt a = k / (post * Q);
121bd882c8aSJames Wright 
122bd882c8aSJames Wright             CeedScalar vk = 0;
123bd882c8aSJames Wright             for (CeedInt b = 0; b < P; b++) {
124bd882c8aSJames Wright               vk += s_interp_1d[j * stride_0 + b * stride_1] * in[(a * P + b) * post + c];
125bd882c8aSJames Wright             }
126bd882c8aSJames Wright             out[k] = vk;
127bd882c8aSJames Wright           }
128bd882c8aSJames Wright           post *= Q;
129bd882c8aSJames Wright         }
130bd882c8aSJames Wright       }
131bd882c8aSJames Wright     });
132bd882c8aSJames Wright   });
133bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
134bd882c8aSJames Wright }
135bd882c8aSJames Wright 
136bd882c8aSJames Wright //------------------------------------------------------------------------------
137bd882c8aSJames Wright // Gradient kernel - tensor
138bd882c8aSJames Wright //------------------------------------------------------------------------------
1390ae60fd3SJeremy L Thompson template <int is_transpose>
140bd882c8aSJames Wright static int CeedBasisApplyGrad_Sycl(sycl::queue &sycl_queue, const SyclModule_t &sycl_module, CeedInt num_elem, const CeedBasis_Sycl *impl,
141bd882c8aSJames Wright                                    const CeedScalar *u, CeedScalar *v) {
142bd882c8aSJames Wright   const CeedInt     buf_len   = impl->buf_len;
143bd882c8aSJames Wright   const CeedInt     op_len    = impl->op_len;
144bd882c8aSJames Wright   const CeedScalar *interp_1d = impl->d_interp_1d;
145bd882c8aSJames Wright   const CeedScalar *grad_1d   = impl->d_grad_1d;
146bd882c8aSJames Wright 
147bd882c8aSJames Wright   const sycl::device &sycl_device     = sycl_queue.get_device();
148bd882c8aSJames Wright   const CeedInt       work_group_size = 32;
149bd882c8aSJames Wright   sycl::range<1>      local_range(work_group_size);
150bd882c8aSJames Wright   sycl::range<1>      global_range(num_elem * work_group_size);
151bd882c8aSJames Wright   sycl::nd_range<1>   kernel_range(global_range, local_range);
152bd882c8aSJames Wright 
153bd882c8aSJames Wright   // Order queue
154bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
155bd882c8aSJames Wright   sycl_queue.submit([&](sycl::handler &cgh) {
156bd882c8aSJames Wright     cgh.depends_on({e});
157bd882c8aSJames Wright     cgh.use_kernel_bundle(sycl_module);
158bd882c8aSJames Wright 
159bd882c8aSJames Wright     sycl::local_accessor<CeedScalar> s_mem(2 * (op_len + buf_len), cgh);
160bd882c8aSJames Wright 
1610ae60fd3SJeremy L Thompson     cgh.parallel_for<CeedBasisSyclGrad<is_transpose>>(kernel_range, [=](sycl::nd_item<1> work_item, sycl::kernel_handler kh) {
162bd882c8aSJames Wright       //-------------------------------------------------------------->
163bd882c8aSJames Wright       // Retrieve spec constant values
164bd882c8aSJames Wright       const CeedInt dim      = kh.get_specialization_constant<BASIS_DIM_ID>();
165bd882c8aSJames Wright       const CeedInt num_comp = kh.get_specialization_constant<BASIS_NUM_COMP_ID>();
166bd882c8aSJames Wright       const CeedInt P_1d     = kh.get_specialization_constant<BASIS_P_1D_ID>();
167bd882c8aSJames Wright       const CeedInt Q_1d     = kh.get_specialization_constant<BASIS_Q_1D_ID>();
168bd882c8aSJames Wright       //-------------------------------------------------------------->
169bd882c8aSJames Wright       const CeedInt num_nodes     = CeedIntPow(P_1d, dim);
170bd882c8aSJames Wright       const CeedInt num_qpts      = CeedIntPow(Q_1d, dim);
1710ae60fd3SJeremy L Thompson       const CeedInt P             = is_transpose ? Q_1d : P_1d;
1720ae60fd3SJeremy L Thompson       const CeedInt Q             = is_transpose ? P_1d : Q_1d;
1730ae60fd3SJeremy L Thompson       const CeedInt stride_0      = is_transpose ? 1 : P_1d;
1740ae60fd3SJeremy L Thompson       const CeedInt stride_1      = is_transpose ? P_1d : 1;
1750ae60fd3SJeremy L Thompson       const CeedInt u_stride      = is_transpose ? num_qpts : num_nodes;
1760ae60fd3SJeremy L Thompson       const CeedInt v_stride      = is_transpose ? num_nodes : num_qpts;
177bd882c8aSJames Wright       const CeedInt u_comp_stride = num_elem * u_stride;
178bd882c8aSJames Wright       const CeedInt v_comp_stride = num_elem * v_stride;
1790ae60fd3SJeremy L Thompson       const CeedInt u_dim_stride  = is_transpose ? num_elem * num_qpts * num_comp : 0;
1800ae60fd3SJeremy L Thompson       const CeedInt v_dim_stride  = is_transpose ? 0 : num_elem * num_qpts * num_comp;
181bd882c8aSJames Wright       sycl::group   work_group    = work_item.get_group();
182bd882c8aSJames Wright       const CeedInt i             = work_item.get_local_linear_id();
183bd882c8aSJames Wright       const CeedInt group_size    = work_group.get_local_linear_range();
184bd882c8aSJames Wright       const CeedInt elem          = work_group.get_group_linear_id();
185bd882c8aSJames Wright 
18633bb61d4SKris Rowe       CeedScalar *s_interp_1d = s_mem.get_multi_ptr<sycl::access::decorated::yes>().get();
187bd882c8aSJames Wright       CeedScalar *s_grad_1d   = s_interp_1d + P * Q;
188bd882c8aSJames Wright       CeedScalar *s_buffer_1  = s_grad_1d + P * Q;
189bd882c8aSJames Wright       CeedScalar *s_buffer_2  = s_buffer_1 + buf_len;
190bd882c8aSJames Wright 
191bd882c8aSJames Wright       for (CeedInt k = i; k < P * Q; k += group_size) {
192bd882c8aSJames Wright         s_interp_1d[k] = interp_1d[k];
193bd882c8aSJames Wright         s_grad_1d[k]   = grad_1d[k];
194bd882c8aSJames Wright       }
195bd882c8aSJames Wright 
196bd882c8aSJames Wright       // Apply basis element by element
197bd882c8aSJames Wright       for (CeedInt comp = 0; comp < num_comp; comp++) {
198bd882c8aSJames Wright         for (CeedInt dim_1 = 0; dim_1 < dim; dim_1++) {
1990ae60fd3SJeremy L Thompson           CeedInt           pre   = is_transpose ? num_qpts : num_nodes;
200bd882c8aSJames Wright           CeedInt           post  = 1;
201bd882c8aSJames Wright           const CeedScalar *cur_u = u + elem * u_stride + dim_1 * u_dim_stride + comp * u_comp_stride;
202bd882c8aSJames Wright           CeedScalar       *cur_v = v + elem * v_stride + dim_1 * v_dim_stride + comp * v_comp_stride;
203bd882c8aSJames Wright 
204bd882c8aSJames Wright           for (CeedInt dim_2 = 0; dim_2 < dim; dim_2++) {
205bd882c8aSJames Wright             // Use older version of sycl workgroup barrier for performance reasons
206bd882c8aSJames Wright             // Can be updated in future to align with SYCL2020 spec if performance bottleneck is removed
207bd882c8aSJames Wright             // sycl::group_barrier(work_group);
208bd882c8aSJames Wright             work_item.barrier(sycl::access::fence_space::local_space);
209bd882c8aSJames Wright 
210bd882c8aSJames Wright             pre /= P;
211bd882c8aSJames Wright             const CeedScalar *op  = dim_1 == dim_2 ? s_grad_1d : s_interp_1d;
212bd882c8aSJames Wright             const CeedScalar *in  = (dim_2 == 0 ? cur_u : (dim_2 % 2 ? s_buffer_2 : s_buffer_1));
213bd882c8aSJames Wright             CeedScalar       *out = dim_2 == dim - 1 ? cur_v : (dim_2 % 2 ? s_buffer_1 : s_buffer_2);
214bd882c8aSJames Wright 
215bd882c8aSJames Wright             // Contract along middle index
216bd882c8aSJames Wright             const CeedInt writeLen = pre * post * Q;
217bd882c8aSJames Wright             for (CeedInt k = i; k < writeLen; k += group_size) {
218bd882c8aSJames Wright               const CeedInt c = k % post;
219bd882c8aSJames Wright               const CeedInt j = (k / post) % Q;
220bd882c8aSJames Wright               const CeedInt a = k / (post * Q);
221bd882c8aSJames Wright 
222bd882c8aSJames Wright               CeedScalar v_k = 0;
223bd882c8aSJames Wright               for (CeedInt b = 0; b < P; b++) v_k += op[j * stride_0 + b * stride_1] * in[(a * P + b) * post + c];
224bd882c8aSJames Wright 
2250ae60fd3SJeremy L Thompson               if (is_transpose && dim_2 == dim - 1) out[k] += v_k;
226bd882c8aSJames Wright               else out[k] = v_k;
227bd882c8aSJames Wright             }
228bd882c8aSJames Wright 
229bd882c8aSJames Wright             post *= Q;
230bd882c8aSJames Wright           }
231bd882c8aSJames Wright         }
232bd882c8aSJames Wright       }
233bd882c8aSJames Wright     });
234bd882c8aSJames Wright   });
235bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
236bd882c8aSJames Wright }
237bd882c8aSJames Wright 
238bd882c8aSJames Wright //------------------------------------------------------------------------------
239bd882c8aSJames Wright // Weight kernel - tensor
240bd882c8aSJames Wright //------------------------------------------------------------------------------
241bd882c8aSJames Wright static int CeedBasisApplyWeight_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, const CeedBasis_Sycl *impl, CeedScalar *w) {
242bd882c8aSJames Wright   const CeedInt     dim         = impl->dim;
243bd882c8aSJames Wright   const CeedInt     Q_1d        = impl->Q_1d;
244bd882c8aSJames Wright   const CeedScalar *q_weight_1d = impl->d_q_weight_1d;
245bd882c8aSJames Wright 
246bd882c8aSJames Wright   const CeedInt  num_quad_x = Q_1d;
247bd882c8aSJames Wright   const CeedInt  num_quad_y = (dim > 1) ? Q_1d : 1;
248bd882c8aSJames Wright   const CeedInt  num_quad_z = (dim > 2) ? Q_1d : 1;
249bd882c8aSJames Wright   sycl::range<3> kernel_range(num_elem * num_quad_z, num_quad_y, num_quad_x);
250bd882c8aSJames Wright 
251bd882c8aSJames Wright   // Order queue
252bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
253bd882c8aSJames Wright   sycl_queue.parallel_for<CeedBasisSyclWeight>(kernel_range, {e}, [=](sycl::item<3> work_item) {
254bd882c8aSJames Wright     if (dim == 1) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]];
255bd882c8aSJames Wright     if (dim == 2) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]] * q_weight_1d[work_item[1]];
256bd882c8aSJames Wright     if (dim == 3) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]] * q_weight_1d[work_item[1]] * q_weight_1d[work_item[0] % Q_1d];
257bd882c8aSJames Wright   });
258bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
259bd882c8aSJames Wright }
260bd882c8aSJames Wright 
261bd882c8aSJames Wright //------------------------------------------------------------------------------
262bd882c8aSJames Wright // Basis apply - tensor
263bd882c8aSJames Wright //------------------------------------------------------------------------------
264bd882c8aSJames Wright static int CeedBasisApply_Sycl(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
265bd882c8aSJames Wright                                CeedVector v) {
266bd882c8aSJames Wright   Ceed              ceed;
2670ae60fd3SJeremy L Thompson   const CeedInt     is_transpose = t_mode == CEED_TRANSPOSE;
268bd882c8aSJames Wright   const CeedScalar *d_u;
269bd882c8aSJames Wright   CeedScalar       *d_v;
2700ae60fd3SJeremy L Thompson   Ceed_Sycl        *data;
2710ae60fd3SJeremy L Thompson   CeedBasis_Sycl   *impl;
2720ae60fd3SJeremy L Thompson 
2730ae60fd3SJeremy L Thompson   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
2740ae60fd3SJeremy L Thompson   CeedCallBackend(CeedGetData(ceed, &data));
2750ae60fd3SJeremy L Thompson   CeedCallBackend(CeedBasisGetData(basis, &impl));
2760ae60fd3SJeremy L Thompson 
2770ae60fd3SJeremy L Thompson   // Get read/write access to u, v
2780ae60fd3SJeremy L Thompson   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
2790ae60fd3SJeremy L Thompson   else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
280bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
281bd882c8aSJames Wright 
282bd882c8aSJames Wright   // Clear v for transpose operation
2830ae60fd3SJeremy L Thompson   if (is_transpose) {
284bd882c8aSJames Wright     CeedSize length;
285bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetLength(v, &length));
286bd882c8aSJames Wright     // Order queue
287bd882c8aSJames Wright     sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
288bd882c8aSJames Wright     data->sycl_queue.fill<CeedScalar>(d_v, 0, length, {e});
289bd882c8aSJames Wright   }
290bd882c8aSJames Wright 
291bd882c8aSJames Wright   // Basis action
292bd882c8aSJames Wright   switch (eval_mode) {
293*d07cdbe5SJeremy L Thompson     case CEED_EVAL_INTERP:
2940ae60fd3SJeremy L Thompson       if (is_transpose) {
295*d07cdbe5SJeremy L Thompson         CeedCallBackend(CeedBasisApplyInterp_Sycl<true>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
296bd882c8aSJames Wright       } else {
297*d07cdbe5SJeremy L Thompson         CeedCallBackend(CeedBasisApplyInterp_Sycl<false>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
298bd882c8aSJames Wright       }
299*d07cdbe5SJeremy L Thompson       break;
300*d07cdbe5SJeremy L Thompson     case CEED_EVAL_GRAD:
3010ae60fd3SJeremy L Thompson       if (is_transpose) {
302*d07cdbe5SJeremy L Thompson         CeedCallBackend(CeedBasisApplyGrad_Sycl<true>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
303bd882c8aSJames Wright       } else {
304*d07cdbe5SJeremy L Thompson         CeedCallBackend(CeedBasisApplyGrad_Sycl<false>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
305bd882c8aSJames Wright       }
306*d07cdbe5SJeremy L Thompson       break;
307*d07cdbe5SJeremy L Thompson     case CEED_EVAL_WEIGHT:
308bd882c8aSJames Wright       CeedCallBackend(CeedBasisApplyWeight_Sycl(data->sycl_queue, num_elem, impl, d_v));
309*d07cdbe5SJeremy L Thompson       break;
3100ae60fd3SJeremy L Thompson     case CEED_EVAL_NONE: /* handled separately below */
3110ae60fd3SJeremy L Thompson       break;
312bd882c8aSJames Wright     // LCOV_EXCL_START
313bd882c8aSJames Wright     case CEED_EVAL_DIV:
314bd882c8aSJames Wright     case CEED_EVAL_CURL:
3154e3038a5SJeremy L Thompson       return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
316bd882c8aSJames Wright       // LCOV_EXCL_STOP
317bd882c8aSJames Wright   }
318bd882c8aSJames Wright 
3190ae60fd3SJeremy L Thompson   // Restore vectors, cover CEED_EVAL_NONE
320bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
3210ae60fd3SJeremy L Thompson   if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
3220ae60fd3SJeremy L Thompson   if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
323bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
324bd882c8aSJames Wright }
325bd882c8aSJames Wright 
326bd882c8aSJames Wright //------------------------------------------------------------------------------
327bd882c8aSJames Wright // Interpolation kernel - non-tensor
328bd882c8aSJames Wright //------------------------------------------------------------------------------
3290ae60fd3SJeremy L Thompson static int CeedBasisApplyNonTensorInterp_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, CeedInt is_transpose, const CeedBasisNonTensor_Sycl *impl,
330bd882c8aSJames Wright                                               const CeedScalar *d_U, CeedScalar *d_V) {
331bd882c8aSJames Wright   const CeedInt     num_comp      = impl->num_comp;
3320ae60fd3SJeremy L Thompson   const CeedInt     P             = is_transpose ? impl->num_qpts : impl->num_nodes;
3330ae60fd3SJeremy L Thompson   const CeedInt     Q             = is_transpose ? impl->num_nodes : impl->num_qpts;
3340ae60fd3SJeremy L Thompson   const CeedInt     stride_0      = is_transpose ? 1 : impl->num_nodes;
3350ae60fd3SJeremy L Thompson   const CeedInt     stride_1      = is_transpose ? impl->num_nodes : 1;
336bd882c8aSJames Wright   const CeedInt     u_stride      = P;
337bd882c8aSJames Wright   const CeedInt     v_stride      = Q;
338bd882c8aSJames Wright   const CeedInt     u_comp_stride = u_stride * num_elem;
339bd882c8aSJames Wright   const CeedInt     v_comp_stride = v_stride * num_elem;
340bd882c8aSJames Wright   const CeedInt     u_size        = P;
341bd882c8aSJames Wright   const CeedInt     v_size        = Q;
342bd882c8aSJames Wright   const CeedScalar *d_B           = impl->d_interp;
343bd882c8aSJames Wright 
344bd882c8aSJames Wright   sycl::range<2> kernel_range(num_elem, v_size);
345bd882c8aSJames Wright 
346bd882c8aSJames Wright   // Order queue
347bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
348bd882c8aSJames Wright   sycl_queue.parallel_for<CeedBasisSyclInterpNT>(kernel_range, {e}, [=](sycl::id<2> indx) {
349bd882c8aSJames Wright     const CeedInt i    = indx[1];
350bd882c8aSJames Wright     const CeedInt elem = indx[0];
351bd882c8aSJames Wright 
352bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
353bd882c8aSJames Wright       const CeedScalar *U = d_U + elem * u_stride + comp * u_comp_stride;
354bd882c8aSJames Wright       CeedScalar        V = 0.0;
355bd882c8aSJames Wright 
356bd882c8aSJames Wright       for (CeedInt j = 0; j < u_size; ++j) {
357bd882c8aSJames Wright         V += d_B[i * stride_0 + j * stride_1] * U[j];
358bd882c8aSJames Wright       }
359bd882c8aSJames Wright       d_V[i + elem * v_stride + comp * v_comp_stride] = V;
360bd882c8aSJames Wright     }
361bd882c8aSJames Wright   });
362bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
363bd882c8aSJames Wright }
364bd882c8aSJames Wright 
365bd882c8aSJames Wright //------------------------------------------------------------------------------
366bd882c8aSJames Wright // Gradient kernel - non-tensor
367bd882c8aSJames Wright //------------------------------------------------------------------------------
3680ae60fd3SJeremy L Thompson static int CeedBasisApplyNonTensorGrad_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, CeedInt is_transpose, const CeedBasisNonTensor_Sycl *impl,
369bd882c8aSJames Wright                                             const CeedScalar *d_U, CeedScalar *d_V) {
370bd882c8aSJames Wright   const CeedInt     num_comp      = impl->num_comp;
3710ae60fd3SJeremy L Thompson   const CeedInt     P             = is_transpose ? impl->num_qpts : impl->num_nodes;
3720ae60fd3SJeremy L Thompson   const CeedInt     Q             = is_transpose ? impl->num_nodes : impl->num_qpts;
3730ae60fd3SJeremy L Thompson   const CeedInt     stride_0      = is_transpose ? 1 : impl->num_nodes;
3740ae60fd3SJeremy L Thompson   const CeedInt     stride_1      = is_transpose ? impl->num_nodes : 1;
375bd882c8aSJames Wright   const CeedInt     g_dim_stride  = P * Q;
376bd882c8aSJames Wright   const CeedInt     u_stride      = P;
377bd882c8aSJames Wright   const CeedInt     v_stride      = Q;
378bd882c8aSJames Wright   const CeedInt     u_comp_stride = u_stride * num_elem;
379bd882c8aSJames Wright   const CeedInt     v_comp_stride = v_stride * num_elem;
380bd882c8aSJames Wright   const CeedInt     u_dim_stride  = u_comp_stride * num_comp;
381bd882c8aSJames Wright   const CeedInt     v_dim_stride  = v_comp_stride * num_comp;
382bd882c8aSJames Wright   const CeedInt     u_size        = P;
383bd882c8aSJames Wright   const CeedInt     v_size        = Q;
3840ae60fd3SJeremy L Thompson   const CeedInt     in_dim        = is_transpose ? impl->dim : 1;
3850ae60fd3SJeremy L Thompson   const CeedInt     out_dim       = is_transpose ? 1 : impl->dim;
386bd882c8aSJames Wright   const CeedScalar *d_G           = impl->d_grad;
387bd882c8aSJames Wright 
388bd882c8aSJames Wright   sycl::range<2> kernel_range(num_elem, v_size);
389bd882c8aSJames Wright 
390bd882c8aSJames Wright   // Order queue
391bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
392bd882c8aSJames Wright   sycl_queue.parallel_for<CeedBasisSyclGradNT>(kernel_range, {e}, [=](sycl::id<2> indx) {
393bd882c8aSJames Wright     const CeedInt i    = indx[1];
394bd882c8aSJames Wright     const CeedInt elem = indx[0];
395bd882c8aSJames Wright 
396bd882c8aSJames Wright     for (CeedInt comp = 0; comp < num_comp; comp++) {
397bd882c8aSJames Wright       CeedScalar V[3] = {0.0, 0.0, 0.0};
398bd882c8aSJames Wright 
399bd882c8aSJames Wright       for (CeedInt d1 = 0; d1 < in_dim; ++d1) {
400bd882c8aSJames Wright         const CeedScalar *U = d_U + elem * u_stride + comp * u_comp_stride + d1 * u_dim_stride;
401bd882c8aSJames Wright         const CeedScalar *G = d_G + i * stride_0 + d1 * g_dim_stride;
402bd882c8aSJames Wright 
403bd882c8aSJames Wright         for (CeedInt j = 0; j < u_size; ++j) {
404bd882c8aSJames Wright           const CeedScalar Uj = U[j];
405bd882c8aSJames Wright 
406bd882c8aSJames Wright           for (CeedInt d0 = 0; d0 < out_dim; ++d0) {
407bd882c8aSJames Wright             V[d0] += G[j * stride_1 + d0 * g_dim_stride] * Uj;
408bd882c8aSJames Wright           }
409bd882c8aSJames Wright         }
410bd882c8aSJames Wright       }
411bd882c8aSJames Wright       for (CeedInt d0 = 0; d0 < out_dim; ++d0) {
412bd882c8aSJames Wright         d_V[i + elem * v_stride + comp * v_comp_stride + d0 * v_dim_stride] = V[d0];
413bd882c8aSJames Wright       }
414bd882c8aSJames Wright     }
415bd882c8aSJames Wright   });
416bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
417bd882c8aSJames Wright }
418bd882c8aSJames Wright 
419bd882c8aSJames Wright //------------------------------------------------------------------------------
420bd882c8aSJames Wright // Weight kernel - non-tensor
421bd882c8aSJames Wright //------------------------------------------------------------------------------
422bd882c8aSJames Wright static int CeedBasisApplyNonTensorWeight_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, const CeedBasisNonTensor_Sycl *impl, CeedScalar *d_V) {
423bd882c8aSJames Wright   const CeedInt     num_qpts = impl->num_qpts;
424bd882c8aSJames Wright   const CeedScalar *q_weight = impl->d_q_weight;
425bd882c8aSJames Wright 
426bd882c8aSJames Wright   sycl::range<2> kernel_range(num_elem, num_qpts);
427bd882c8aSJames Wright 
428bd882c8aSJames Wright   // Order queue
429bd882c8aSJames Wright   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
430bd882c8aSJames Wright   sycl_queue.parallel_for<CeedBasisSyclWeightNT>(kernel_range, {e}, [=](sycl::id<2> indx) {
431bd882c8aSJames Wright     const CeedInt i          = indx[1];
432bd882c8aSJames Wright     const CeedInt elem       = indx[0];
433bd882c8aSJames Wright     d_V[i + elem * num_qpts] = q_weight[i];
434bd882c8aSJames Wright   });
435bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
436bd882c8aSJames Wright }
437bd882c8aSJames Wright 
438bd882c8aSJames Wright //------------------------------------------------------------------------------
439bd882c8aSJames Wright // Basis apply - non-tensor
440bd882c8aSJames Wright //------------------------------------------------------------------------------
441bd882c8aSJames Wright static int CeedBasisApplyNonTensor_Sycl(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
442bd882c8aSJames Wright                                         CeedVector v) {
443bd882c8aSJames Wright   Ceed                     ceed;
4440ae60fd3SJeremy L Thompson   const CeedInt            is_transpose = t_mode == CEED_TRANSPOSE;
445bd882c8aSJames Wright   const CeedScalar        *d_u;
446bd882c8aSJames Wright   CeedScalar              *d_v;
4470ae60fd3SJeremy L Thompson   CeedBasisNonTensor_Sycl *impl;
4480ae60fd3SJeremy L Thompson   Ceed_Sycl               *data;
4490ae60fd3SJeremy L Thompson 
4500ae60fd3SJeremy L Thompson   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
4510ae60fd3SJeremy L Thompson   CeedCallBackend(CeedBasisGetData(basis, &impl));
4520ae60fd3SJeremy L Thompson   CeedCallBackend(CeedGetData(ceed, &data));
4530ae60fd3SJeremy L Thompson 
4540ae60fd3SJeremy L Thompson   // Get read/write access to u, v
4550ae60fd3SJeremy L Thompson   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
4560ae60fd3SJeremy L Thompson   else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
457bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
458bd882c8aSJames Wright 
459bd882c8aSJames Wright   // Clear v for transpose operation
4600ae60fd3SJeremy L Thompson   if (is_transpose) {
461bd882c8aSJames Wright     CeedSize length;
462bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetLength(v, &length));
463bd882c8aSJames Wright     // Order queue
464bd882c8aSJames Wright     sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
465bd882c8aSJames Wright     data->sycl_queue.fill<CeedScalar>(d_v, 0, length, {e});
466bd882c8aSJames Wright   }
467bd882c8aSJames Wright 
468bd882c8aSJames Wright   // Apply basis operation
469bd882c8aSJames Wright   switch (eval_mode) {
470*d07cdbe5SJeremy L Thompson     case CEED_EVAL_INTERP:
4710ae60fd3SJeremy L Thompson       CeedCallBackend(CeedBasisApplyNonTensorInterp_Sycl(data->sycl_queue, num_elem, is_transpose, impl, d_u, d_v));
472*d07cdbe5SJeremy L Thompson       break;
473*d07cdbe5SJeremy L Thompson     case CEED_EVAL_GRAD:
4740ae60fd3SJeremy L Thompson       CeedCallBackend(CeedBasisApplyNonTensorGrad_Sycl(data->sycl_queue, num_elem, is_transpose, impl, d_u, d_v));
475*d07cdbe5SJeremy L Thompson       break;
476*d07cdbe5SJeremy L Thompson     case CEED_EVAL_WEIGHT:
477bd882c8aSJames Wright       CeedCallBackend(CeedBasisApplyNonTensorWeight_Sycl(data->sycl_queue, num_elem, impl, d_v));
478*d07cdbe5SJeremy L Thompson       break;
4790ae60fd3SJeremy L Thompson     case CEED_EVAL_NONE: /* handled separately below */
4800ae60fd3SJeremy L Thompson       break;
481bd882c8aSJames Wright     // LCOV_EXCL_START
482bd882c8aSJames Wright     case CEED_EVAL_DIV:
483bd882c8aSJames Wright     case CEED_EVAL_CURL:
4849d1bceceSJames Wright       return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
485bd882c8aSJames Wright       // LCOV_EXCL_STOP
486bd882c8aSJames Wright   }
487bd882c8aSJames Wright 
4880ae60fd3SJeremy L Thompson   // Restore vectors, cover CEED_EVAL_NONE
489bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
4900ae60fd3SJeremy L Thompson   if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
4910ae60fd3SJeremy L Thompson   if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
4920ae60fd3SJeremy L Thompson 
493bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
494bd882c8aSJames Wright }
495bd882c8aSJames Wright 
496bd882c8aSJames Wright //------------------------------------------------------------------------------
497bd882c8aSJames Wright // Destroy tensor basis
498bd882c8aSJames Wright //------------------------------------------------------------------------------
499bd882c8aSJames Wright static int CeedBasisDestroy_Sycl(CeedBasis basis) {
500bd882c8aSJames Wright   Ceed ceed;
501bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
502bd882c8aSJames Wright   CeedBasis_Sycl *impl;
503bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetData(basis, &impl));
504bd882c8aSJames Wright   Ceed_Sycl *data;
505bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
506bd882c8aSJames Wright 
507bd882c8aSJames Wright   // Wait for all work to finish before freeing memory
508bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
509bd882c8aSJames Wright 
510bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_q_weight_1d, data->sycl_context));
511bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_interp_1d, data->sycl_context));
512bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_grad_1d, data->sycl_context));
513bd882c8aSJames Wright 
514bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
515bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
516bd882c8aSJames Wright }
517bd882c8aSJames Wright 
518bd882c8aSJames Wright //------------------------------------------------------------------------------
519bd882c8aSJames Wright // Destroy non-tensor basis
520bd882c8aSJames Wright //------------------------------------------------------------------------------
521bd882c8aSJames Wright static int CeedBasisDestroyNonTensor_Sycl(CeedBasis basis) {
522bd882c8aSJames Wright   Ceed ceed;
523bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
524bd882c8aSJames Wright   CeedBasisNonTensor_Sycl *impl;
525bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetData(basis, &impl));
526bd882c8aSJames Wright   Ceed_Sycl *data;
527bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
528bd882c8aSJames Wright 
529bd882c8aSJames Wright   // Wait for all work to finish before freeing memory
530bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
531bd882c8aSJames Wright 
532bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_q_weight, data->sycl_context));
533bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_interp, data->sycl_context));
534bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_grad, data->sycl_context));
535bd882c8aSJames Wright 
536bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
537bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
538bd882c8aSJames Wright }
539bd882c8aSJames Wright 
540bd882c8aSJames Wright //------------------------------------------------------------------------------
541bd882c8aSJames Wright // Create tensor
542bd882c8aSJames Wright //------------------------------------------------------------------------------
543bd882c8aSJames Wright int CeedBasisCreateTensorH1_Sycl(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
544bd882c8aSJames Wright                                  const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
545bd882c8aSJames Wright   Ceed ceed;
546bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
547bd882c8aSJames Wright   CeedBasis_Sycl *impl;
548bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(1, &impl));
549bd882c8aSJames Wright   Ceed_Sycl *data;
550bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
551bd882c8aSJames Wright 
552bd882c8aSJames Wright   CeedInt num_comp;
553bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
554bd882c8aSJames Wright 
555bd882c8aSJames Wright   const CeedInt num_nodes = CeedIntPow(P_1d, dim);
556bd882c8aSJames Wright   const CeedInt num_qpts  = CeedIntPow(Q_1d, dim);
557bd882c8aSJames Wright 
558bd882c8aSJames Wright   impl->dim       = dim;
559bd882c8aSJames Wright   impl->P_1d      = P_1d;
560bd882c8aSJames Wright   impl->Q_1d      = Q_1d;
561bd882c8aSJames Wright   impl->num_comp  = num_comp;
562bd882c8aSJames Wright   impl->num_nodes = num_nodes;
563bd882c8aSJames Wright   impl->num_qpts  = num_qpts;
564bd882c8aSJames Wright   impl->buf_len   = num_comp * CeedIntMax(num_nodes, num_qpts);
565bd882c8aSJames Wright   impl->op_len    = Q_1d * P_1d;
566bd882c8aSJames Wright 
567bd882c8aSJames Wright   // Order queue
568bd882c8aSJames Wright   sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
569bd882c8aSJames Wright 
570bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_q_weight_1d = sycl::malloc_device<CeedScalar>(Q_1d, data->sycl_device, data->sycl_context));
571bd882c8aSJames Wright   sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight_1d, impl->d_q_weight_1d, Q_1d, {e});
572bd882c8aSJames Wright 
573bd882c8aSJames Wright   const CeedInt interp_length = Q_1d * P_1d;
574bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_interp_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
575bd882c8aSJames Wright   sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp_1d, impl->d_interp_1d, interp_length, {e});
576bd882c8aSJames Wright 
577bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_grad_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
578bd882c8aSJames Wright   sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad_1d, impl->d_grad_1d, interp_length, {e});
579bd882c8aSJames Wright 
580bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_weight, copy_interp, copy_grad}));
581bd882c8aSJames Wright 
582bd882c8aSJames Wright   std::vector<sycl::kernel_id> kernel_ids = {sycl::get_kernel_id<CeedBasisSyclInterp<1>>(), sycl::get_kernel_id<CeedBasisSyclInterp<0>>(),
583bd882c8aSJames Wright                                              sycl::get_kernel_id<CeedBasisSyclGrad<1>>(), sycl::get_kernel_id<CeedBasisSyclGrad<0>>()};
584bd882c8aSJames Wright 
585bd882c8aSJames Wright   sycl::kernel_bundle<sycl::bundle_state::input> input_bundle = sycl::get_kernel_bundle<sycl::bundle_state::input>(data->sycl_context, kernel_ids);
586bd882c8aSJames Wright   input_bundle.set_specialization_constant<BASIS_DIM_ID>(dim);
587bd882c8aSJames Wright   input_bundle.set_specialization_constant<BASIS_NUM_COMP_ID>(num_comp);
588bd882c8aSJames Wright   input_bundle.set_specialization_constant<BASIS_Q_1D_ID>(Q_1d);
589bd882c8aSJames Wright   input_bundle.set_specialization_constant<BASIS_P_1D_ID>(P_1d);
590bd882c8aSJames Wright 
591bd882c8aSJames Wright   CeedCallSycl(ceed, impl->sycl_module = new SyclModule_t(sycl::build(input_bundle)));
592bd882c8aSJames Wright 
593bd882c8aSJames Wright   CeedCallBackend(CeedBasisSetData(basis, impl));
594bd882c8aSJames Wright 
595bd882c8aSJames Wright   // Register backend functions
596bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApply_Sycl));
597bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Sycl));
598bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
599bd882c8aSJames Wright }
600bd882c8aSJames Wright 
601bd882c8aSJames Wright //------------------------------------------------------------------------------
602bd882c8aSJames Wright // Create non-tensor
603bd882c8aSJames Wright //------------------------------------------------------------------------------
604bd882c8aSJames Wright int CeedBasisCreateH1_Sycl(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad,
605dd64fc84SJeremy L Thompson                            const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
606bd882c8aSJames Wright   Ceed ceed;
607bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
608bd882c8aSJames Wright   CeedBasisNonTensor_Sycl *impl;
609bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(1, &impl));
610bd882c8aSJames Wright   Ceed_Sycl *data;
611bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
612bd882c8aSJames Wright 
613bd882c8aSJames Wright   CeedInt num_comp;
614bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
615bd882c8aSJames Wright 
616bd882c8aSJames Wright   impl->dim       = dim;
617bd882c8aSJames Wright   impl->num_comp  = num_comp;
618bd882c8aSJames Wright   impl->num_nodes = num_nodes;
619bd882c8aSJames Wright   impl->num_qpts  = num_qpts;
620bd882c8aSJames Wright 
621bd882c8aSJames Wright   // Order queue
622bd882c8aSJames Wright   sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
623bd882c8aSJames Wright 
624bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_q_weight = sycl::malloc_device<CeedScalar>(num_qpts, data->sycl_device, data->sycl_context));
625bd882c8aSJames Wright   sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight, impl->d_q_weight, num_qpts, {e});
626bd882c8aSJames Wright 
627bd882c8aSJames Wright   const CeedInt interp_length = num_qpts * num_nodes;
628bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_interp = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
629bd882c8aSJames Wright   sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp, impl->d_interp, interp_length, {e});
630bd882c8aSJames Wright 
631bd882c8aSJames Wright   const CeedInt grad_length = num_qpts * num_nodes * dim;
632bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_grad = sycl::malloc_device<CeedScalar>(grad_length, data->sycl_device, data->sycl_context));
633bd882c8aSJames Wright   sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad, impl->d_grad, grad_length, {e});
634bd882c8aSJames Wright 
635bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_weight, copy_interp, copy_grad}));
636bd882c8aSJames Wright 
637bd882c8aSJames Wright   CeedCallBackend(CeedBasisSetData(basis, impl));
638bd882c8aSJames Wright 
639bd882c8aSJames Wright   // Register backend functions
640bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Sycl));
641bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Sycl));
642bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
643bd882c8aSJames Wright }
644ff1e7120SSebastian Grimberg 
645bd882c8aSJames Wright //------------------------------------------------------------------------------
646