1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3bd882c8aSJames Wright //
4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5bd882c8aSJames Wright //
6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed
7bd882c8aSJames Wright
8bd882c8aSJames Wright #include <ceed/backend.h>
9bd882c8aSJames Wright #include <ceed/ceed.h>
10bd882c8aSJames Wright #include <ceed/jit-tools.h>
11bd882c8aSJames Wright
12bd882c8aSJames Wright #include <sycl/sycl.hpp>
13bd882c8aSJames Wright #include <vector>
14bd882c8aSJames Wright
15bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
16bd882c8aSJames Wright #include "ceed-sycl-ref.hpp"
17bd882c8aSJames Wright
18bd882c8aSJames Wright template <int>
19bd882c8aSJames Wright class CeedBasisSyclInterp;
20bd882c8aSJames Wright template <int>
21bd882c8aSJames Wright class CeedBasisSyclGrad;
22bd882c8aSJames Wright class CeedBasisSyclWeight;
23bd882c8aSJames Wright
24bd882c8aSJames Wright class CeedBasisSyclInterpNT;
25bd882c8aSJames Wright class CeedBasisSyclGradNT;
26bd882c8aSJames Wright class CeedBasisSyclWeightNT;
27bd882c8aSJames Wright
28bd882c8aSJames Wright using SpecID = sycl::specialization_id<CeedInt>;
29bd882c8aSJames Wright
30bd882c8aSJames Wright static constexpr SpecID BASIS_DIM_ID;
31bd882c8aSJames Wright static constexpr SpecID BASIS_NUM_COMP_ID;
32bd882c8aSJames Wright static constexpr SpecID BASIS_P_1D_ID;
33bd882c8aSJames Wright static constexpr SpecID BASIS_Q_1D_ID;
34bd882c8aSJames Wright
35bd882c8aSJames Wright //------------------------------------------------------------------------------
36bd882c8aSJames Wright // Interpolation kernel - tensor
37bd882c8aSJames Wright //------------------------------------------------------------------------------
380ae60fd3SJeremy L Thompson template <int is_transpose>
CeedBasisApplyInterp_Sycl(sycl::queue & sycl_queue,const SyclModule_t & sycl_module,CeedInt num_elem,const CeedBasis_Sycl * impl,const CeedScalar * u,CeedScalar * v)39bd882c8aSJames Wright static int CeedBasisApplyInterp_Sycl(sycl::queue &sycl_queue, const SyclModule_t &sycl_module, CeedInt num_elem, const CeedBasis_Sycl *impl,
40bd882c8aSJames Wright const CeedScalar *u, CeedScalar *v) {
41bd882c8aSJames Wright const CeedInt buf_len = impl->buf_len;
42bd882c8aSJames Wright const CeedInt op_len = impl->op_len;
43bd882c8aSJames Wright const CeedScalar *interp_1d = impl->d_interp_1d;
44bd882c8aSJames Wright
45bd882c8aSJames Wright const sycl::device &sycl_device = sycl_queue.get_device();
46bd882c8aSJames Wright const CeedInt max_work_group_size = 32;
47bd882c8aSJames Wright const CeedInt work_group_size = CeedIntMin(impl->num_qpts, max_work_group_size);
48bd882c8aSJames Wright sycl::range<1> local_range(work_group_size);
49bd882c8aSJames Wright sycl::range<1> global_range(num_elem * work_group_size);
50bd882c8aSJames Wright sycl::nd_range<1> kernel_range(global_range, local_range);
51bd882c8aSJames Wright
521f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e;
531f4b1b45SUmesh Unnikrishnan
541f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()};
551f4b1b45SUmesh Unnikrishnan
56bd882c8aSJames Wright sycl_queue.submit([&](sycl::handler &cgh) {
571f4b1b45SUmesh Unnikrishnan cgh.depends_on(e);
58bd882c8aSJames Wright cgh.use_kernel_bundle(sycl_module);
59bd882c8aSJames Wright
60bd882c8aSJames Wright sycl::local_accessor<CeedScalar> s_mem(op_len + 2 * buf_len, cgh);
61bd882c8aSJames Wright
620ae60fd3SJeremy L Thompson cgh.parallel_for<CeedBasisSyclInterp<is_transpose>>(kernel_range, [=](sycl::nd_item<1> work_item, sycl::kernel_handler kh) {
63bd882c8aSJames Wright //-------------------------------------------------------------->
64bd882c8aSJames Wright // Retrieve spec constant values
65bd882c8aSJames Wright const CeedInt dim = kh.get_specialization_constant<BASIS_DIM_ID>();
66bd882c8aSJames Wright const CeedInt num_comp = kh.get_specialization_constant<BASIS_NUM_COMP_ID>();
67bd882c8aSJames Wright const CeedInt P_1d = kh.get_specialization_constant<BASIS_P_1D_ID>();
68bd882c8aSJames Wright const CeedInt Q_1d = kh.get_specialization_constant<BASIS_Q_1D_ID>();
69bd882c8aSJames Wright //-------------------------------------------------------------->
70bd882c8aSJames Wright const CeedInt num_nodes = CeedIntPow(P_1d, dim);
71bd882c8aSJames Wright const CeedInt num_qpts = CeedIntPow(Q_1d, dim);
720ae60fd3SJeremy L Thompson const CeedInt P = is_transpose ? Q_1d : P_1d;
730ae60fd3SJeremy L Thompson const CeedInt Q = is_transpose ? P_1d : Q_1d;
740ae60fd3SJeremy L Thompson const CeedInt stride_0 = is_transpose ? 1 : P_1d;
750ae60fd3SJeremy L Thompson const CeedInt stride_1 = is_transpose ? P_1d : 1;
760ae60fd3SJeremy L Thompson const CeedInt u_stride = is_transpose ? num_qpts : num_nodes;
770ae60fd3SJeremy L Thompson const CeedInt v_stride = is_transpose ? num_nodes : num_qpts;
78bd882c8aSJames Wright const CeedInt u_comp_stride = num_elem * u_stride;
79bd882c8aSJames Wright const CeedInt v_comp_stride = num_elem * v_stride;
80bd882c8aSJames Wright const CeedInt u_size = u_stride;
81bd882c8aSJames Wright
82bd882c8aSJames Wright sycl::group work_group = work_item.get_group();
83bd882c8aSJames Wright const CeedInt i = work_item.get_local_linear_id();
84bd882c8aSJames Wright const CeedInt group_size = work_group.get_local_linear_range();
85bd882c8aSJames Wright const CeedInt elem = work_group.get_group_linear_id();
86bd882c8aSJames Wright
8733bb61d4SKris Rowe CeedScalar *s_interp_1d = s_mem.get_multi_ptr<sycl::access::decorated::yes>().get();
88bd882c8aSJames Wright CeedScalar *s_buffer_1 = s_interp_1d + Q * P;
89bd882c8aSJames Wright CeedScalar *s_buffer_2 = s_buffer_1 + buf_len;
90bd882c8aSJames Wright
91bd882c8aSJames Wright for (CeedInt k = i; k < P * Q; k += group_size) {
92bd882c8aSJames Wright s_interp_1d[k] = interp_1d[k];
93bd882c8aSJames Wright }
94bd882c8aSJames Wright
95bd882c8aSJames Wright // Apply basis element by element
96bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) {
97bd882c8aSJames Wright const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
98bd882c8aSJames Wright CeedScalar *cur_v = v + elem * v_stride + comp * v_comp_stride;
99bd882c8aSJames Wright
100bd882c8aSJames Wright for (CeedInt k = i; k < u_size; k += group_size) {
101bd882c8aSJames Wright s_buffer_1[k] = cur_u[k];
102bd882c8aSJames Wright }
103bd882c8aSJames Wright
104bd882c8aSJames Wright CeedInt pre = u_size;
105bd882c8aSJames Wright CeedInt post = 1;
106bd882c8aSJames Wright
107bd882c8aSJames Wright for (CeedInt d = 0; d < dim; d++) {
108bd882c8aSJames Wright // Use older version of sycl workgroup barrier for performance reasons
109bd882c8aSJames Wright // Can be updated in future to align with SYCL2020 spec if performance bottleneck is removed
110bd882c8aSJames Wright // sycl::group_barrier(work_group);
111bd882c8aSJames Wright work_item.barrier(sycl::access::fence_space::local_space);
112bd882c8aSJames Wright
113bd882c8aSJames Wright pre /= P;
114bd882c8aSJames Wright const CeedScalar *in = d % 2 ? s_buffer_2 : s_buffer_1;
115bd882c8aSJames Wright CeedScalar *out = d == dim - 1 ? cur_v : (d % 2 ? s_buffer_1 : s_buffer_2);
116bd882c8aSJames Wright
117bd882c8aSJames Wright // Contract along middle index
118bd882c8aSJames Wright const CeedInt writeLen = pre * post * Q;
119bd882c8aSJames Wright for (CeedInt k = i; k < writeLen; k += group_size) {
120bd882c8aSJames Wright const CeedInt c = k % post;
121bd882c8aSJames Wright const CeedInt j = (k / post) % Q;
122bd882c8aSJames Wright const CeedInt a = k / (post * Q);
123bd882c8aSJames Wright
124bd882c8aSJames Wright CeedScalar vk = 0;
125bd882c8aSJames Wright for (CeedInt b = 0; b < P; b++) {
126bd882c8aSJames Wright vk += s_interp_1d[j * stride_0 + b * stride_1] * in[(a * P + b) * post + c];
127bd882c8aSJames Wright }
128bd882c8aSJames Wright out[k] = vk;
129bd882c8aSJames Wright }
130bd882c8aSJames Wright post *= Q;
131bd882c8aSJames Wright }
132bd882c8aSJames Wright }
133bd882c8aSJames Wright });
134bd882c8aSJames Wright });
135bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
136bd882c8aSJames Wright }
137bd882c8aSJames Wright
138bd882c8aSJames Wright //------------------------------------------------------------------------------
139bd882c8aSJames Wright // Gradient kernel - tensor
140bd882c8aSJames Wright //------------------------------------------------------------------------------
1410ae60fd3SJeremy L Thompson template <int is_transpose>
CeedBasisApplyGrad_Sycl(sycl::queue & sycl_queue,const SyclModule_t & sycl_module,CeedInt num_elem,const CeedBasis_Sycl * impl,const CeedScalar * u,CeedScalar * v)142bd882c8aSJames Wright static int CeedBasisApplyGrad_Sycl(sycl::queue &sycl_queue, const SyclModule_t &sycl_module, CeedInt num_elem, const CeedBasis_Sycl *impl,
143bd882c8aSJames Wright const CeedScalar *u, CeedScalar *v) {
144bd882c8aSJames Wright const CeedInt buf_len = impl->buf_len;
145bd882c8aSJames Wright const CeedInt op_len = impl->op_len;
146bd882c8aSJames Wright const CeedScalar *interp_1d = impl->d_interp_1d;
147bd882c8aSJames Wright const CeedScalar *grad_1d = impl->d_grad_1d;
148bd882c8aSJames Wright
149bd882c8aSJames Wright const sycl::device &sycl_device = sycl_queue.get_device();
150bd882c8aSJames Wright const CeedInt work_group_size = 32;
151bd882c8aSJames Wright sycl::range<1> local_range(work_group_size);
152bd882c8aSJames Wright sycl::range<1> global_range(num_elem * work_group_size);
153bd882c8aSJames Wright sycl::nd_range<1> kernel_range(global_range, local_range);
154bd882c8aSJames Wright
1551f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e;
1561f4b1b45SUmesh Unnikrishnan
1571f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()};
1581f4b1b45SUmesh Unnikrishnan
159bd882c8aSJames Wright sycl_queue.submit([&](sycl::handler &cgh) {
1601f4b1b45SUmesh Unnikrishnan cgh.depends_on(e);
161bd882c8aSJames Wright cgh.use_kernel_bundle(sycl_module);
162bd882c8aSJames Wright
163bd882c8aSJames Wright sycl::local_accessor<CeedScalar> s_mem(2 * (op_len + buf_len), cgh);
164bd882c8aSJames Wright
1650ae60fd3SJeremy L Thompson cgh.parallel_for<CeedBasisSyclGrad<is_transpose>>(kernel_range, [=](sycl::nd_item<1> work_item, sycl::kernel_handler kh) {
166bd882c8aSJames Wright //-------------------------------------------------------------->
167bd882c8aSJames Wright // Retrieve spec constant values
168bd882c8aSJames Wright const CeedInt dim = kh.get_specialization_constant<BASIS_DIM_ID>();
169bd882c8aSJames Wright const CeedInt num_comp = kh.get_specialization_constant<BASIS_NUM_COMP_ID>();
170bd882c8aSJames Wright const CeedInt P_1d = kh.get_specialization_constant<BASIS_P_1D_ID>();
171bd882c8aSJames Wright const CeedInt Q_1d = kh.get_specialization_constant<BASIS_Q_1D_ID>();
172bd882c8aSJames Wright //-------------------------------------------------------------->
173bd882c8aSJames Wright const CeedInt num_nodes = CeedIntPow(P_1d, dim);
174bd882c8aSJames Wright const CeedInt num_qpts = CeedIntPow(Q_1d, dim);
1750ae60fd3SJeremy L Thompson const CeedInt P = is_transpose ? Q_1d : P_1d;
1760ae60fd3SJeremy L Thompson const CeedInt Q = is_transpose ? P_1d : Q_1d;
1770ae60fd3SJeremy L Thompson const CeedInt stride_0 = is_transpose ? 1 : P_1d;
1780ae60fd3SJeremy L Thompson const CeedInt stride_1 = is_transpose ? P_1d : 1;
1790ae60fd3SJeremy L Thompson const CeedInt u_stride = is_transpose ? num_qpts : num_nodes;
1800ae60fd3SJeremy L Thompson const CeedInt v_stride = is_transpose ? num_nodes : num_qpts;
181bd882c8aSJames Wright const CeedInt u_comp_stride = num_elem * u_stride;
182bd882c8aSJames Wright const CeedInt v_comp_stride = num_elem * v_stride;
1830ae60fd3SJeremy L Thompson const CeedInt u_dim_stride = is_transpose ? num_elem * num_qpts * num_comp : 0;
1840ae60fd3SJeremy L Thompson const CeedInt v_dim_stride = is_transpose ? 0 : num_elem * num_qpts * num_comp;
185bd882c8aSJames Wright sycl::group work_group = work_item.get_group();
186bd882c8aSJames Wright const CeedInt i = work_item.get_local_linear_id();
187bd882c8aSJames Wright const CeedInt group_size = work_group.get_local_linear_range();
188bd882c8aSJames Wright const CeedInt elem = work_group.get_group_linear_id();
189bd882c8aSJames Wright
19033bb61d4SKris Rowe CeedScalar *s_interp_1d = s_mem.get_multi_ptr<sycl::access::decorated::yes>().get();
191bd882c8aSJames Wright CeedScalar *s_grad_1d = s_interp_1d + P * Q;
192bd882c8aSJames Wright CeedScalar *s_buffer_1 = s_grad_1d + P * Q;
193bd882c8aSJames Wright CeedScalar *s_buffer_2 = s_buffer_1 + buf_len;
194bd882c8aSJames Wright
195bd882c8aSJames Wright for (CeedInt k = i; k < P * Q; k += group_size) {
196bd882c8aSJames Wright s_interp_1d[k] = interp_1d[k];
197bd882c8aSJames Wright s_grad_1d[k] = grad_1d[k];
198bd882c8aSJames Wright }
199bd882c8aSJames Wright
200bd882c8aSJames Wright // Apply basis element by element
201bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) {
202bd882c8aSJames Wright for (CeedInt dim_1 = 0; dim_1 < dim; dim_1++) {
2030ae60fd3SJeremy L Thompson CeedInt pre = is_transpose ? num_qpts : num_nodes;
204bd882c8aSJames Wright CeedInt post = 1;
205bd882c8aSJames Wright const CeedScalar *cur_u = u + elem * u_stride + dim_1 * u_dim_stride + comp * u_comp_stride;
206bd882c8aSJames Wright CeedScalar *cur_v = v + elem * v_stride + dim_1 * v_dim_stride + comp * v_comp_stride;
207bd882c8aSJames Wright
208bd882c8aSJames Wright for (CeedInt dim_2 = 0; dim_2 < dim; dim_2++) {
209bd882c8aSJames Wright // Use older version of sycl workgroup barrier for performance reasons
210bd882c8aSJames Wright // Can be updated in future to align with SYCL2020 spec if performance bottleneck is removed
211bd882c8aSJames Wright // sycl::group_barrier(work_group);
212bd882c8aSJames Wright work_item.barrier(sycl::access::fence_space::local_space);
213bd882c8aSJames Wright
214bd882c8aSJames Wright pre /= P;
215bd882c8aSJames Wright const CeedScalar *op = dim_1 == dim_2 ? s_grad_1d : s_interp_1d;
216bd882c8aSJames Wright const CeedScalar *in = (dim_2 == 0 ? cur_u : (dim_2 % 2 ? s_buffer_2 : s_buffer_1));
217bd882c8aSJames Wright CeedScalar *out = dim_2 == dim - 1 ? cur_v : (dim_2 % 2 ? s_buffer_1 : s_buffer_2);
218bd882c8aSJames Wright
219bd882c8aSJames Wright // Contract along middle index
220bd882c8aSJames Wright const CeedInt writeLen = pre * post * Q;
221bd882c8aSJames Wright for (CeedInt k = i; k < writeLen; k += group_size) {
222bd882c8aSJames Wright const CeedInt c = k % post;
223bd882c8aSJames Wright const CeedInt j = (k / post) % Q;
224bd882c8aSJames Wright const CeedInt a = k / (post * Q);
225bd882c8aSJames Wright
226bd882c8aSJames Wright CeedScalar v_k = 0;
227bd882c8aSJames Wright for (CeedInt b = 0; b < P; b++) v_k += op[j * stride_0 + b * stride_1] * in[(a * P + b) * post + c];
228bd882c8aSJames Wright
2290ae60fd3SJeremy L Thompson if (is_transpose && dim_2 == dim - 1) out[k] += v_k;
230bd882c8aSJames Wright else out[k] = v_k;
231bd882c8aSJames Wright }
232bd882c8aSJames Wright
233bd882c8aSJames Wright post *= Q;
234bd882c8aSJames Wright }
235bd882c8aSJames Wright }
236bd882c8aSJames Wright }
237bd882c8aSJames Wright });
238bd882c8aSJames Wright });
239bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
240bd882c8aSJames Wright }
241bd882c8aSJames Wright
242bd882c8aSJames Wright //------------------------------------------------------------------------------
243bd882c8aSJames Wright // Weight kernel - tensor
244bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisApplyWeight_Sycl(sycl::queue & sycl_queue,CeedInt num_elem,const CeedBasis_Sycl * impl,CeedScalar * w)245bd882c8aSJames Wright static int CeedBasisApplyWeight_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, const CeedBasis_Sycl *impl, CeedScalar *w) {
246bd882c8aSJames Wright const CeedInt dim = impl->dim;
247bd882c8aSJames Wright const CeedInt Q_1d = impl->Q_1d;
248bd882c8aSJames Wright const CeedScalar *q_weight_1d = impl->d_q_weight_1d;
249bd882c8aSJames Wright
250bd882c8aSJames Wright const CeedInt num_quad_x = Q_1d;
251bd882c8aSJames Wright const CeedInt num_quad_y = (dim > 1) ? Q_1d : 1;
252bd882c8aSJames Wright const CeedInt num_quad_z = (dim > 2) ? Q_1d : 1;
253bd882c8aSJames Wright sycl::range<3> kernel_range(num_elem * num_quad_z, num_quad_y, num_quad_x);
254bd882c8aSJames Wright
2551f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e;
2561f4b1b45SUmesh Unnikrishnan
2571f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()};
2581f4b1b45SUmesh Unnikrishnan
2591f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedBasisSyclWeight>(kernel_range, e, [=](sycl::item<3> work_item) {
260bd882c8aSJames Wright if (dim == 1) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]];
261bd882c8aSJames Wright if (dim == 2) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]] * q_weight_1d[work_item[1]];
262bd882c8aSJames Wright if (dim == 3) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]] * q_weight_1d[work_item[1]] * q_weight_1d[work_item[0] % Q_1d];
263bd882c8aSJames Wright });
264bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
265bd882c8aSJames Wright }
266bd882c8aSJames Wright
267bd882c8aSJames Wright //------------------------------------------------------------------------------
268bd882c8aSJames Wright // Basis apply - tensor
269bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisApply_Sycl(CeedBasis basis,const CeedInt num_elem,CeedTransposeMode t_mode,CeedEvalMode eval_mode,CeedVector u,CeedVector v)270bd882c8aSJames Wright static int CeedBasisApply_Sycl(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
271bd882c8aSJames Wright CeedVector v) {
272bd882c8aSJames Wright Ceed ceed;
2730ae60fd3SJeremy L Thompson const CeedInt is_transpose = t_mode == CEED_TRANSPOSE;
274bd882c8aSJames Wright const CeedScalar *d_u;
275bd882c8aSJames Wright CeedScalar *d_v;
2760ae60fd3SJeremy L Thompson Ceed_Sycl *data;
2770ae60fd3SJeremy L Thompson CeedBasis_Sycl *impl;
2780ae60fd3SJeremy L Thompson
2790ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
2800ae60fd3SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data));
2810ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisGetData(basis, &impl));
2820ae60fd3SJeremy L Thompson
2830ae60fd3SJeremy L Thompson // Get read/write access to u, v
2840ae60fd3SJeremy L Thompson if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
2850ae60fd3SJeremy L Thompson else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
286bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
287bd882c8aSJames Wright
288bd882c8aSJames Wright // Clear v for transpose operation
2890ae60fd3SJeremy L Thompson if (is_transpose) {
290bd882c8aSJames Wright CeedSize length;
291bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(v, &length));
2921f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e;
2931f4b1b45SUmesh Unnikrishnan
2941f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()};
2951f4b1b45SUmesh Unnikrishnan data->sycl_queue.fill<CeedScalar>(d_v, 0, length, e);
296bd882c8aSJames Wright }
297bd882c8aSJames Wright
298bd882c8aSJames Wright // Basis action
299bd882c8aSJames Wright switch (eval_mode) {
300d07cdbe5SJeremy L Thompson case CEED_EVAL_INTERP:
3010ae60fd3SJeremy L Thompson if (is_transpose) {
302d07cdbe5SJeremy L Thompson CeedCallBackend(CeedBasisApplyInterp_Sycl<true>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
303bd882c8aSJames Wright } else {
304d07cdbe5SJeremy L Thompson CeedCallBackend(CeedBasisApplyInterp_Sycl<false>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
305bd882c8aSJames Wright }
306d07cdbe5SJeremy L Thompson break;
307d07cdbe5SJeremy L Thompson case CEED_EVAL_GRAD:
3080ae60fd3SJeremy L Thompson if (is_transpose) {
309d07cdbe5SJeremy L Thompson CeedCallBackend(CeedBasisApplyGrad_Sycl<true>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
310bd882c8aSJames Wright } else {
311d07cdbe5SJeremy L Thompson CeedCallBackend(CeedBasisApplyGrad_Sycl<false>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v));
312bd882c8aSJames Wright }
313d07cdbe5SJeremy L Thompson break;
314d07cdbe5SJeremy L Thompson case CEED_EVAL_WEIGHT:
315097cc795SJames Wright CeedCheck(impl->d_q_weight_1d, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weight_1d not set", CeedEvalModes[eval_mode]);
316bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyWeight_Sycl(data->sycl_queue, num_elem, impl, d_v));
317d07cdbe5SJeremy L Thompson break;
3180ae60fd3SJeremy L Thompson case CEED_EVAL_NONE: /* handled separately below */
3190ae60fd3SJeremy L Thompson break;
320bd882c8aSJames Wright // LCOV_EXCL_START
321bd882c8aSJames Wright case CEED_EVAL_DIV:
322bd882c8aSJames Wright case CEED_EVAL_CURL:
3234e3038a5SJeremy L Thompson return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
324bd882c8aSJames Wright // LCOV_EXCL_STOP
325bd882c8aSJames Wright }
326bd882c8aSJames Wright
3270ae60fd3SJeremy L Thompson // Restore vectors, cover CEED_EVAL_NONE
328bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
3290ae60fd3SJeremy L Thompson if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
3300ae60fd3SJeremy L Thompson if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
3319bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
332bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
333bd882c8aSJames Wright }
334bd882c8aSJames Wright
335bd882c8aSJames Wright //------------------------------------------------------------------------------
336bd882c8aSJames Wright // Interpolation kernel - non-tensor
337bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisApplyNonTensorInterp_Sycl(sycl::queue & sycl_queue,CeedInt num_elem,CeedInt is_transpose,const CeedBasisNonTensor_Sycl * impl,const CeedScalar * d_U,CeedScalar * d_V)3380ae60fd3SJeremy L Thompson static int CeedBasisApplyNonTensorInterp_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, CeedInt is_transpose, const CeedBasisNonTensor_Sycl *impl,
339bd882c8aSJames Wright const CeedScalar *d_U, CeedScalar *d_V) {
340bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp;
3410ae60fd3SJeremy L Thompson const CeedInt P = is_transpose ? impl->num_qpts : impl->num_nodes;
3420ae60fd3SJeremy L Thompson const CeedInt Q = is_transpose ? impl->num_nodes : impl->num_qpts;
3430ae60fd3SJeremy L Thompson const CeedInt stride_0 = is_transpose ? 1 : impl->num_nodes;
3440ae60fd3SJeremy L Thompson const CeedInt stride_1 = is_transpose ? impl->num_nodes : 1;
345bd882c8aSJames Wright const CeedInt u_stride = P;
346bd882c8aSJames Wright const CeedInt v_stride = Q;
347bd882c8aSJames Wright const CeedInt u_comp_stride = u_stride * num_elem;
348bd882c8aSJames Wright const CeedInt v_comp_stride = v_stride * num_elem;
349bd882c8aSJames Wright const CeedInt u_size = P;
350bd882c8aSJames Wright const CeedInt v_size = Q;
351bd882c8aSJames Wright const CeedScalar *d_B = impl->d_interp;
352bd882c8aSJames Wright
353bd882c8aSJames Wright sycl::range<2> kernel_range(num_elem, v_size);
354bd882c8aSJames Wright
3551f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e;
3561f4b1b45SUmesh Unnikrishnan
3571f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()};
3581f4b1b45SUmesh Unnikrishnan
3591f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedBasisSyclInterpNT>(kernel_range, e, [=](sycl::id<2> indx) {
360bd882c8aSJames Wright const CeedInt i = indx[1];
361bd882c8aSJames Wright const CeedInt elem = indx[0];
362bd882c8aSJames Wright
363bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) {
364bd882c8aSJames Wright const CeedScalar *U = d_U + elem * u_stride + comp * u_comp_stride;
365bd882c8aSJames Wright CeedScalar V = 0.0;
366bd882c8aSJames Wright
367bd882c8aSJames Wright for (CeedInt j = 0; j < u_size; ++j) {
368bd882c8aSJames Wright V += d_B[i * stride_0 + j * stride_1] * U[j];
369bd882c8aSJames Wright }
370bd882c8aSJames Wright d_V[i + elem * v_stride + comp * v_comp_stride] = V;
371bd882c8aSJames Wright }
372bd882c8aSJames Wright });
373bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
374bd882c8aSJames Wright }
375bd882c8aSJames Wright
376bd882c8aSJames Wright //------------------------------------------------------------------------------
377bd882c8aSJames Wright // Gradient kernel - non-tensor
378bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisApplyNonTensorGrad_Sycl(sycl::queue & sycl_queue,CeedInt num_elem,CeedInt is_transpose,const CeedBasisNonTensor_Sycl * impl,const CeedScalar * d_U,CeedScalar * d_V)3790ae60fd3SJeremy L Thompson static int CeedBasisApplyNonTensorGrad_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, CeedInt is_transpose, const CeedBasisNonTensor_Sycl *impl,
380bd882c8aSJames Wright const CeedScalar *d_U, CeedScalar *d_V) {
381bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp;
3820ae60fd3SJeremy L Thompson const CeedInt P = is_transpose ? impl->num_qpts : impl->num_nodes;
3830ae60fd3SJeremy L Thompson const CeedInt Q = is_transpose ? impl->num_nodes : impl->num_qpts;
3840ae60fd3SJeremy L Thompson const CeedInt stride_0 = is_transpose ? 1 : impl->num_nodes;
3850ae60fd3SJeremy L Thompson const CeedInt stride_1 = is_transpose ? impl->num_nodes : 1;
386bd882c8aSJames Wright const CeedInt g_dim_stride = P * Q;
387bd882c8aSJames Wright const CeedInt u_stride = P;
388bd882c8aSJames Wright const CeedInt v_stride = Q;
389bd882c8aSJames Wright const CeedInt u_comp_stride = u_stride * num_elem;
390bd882c8aSJames Wright const CeedInt v_comp_stride = v_stride * num_elem;
391bd882c8aSJames Wright const CeedInt u_dim_stride = u_comp_stride * num_comp;
392bd882c8aSJames Wright const CeedInt v_dim_stride = v_comp_stride * num_comp;
393bd882c8aSJames Wright const CeedInt u_size = P;
394bd882c8aSJames Wright const CeedInt v_size = Q;
3950ae60fd3SJeremy L Thompson const CeedInt in_dim = is_transpose ? impl->dim : 1;
3960ae60fd3SJeremy L Thompson const CeedInt out_dim = is_transpose ? 1 : impl->dim;
397bd882c8aSJames Wright const CeedScalar *d_G = impl->d_grad;
398bd882c8aSJames Wright
399bd882c8aSJames Wright sycl::range<2> kernel_range(num_elem, v_size);
400bd882c8aSJames Wright
4011f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e;
4021f4b1b45SUmesh Unnikrishnan
4031f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()};
4041f4b1b45SUmesh Unnikrishnan
4051f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedBasisSyclGradNT>(kernel_range, e, [=](sycl::id<2> indx) {
406bd882c8aSJames Wright const CeedInt i = indx[1];
407bd882c8aSJames Wright const CeedInt elem = indx[0];
408bd882c8aSJames Wright
409bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) {
410bd882c8aSJames Wright CeedScalar V[3] = {0.0, 0.0, 0.0};
411bd882c8aSJames Wright
412bd882c8aSJames Wright for (CeedInt d1 = 0; d1 < in_dim; ++d1) {
413bd882c8aSJames Wright const CeedScalar *U = d_U + elem * u_stride + comp * u_comp_stride + d1 * u_dim_stride;
414bd882c8aSJames Wright const CeedScalar *G = d_G + i * stride_0 + d1 * g_dim_stride;
415bd882c8aSJames Wright
416bd882c8aSJames Wright for (CeedInt j = 0; j < u_size; ++j) {
417bd882c8aSJames Wright const CeedScalar Uj = U[j];
418bd882c8aSJames Wright
419bd882c8aSJames Wright for (CeedInt d0 = 0; d0 < out_dim; ++d0) {
420bd882c8aSJames Wright V[d0] += G[j * stride_1 + d0 * g_dim_stride] * Uj;
421bd882c8aSJames Wright }
422bd882c8aSJames Wright }
423bd882c8aSJames Wright }
424bd882c8aSJames Wright for (CeedInt d0 = 0; d0 < out_dim; ++d0) {
425bd882c8aSJames Wright d_V[i + elem * v_stride + comp * v_comp_stride + d0 * v_dim_stride] = V[d0];
426bd882c8aSJames Wright }
427bd882c8aSJames Wright }
428bd882c8aSJames Wright });
429bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
430bd882c8aSJames Wright }
431bd882c8aSJames Wright
432bd882c8aSJames Wright //------------------------------------------------------------------------------
433bd882c8aSJames Wright // Weight kernel - non-tensor
434bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisApplyNonTensorWeight_Sycl(sycl::queue & sycl_queue,CeedInt num_elem,const CeedBasisNonTensor_Sycl * impl,CeedScalar * d_V)435bd882c8aSJames Wright static int CeedBasisApplyNonTensorWeight_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, const CeedBasisNonTensor_Sycl *impl, CeedScalar *d_V) {
436bd882c8aSJames Wright const CeedInt num_qpts = impl->num_qpts;
437bd882c8aSJames Wright const CeedScalar *q_weight = impl->d_q_weight;
438bd882c8aSJames Wright
439bd882c8aSJames Wright sycl::range<2> kernel_range(num_elem, num_qpts);
440bd882c8aSJames Wright
4411f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e;
4421f4b1b45SUmesh Unnikrishnan
4431f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()};
4441f4b1b45SUmesh Unnikrishnan
4451f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedBasisSyclWeightNT>(kernel_range, e, [=](sycl::id<2> indx) {
446bd882c8aSJames Wright const CeedInt i = indx[1];
447bd882c8aSJames Wright const CeedInt elem = indx[0];
448bd882c8aSJames Wright d_V[i + elem * num_qpts] = q_weight[i];
449bd882c8aSJames Wright });
450bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
451bd882c8aSJames Wright }
452bd882c8aSJames Wright
453bd882c8aSJames Wright //------------------------------------------------------------------------------
454bd882c8aSJames Wright // Basis apply - non-tensor
455bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisApplyNonTensor_Sycl(CeedBasis basis,const CeedInt num_elem,CeedTransposeMode t_mode,CeedEvalMode eval_mode,CeedVector u,CeedVector v)456bd882c8aSJames Wright static int CeedBasisApplyNonTensor_Sycl(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
457bd882c8aSJames Wright CeedVector v) {
458bd882c8aSJames Wright Ceed ceed;
4590ae60fd3SJeremy L Thompson const CeedInt is_transpose = t_mode == CEED_TRANSPOSE;
460bd882c8aSJames Wright const CeedScalar *d_u;
461bd882c8aSJames Wright CeedScalar *d_v;
4620ae60fd3SJeremy L Thompson CeedBasisNonTensor_Sycl *impl;
4630ae60fd3SJeremy L Thompson Ceed_Sycl *data;
4640ae60fd3SJeremy L Thompson
4650ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
4660ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisGetData(basis, &impl));
4670ae60fd3SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data));
4680ae60fd3SJeremy L Thompson
4690ae60fd3SJeremy L Thompson // Get read/write access to u, v
4700ae60fd3SJeremy L Thompson if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
4710ae60fd3SJeremy L Thompson else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
472bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
473bd882c8aSJames Wright
474bd882c8aSJames Wright // Clear v for transpose operation
4750ae60fd3SJeremy L Thompson if (is_transpose) {
476bd882c8aSJames Wright CeedSize length;
477bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(v, &length));
478bd882c8aSJames Wright // Order queue
479bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
480bd882c8aSJames Wright data->sycl_queue.fill<CeedScalar>(d_v, 0, length, {e});
481bd882c8aSJames Wright }
482bd882c8aSJames Wright
483bd882c8aSJames Wright // Apply basis operation
484bd882c8aSJames Wright switch (eval_mode) {
485d07cdbe5SJeremy L Thompson case CEED_EVAL_INTERP:
4860ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisApplyNonTensorInterp_Sycl(data->sycl_queue, num_elem, is_transpose, impl, d_u, d_v));
487d07cdbe5SJeremy L Thompson break;
488d07cdbe5SJeremy L Thompson case CEED_EVAL_GRAD:
4890ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisApplyNonTensorGrad_Sycl(data->sycl_queue, num_elem, is_transpose, impl, d_u, d_v));
490d07cdbe5SJeremy L Thompson break;
491d07cdbe5SJeremy L Thompson case CEED_EVAL_WEIGHT:
492097cc795SJames Wright CeedCheck(impl->d_q_weight, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weights not set", CeedEvalModes[eval_mode]);
493bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyNonTensorWeight_Sycl(data->sycl_queue, num_elem, impl, d_v));
494d07cdbe5SJeremy L Thompson break;
4950ae60fd3SJeremy L Thompson case CEED_EVAL_NONE: /* handled separately below */
4960ae60fd3SJeremy L Thompson break;
497bd882c8aSJames Wright // LCOV_EXCL_START
498bd882c8aSJames Wright case CEED_EVAL_DIV:
499bd882c8aSJames Wright case CEED_EVAL_CURL:
5009d1bceceSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
501bd882c8aSJames Wright // LCOV_EXCL_STOP
502bd882c8aSJames Wright }
503bd882c8aSJames Wright
5040ae60fd3SJeremy L Thompson // Restore vectors, cover CEED_EVAL_NONE
505bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
5060ae60fd3SJeremy L Thompson if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
5070ae60fd3SJeremy L Thompson if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
5089bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
509bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
510bd882c8aSJames Wright }
511bd882c8aSJames Wright
512bd882c8aSJames Wright //------------------------------------------------------------------------------
513bd882c8aSJames Wright // Destroy tensor basis
514bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisDestroy_Sycl(CeedBasis basis)515bd882c8aSJames Wright static int CeedBasisDestroy_Sycl(CeedBasis basis) {
516bd882c8aSJames Wright Ceed ceed;
517bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
518bd882c8aSJames Wright CeedBasis_Sycl *impl;
519bd882c8aSJames Wright CeedCallBackend(CeedBasisGetData(basis, &impl));
520bd882c8aSJames Wright Ceed_Sycl *data;
521bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data));
522bd882c8aSJames Wright
523bd882c8aSJames Wright // Wait for all work to finish before freeing memory
524bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
525bd882c8aSJames Wright
526097cc795SJames Wright if (impl->d_q_weight_1d) CeedCallSycl(ceed, sycl::free(impl->d_q_weight_1d, data->sycl_context));
527bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_interp_1d, data->sycl_context));
528bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_grad_1d, data->sycl_context));
529bd882c8aSJames Wright
530bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl));
5319bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
532bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
533bd882c8aSJames Wright }
534bd882c8aSJames Wright
535bd882c8aSJames Wright //------------------------------------------------------------------------------
536bd882c8aSJames Wright // Destroy non-tensor basis
537bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisDestroyNonTensor_Sycl(CeedBasis basis)538bd882c8aSJames Wright static int CeedBasisDestroyNonTensor_Sycl(CeedBasis basis) {
539bd882c8aSJames Wright Ceed ceed;
540bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
541bd882c8aSJames Wright CeedBasisNonTensor_Sycl *impl;
542bd882c8aSJames Wright CeedCallBackend(CeedBasisGetData(basis, &impl));
543bd882c8aSJames Wright Ceed_Sycl *data;
544bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data));
545bd882c8aSJames Wright
546bd882c8aSJames Wright // Wait for all work to finish before freeing memory
547bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
548bd882c8aSJames Wright
549097cc795SJames Wright if (impl->d_q_weight) CeedCallSycl(ceed, sycl::free(impl->d_q_weight, data->sycl_context));
550bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_interp, data->sycl_context));
551bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_grad, data->sycl_context));
552bd882c8aSJames Wright
553bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl));
5549bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
555bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
556bd882c8aSJames Wright }
557bd882c8aSJames Wright
558bd882c8aSJames Wright //------------------------------------------------------------------------------
559bd882c8aSJames Wright // Create tensor
560bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisCreateTensorH1_Sycl(CeedInt dim,CeedInt P_1d,CeedInt Q_1d,const CeedScalar * interp_1d,const CeedScalar * grad_1d,const CeedScalar * q_ref_1d,const CeedScalar * q_weight_1d,CeedBasis basis)561bd882c8aSJames Wright int CeedBasisCreateTensorH1_Sycl(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
562bd882c8aSJames Wright const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
563bd882c8aSJames Wright Ceed ceed;
564bd882c8aSJames Wright CeedBasis_Sycl *impl;
565bd882c8aSJames Wright Ceed_Sycl *data;
5669bc66399SJeremy L Thompson
5679bc66399SJeremy L Thompson CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
5689bc66399SJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl));
569bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data));
570bd882c8aSJames Wright
571bd882c8aSJames Wright CeedInt num_comp;
572bd882c8aSJames Wright CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
573bd882c8aSJames Wright
574bd882c8aSJames Wright const CeedInt num_nodes = CeedIntPow(P_1d, dim);
575bd882c8aSJames Wright const CeedInt num_qpts = CeedIntPow(Q_1d, dim);
576bd882c8aSJames Wright
577bd882c8aSJames Wright impl->dim = dim;
578bd882c8aSJames Wright impl->P_1d = P_1d;
579bd882c8aSJames Wright impl->Q_1d = Q_1d;
580bd882c8aSJames Wright impl->num_comp = num_comp;
581bd882c8aSJames Wright impl->num_nodes = num_nodes;
582bd882c8aSJames Wright impl->num_qpts = num_qpts;
583bd882c8aSJames Wright impl->buf_len = num_comp * CeedIntMax(num_nodes, num_qpts);
584bd882c8aSJames Wright impl->op_len = Q_1d * P_1d;
585bd882c8aSJames Wright
5861f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e;
5871f4b1b45SUmesh Unnikrishnan
5881f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()};
589bd882c8aSJames Wright
590097cc795SJames Wright std::vector<sycl::event> copy_events;
591097cc795SJames Wright if (q_weight_1d) {
592bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_q_weight_1d = sycl::malloc_device<CeedScalar>(Q_1d, data->sycl_device, data->sycl_context));
5931f4b1b45SUmesh Unnikrishnan sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight_1d, impl->d_q_weight_1d, Q_1d, e);
594097cc795SJames Wright copy_events.push_back(copy_weight);
595097cc795SJames Wright }
596bd882c8aSJames Wright
597bd882c8aSJames Wright const CeedInt interp_length = Q_1d * P_1d;
598bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_interp_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
5991f4b1b45SUmesh Unnikrishnan sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp_1d, impl->d_interp_1d, interp_length, e);
600097cc795SJames Wright copy_events.push_back(copy_interp);
601bd882c8aSJames Wright
602bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_grad_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
6031f4b1b45SUmesh Unnikrishnan sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad_1d, impl->d_grad_1d, interp_length, e);
604097cc795SJames Wright copy_events.push_back(copy_grad);
605bd882c8aSJames Wright
606097cc795SJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw(copy_events));
607bd882c8aSJames Wright
608bd882c8aSJames Wright std::vector<sycl::kernel_id> kernel_ids = {sycl::get_kernel_id<CeedBasisSyclInterp<1>>(), sycl::get_kernel_id<CeedBasisSyclInterp<0>>(),
609bd882c8aSJames Wright sycl::get_kernel_id<CeedBasisSyclGrad<1>>(), sycl::get_kernel_id<CeedBasisSyclGrad<0>>()};
610bd882c8aSJames Wright
611bd882c8aSJames Wright sycl::kernel_bundle<sycl::bundle_state::input> input_bundle = sycl::get_kernel_bundle<sycl::bundle_state::input>(data->sycl_context, kernel_ids);
612bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_DIM_ID>(dim);
613bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_NUM_COMP_ID>(num_comp);
614bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_Q_1D_ID>(Q_1d);
615bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_P_1D_ID>(P_1d);
616bd882c8aSJames Wright
617bd882c8aSJames Wright CeedCallSycl(ceed, impl->sycl_module = new SyclModule_t(sycl::build(input_bundle)));
618bd882c8aSJames Wright
619bd882c8aSJames Wright CeedCallBackend(CeedBasisSetData(basis, impl));
620bd882c8aSJames Wright
621bd882c8aSJames Wright // Register backend functions
622bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApply_Sycl));
623bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Sycl));
6249bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
625bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
626bd882c8aSJames Wright }
627bd882c8aSJames Wright
628bd882c8aSJames Wright //------------------------------------------------------------------------------
629bd882c8aSJames Wright // Create non-tensor
630bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisCreateH1_Sycl(CeedElemTopology topo,CeedInt dim,CeedInt num_nodes,CeedInt num_qpts,const CeedScalar * interp,const CeedScalar * grad,const CeedScalar * q_ref,const CeedScalar * q_weight,CeedBasis basis)631bd882c8aSJames Wright int CeedBasisCreateH1_Sycl(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad,
632dd64fc84SJeremy L Thompson const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
633bd882c8aSJames Wright Ceed ceed;
634bd882c8aSJames Wright CeedBasisNonTensor_Sycl *impl;
635bd882c8aSJames Wright Ceed_Sycl *data;
6369bc66399SJeremy L Thompson
6379bc66399SJeremy L Thompson CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
6389bc66399SJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl));
639bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data));
640bd882c8aSJames Wright
641bd882c8aSJames Wright CeedInt num_comp;
642bd882c8aSJames Wright CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
643bd882c8aSJames Wright
644bd882c8aSJames Wright impl->dim = dim;
645bd882c8aSJames Wright impl->num_comp = num_comp;
646bd882c8aSJames Wright impl->num_nodes = num_nodes;
647bd882c8aSJames Wright impl->num_qpts = num_qpts;
648bd882c8aSJames Wright
6491f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e;
6501f4b1b45SUmesh Unnikrishnan
6511f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()};
652bd882c8aSJames Wright
653097cc795SJames Wright std::vector<sycl::event> copy_events;
654097cc795SJames Wright if (q_weight) {
655bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_q_weight = sycl::malloc_device<CeedScalar>(num_qpts, data->sycl_device, data->sycl_context));
6561f4b1b45SUmesh Unnikrishnan sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight, impl->d_q_weight, num_qpts, e);
657097cc795SJames Wright copy_events.push_back(copy_weight);
658097cc795SJames Wright }
659bd882c8aSJames Wright
660bd882c8aSJames Wright const CeedInt interp_length = num_qpts * num_nodes;
661bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_interp = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
6621f4b1b45SUmesh Unnikrishnan sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp, impl->d_interp, interp_length, e);
663097cc795SJames Wright copy_events.push_back(copy_interp);
664bd882c8aSJames Wright
665bd882c8aSJames Wright const CeedInt grad_length = num_qpts * num_nodes * dim;
666bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_grad = sycl::malloc_device<CeedScalar>(grad_length, data->sycl_device, data->sycl_context));
6671f4b1b45SUmesh Unnikrishnan sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad, impl->d_grad, grad_length, e);
668097cc795SJames Wright copy_events.push_back(copy_grad);
669bd882c8aSJames Wright
670097cc795SJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw(copy_events));
671bd882c8aSJames Wright
672bd882c8aSJames Wright CeedCallBackend(CeedBasisSetData(basis, impl));
673bd882c8aSJames Wright
674bd882c8aSJames Wright // Register backend functions
675bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Sycl));
676bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Sycl));
6779bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
678bd882c8aSJames Wright return CEED_ERROR_SUCCESS;
679bd882c8aSJames Wright }
680ff1e7120SSebastian Grimberg
681bd882c8aSJames Wright //------------------------------------------------------------------------------
682