15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3bd882c8aSJames Wright // 4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5bd882c8aSJames Wright // 6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7bd882c8aSJames Wright 8bd882c8aSJames Wright #include <ceed/backend.h> 9bd882c8aSJames Wright #include <ceed/ceed.h> 10bd882c8aSJames Wright #include <ceed/jit-tools.h> 11bd882c8aSJames Wright 12bd882c8aSJames Wright #include <sycl/sycl.hpp> 13bd882c8aSJames Wright #include <vector> 14bd882c8aSJames Wright 15bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 16bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 17bd882c8aSJames Wright 18bd882c8aSJames Wright template <int> 19bd882c8aSJames Wright class CeedBasisSyclInterp; 20bd882c8aSJames Wright template <int> 21bd882c8aSJames Wright class CeedBasisSyclGrad; 22bd882c8aSJames Wright class CeedBasisSyclWeight; 23bd882c8aSJames Wright 24bd882c8aSJames Wright class CeedBasisSyclInterpNT; 25bd882c8aSJames Wright class CeedBasisSyclGradNT; 26bd882c8aSJames Wright class CeedBasisSyclWeightNT; 27bd882c8aSJames Wright 28bd882c8aSJames Wright using SpecID = sycl::specialization_id<CeedInt>; 29bd882c8aSJames Wright 30bd882c8aSJames Wright static constexpr SpecID BASIS_DIM_ID; 31bd882c8aSJames Wright static constexpr SpecID BASIS_NUM_COMP_ID; 32bd882c8aSJames Wright static constexpr SpecID BASIS_P_1D_ID; 33bd882c8aSJames Wright static constexpr SpecID BASIS_Q_1D_ID; 34bd882c8aSJames Wright 35bd882c8aSJames Wright //------------------------------------------------------------------------------ 36bd882c8aSJames Wright // Interpolation kernel - tensor 37bd882c8aSJames Wright //------------------------------------------------------------------------------ 380ae60fd3SJeremy L Thompson template <int is_transpose> 39bd882c8aSJames Wright static int CeedBasisApplyInterp_Sycl(sycl::queue &sycl_queue, const SyclModule_t &sycl_module, CeedInt num_elem, const CeedBasis_Sycl *impl, 40bd882c8aSJames Wright const CeedScalar *u, CeedScalar *v) { 41bd882c8aSJames Wright const CeedInt buf_len = impl->buf_len; 42bd882c8aSJames Wright const CeedInt op_len = impl->op_len; 43bd882c8aSJames Wright const CeedScalar *interp_1d = impl->d_interp_1d; 44bd882c8aSJames Wright 45bd882c8aSJames Wright const sycl::device &sycl_device = sycl_queue.get_device(); 46bd882c8aSJames Wright const CeedInt max_work_group_size = 32; 47bd882c8aSJames Wright const CeedInt work_group_size = CeedIntMin(impl->num_qpts, max_work_group_size); 48bd882c8aSJames Wright sycl::range<1> local_range(work_group_size); 49bd882c8aSJames Wright sycl::range<1> global_range(num_elem * work_group_size); 50bd882c8aSJames Wright sycl::nd_range<1> kernel_range(global_range, local_range); 51bd882c8aSJames Wright 521f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 531f4b1b45SUmesh Unnikrishnan 541f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 551f4b1b45SUmesh Unnikrishnan 56bd882c8aSJames Wright sycl_queue.submit([&](sycl::handler &cgh) { 571f4b1b45SUmesh Unnikrishnan cgh.depends_on(e); 58bd882c8aSJames Wright cgh.use_kernel_bundle(sycl_module); 59bd882c8aSJames Wright 60bd882c8aSJames Wright sycl::local_accessor<CeedScalar> s_mem(op_len + 2 * buf_len, cgh); 61bd882c8aSJames Wright 620ae60fd3SJeremy L Thompson cgh.parallel_for<CeedBasisSyclInterp<is_transpose>>(kernel_range, [=](sycl::nd_item<1> work_item, sycl::kernel_handler kh) { 63bd882c8aSJames Wright //--------------------------------------------------------------> 64bd882c8aSJames Wright // Retrieve spec constant values 65bd882c8aSJames Wright const CeedInt dim = kh.get_specialization_constant<BASIS_DIM_ID>(); 66bd882c8aSJames Wright const CeedInt num_comp = kh.get_specialization_constant<BASIS_NUM_COMP_ID>(); 67bd882c8aSJames Wright const CeedInt P_1d = kh.get_specialization_constant<BASIS_P_1D_ID>(); 68bd882c8aSJames Wright const CeedInt Q_1d = kh.get_specialization_constant<BASIS_Q_1D_ID>(); 69bd882c8aSJames Wright //--------------------------------------------------------------> 70bd882c8aSJames Wright const CeedInt num_nodes = CeedIntPow(P_1d, dim); 71bd882c8aSJames Wright const CeedInt num_qpts = CeedIntPow(Q_1d, dim); 720ae60fd3SJeremy L Thompson const CeedInt P = is_transpose ? Q_1d : P_1d; 730ae60fd3SJeremy L Thompson const CeedInt Q = is_transpose ? P_1d : Q_1d; 740ae60fd3SJeremy L Thompson const CeedInt stride_0 = is_transpose ? 1 : P_1d; 750ae60fd3SJeremy L Thompson const CeedInt stride_1 = is_transpose ? P_1d : 1; 760ae60fd3SJeremy L Thompson const CeedInt u_stride = is_transpose ? num_qpts : num_nodes; 770ae60fd3SJeremy L Thompson const CeedInt v_stride = is_transpose ? num_nodes : num_qpts; 78bd882c8aSJames Wright const CeedInt u_comp_stride = num_elem * u_stride; 79bd882c8aSJames Wright const CeedInt v_comp_stride = num_elem * v_stride; 80bd882c8aSJames Wright const CeedInt u_size = u_stride; 81bd882c8aSJames Wright 82bd882c8aSJames Wright sycl::group work_group = work_item.get_group(); 83bd882c8aSJames Wright const CeedInt i = work_item.get_local_linear_id(); 84bd882c8aSJames Wright const CeedInt group_size = work_group.get_local_linear_range(); 85bd882c8aSJames Wright const CeedInt elem = work_group.get_group_linear_id(); 86bd882c8aSJames Wright 8733bb61d4SKris Rowe CeedScalar *s_interp_1d = s_mem.get_multi_ptr<sycl::access::decorated::yes>().get(); 88bd882c8aSJames Wright CeedScalar *s_buffer_1 = s_interp_1d + Q * P; 89bd882c8aSJames Wright CeedScalar *s_buffer_2 = s_buffer_1 + buf_len; 90bd882c8aSJames Wright 91bd882c8aSJames Wright for (CeedInt k = i; k < P * Q; k += group_size) { 92bd882c8aSJames Wright s_interp_1d[k] = interp_1d[k]; 93bd882c8aSJames Wright } 94bd882c8aSJames Wright 95bd882c8aSJames Wright // Apply basis element by element 96bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 97bd882c8aSJames Wright const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride; 98bd882c8aSJames Wright CeedScalar *cur_v = v + elem * v_stride + comp * v_comp_stride; 99bd882c8aSJames Wright 100bd882c8aSJames Wright for (CeedInt k = i; k < u_size; k += group_size) { 101bd882c8aSJames Wright s_buffer_1[k] = cur_u[k]; 102bd882c8aSJames Wright } 103bd882c8aSJames Wright 104bd882c8aSJames Wright CeedInt pre = u_size; 105bd882c8aSJames Wright CeedInt post = 1; 106bd882c8aSJames Wright 107bd882c8aSJames Wright for (CeedInt d = 0; d < dim; d++) { 108bd882c8aSJames Wright // Use older version of sycl workgroup barrier for performance reasons 109bd882c8aSJames Wright // Can be updated in future to align with SYCL2020 spec if performance bottleneck is removed 110bd882c8aSJames Wright // sycl::group_barrier(work_group); 111bd882c8aSJames Wright work_item.barrier(sycl::access::fence_space::local_space); 112bd882c8aSJames Wright 113bd882c8aSJames Wright pre /= P; 114bd882c8aSJames Wright const CeedScalar *in = d % 2 ? s_buffer_2 : s_buffer_1; 115bd882c8aSJames Wright CeedScalar *out = d == dim - 1 ? cur_v : (d % 2 ? s_buffer_1 : s_buffer_2); 116bd882c8aSJames Wright 117bd882c8aSJames Wright // Contract along middle index 118bd882c8aSJames Wright const CeedInt writeLen = pre * post * Q; 119bd882c8aSJames Wright for (CeedInt k = i; k < writeLen; k += group_size) { 120bd882c8aSJames Wright const CeedInt c = k % post; 121bd882c8aSJames Wright const CeedInt j = (k / post) % Q; 122bd882c8aSJames Wright const CeedInt a = k / (post * Q); 123bd882c8aSJames Wright 124bd882c8aSJames Wright CeedScalar vk = 0; 125bd882c8aSJames Wright for (CeedInt b = 0; b < P; b++) { 126bd882c8aSJames Wright vk += s_interp_1d[j * stride_0 + b * stride_1] * in[(a * P + b) * post + c]; 127bd882c8aSJames Wright } 128bd882c8aSJames Wright out[k] = vk; 129bd882c8aSJames Wright } 130bd882c8aSJames Wright post *= Q; 131bd882c8aSJames Wright } 132bd882c8aSJames Wright } 133bd882c8aSJames Wright }); 134bd882c8aSJames Wright }); 135bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 136bd882c8aSJames Wright } 137bd882c8aSJames Wright 138bd882c8aSJames Wright //------------------------------------------------------------------------------ 139bd882c8aSJames Wright // Gradient kernel - tensor 140bd882c8aSJames Wright //------------------------------------------------------------------------------ 1410ae60fd3SJeremy L Thompson template <int is_transpose> 142bd882c8aSJames Wright static int CeedBasisApplyGrad_Sycl(sycl::queue &sycl_queue, const SyclModule_t &sycl_module, CeedInt num_elem, const CeedBasis_Sycl *impl, 143bd882c8aSJames Wright const CeedScalar *u, CeedScalar *v) { 144bd882c8aSJames Wright const CeedInt buf_len = impl->buf_len; 145bd882c8aSJames Wright const CeedInt op_len = impl->op_len; 146bd882c8aSJames Wright const CeedScalar *interp_1d = impl->d_interp_1d; 147bd882c8aSJames Wright const CeedScalar *grad_1d = impl->d_grad_1d; 148bd882c8aSJames Wright 149bd882c8aSJames Wright const sycl::device &sycl_device = sycl_queue.get_device(); 150bd882c8aSJames Wright const CeedInt work_group_size = 32; 151bd882c8aSJames Wright sycl::range<1> local_range(work_group_size); 152bd882c8aSJames Wright sycl::range<1> global_range(num_elem * work_group_size); 153bd882c8aSJames Wright sycl::nd_range<1> kernel_range(global_range, local_range); 154bd882c8aSJames Wright 1551f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 1561f4b1b45SUmesh Unnikrishnan 1571f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 1581f4b1b45SUmesh Unnikrishnan 159bd882c8aSJames Wright sycl_queue.submit([&](sycl::handler &cgh) { 1601f4b1b45SUmesh Unnikrishnan cgh.depends_on(e); 161bd882c8aSJames Wright cgh.use_kernel_bundle(sycl_module); 162bd882c8aSJames Wright 163bd882c8aSJames Wright sycl::local_accessor<CeedScalar> s_mem(2 * (op_len + buf_len), cgh); 164bd882c8aSJames Wright 1650ae60fd3SJeremy L Thompson cgh.parallel_for<CeedBasisSyclGrad<is_transpose>>(kernel_range, [=](sycl::nd_item<1> work_item, sycl::kernel_handler kh) { 166bd882c8aSJames Wright //--------------------------------------------------------------> 167bd882c8aSJames Wright // Retrieve spec constant values 168bd882c8aSJames Wright const CeedInt dim = kh.get_specialization_constant<BASIS_DIM_ID>(); 169bd882c8aSJames Wright const CeedInt num_comp = kh.get_specialization_constant<BASIS_NUM_COMP_ID>(); 170bd882c8aSJames Wright const CeedInt P_1d = kh.get_specialization_constant<BASIS_P_1D_ID>(); 171bd882c8aSJames Wright const CeedInt Q_1d = kh.get_specialization_constant<BASIS_Q_1D_ID>(); 172bd882c8aSJames Wright //--------------------------------------------------------------> 173bd882c8aSJames Wright const CeedInt num_nodes = CeedIntPow(P_1d, dim); 174bd882c8aSJames Wright const CeedInt num_qpts = CeedIntPow(Q_1d, dim); 1750ae60fd3SJeremy L Thompson const CeedInt P = is_transpose ? Q_1d : P_1d; 1760ae60fd3SJeremy L Thompson const CeedInt Q = is_transpose ? P_1d : Q_1d; 1770ae60fd3SJeremy L Thompson const CeedInt stride_0 = is_transpose ? 1 : P_1d; 1780ae60fd3SJeremy L Thompson const CeedInt stride_1 = is_transpose ? P_1d : 1; 1790ae60fd3SJeremy L Thompson const CeedInt u_stride = is_transpose ? num_qpts : num_nodes; 1800ae60fd3SJeremy L Thompson const CeedInt v_stride = is_transpose ? num_nodes : num_qpts; 181bd882c8aSJames Wright const CeedInt u_comp_stride = num_elem * u_stride; 182bd882c8aSJames Wright const CeedInt v_comp_stride = num_elem * v_stride; 1830ae60fd3SJeremy L Thompson const CeedInt u_dim_stride = is_transpose ? num_elem * num_qpts * num_comp : 0; 1840ae60fd3SJeremy L Thompson const CeedInt v_dim_stride = is_transpose ? 0 : num_elem * num_qpts * num_comp; 185bd882c8aSJames Wright sycl::group work_group = work_item.get_group(); 186bd882c8aSJames Wright const CeedInt i = work_item.get_local_linear_id(); 187bd882c8aSJames Wright const CeedInt group_size = work_group.get_local_linear_range(); 188bd882c8aSJames Wright const CeedInt elem = work_group.get_group_linear_id(); 189bd882c8aSJames Wright 19033bb61d4SKris Rowe CeedScalar *s_interp_1d = s_mem.get_multi_ptr<sycl::access::decorated::yes>().get(); 191bd882c8aSJames Wright CeedScalar *s_grad_1d = s_interp_1d + P * Q; 192bd882c8aSJames Wright CeedScalar *s_buffer_1 = s_grad_1d + P * Q; 193bd882c8aSJames Wright CeedScalar *s_buffer_2 = s_buffer_1 + buf_len; 194bd882c8aSJames Wright 195bd882c8aSJames Wright for (CeedInt k = i; k < P * Q; k += group_size) { 196bd882c8aSJames Wright s_interp_1d[k] = interp_1d[k]; 197bd882c8aSJames Wright s_grad_1d[k] = grad_1d[k]; 198bd882c8aSJames Wright } 199bd882c8aSJames Wright 200bd882c8aSJames Wright // Apply basis element by element 201bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 202bd882c8aSJames Wright for (CeedInt dim_1 = 0; dim_1 < dim; dim_1++) { 2030ae60fd3SJeremy L Thompson CeedInt pre = is_transpose ? num_qpts : num_nodes; 204bd882c8aSJames Wright CeedInt post = 1; 205bd882c8aSJames Wright const CeedScalar *cur_u = u + elem * u_stride + dim_1 * u_dim_stride + comp * u_comp_stride; 206bd882c8aSJames Wright CeedScalar *cur_v = v + elem * v_stride + dim_1 * v_dim_stride + comp * v_comp_stride; 207bd882c8aSJames Wright 208bd882c8aSJames Wright for (CeedInt dim_2 = 0; dim_2 < dim; dim_2++) { 209bd882c8aSJames Wright // Use older version of sycl workgroup barrier for performance reasons 210bd882c8aSJames Wright // Can be updated in future to align with SYCL2020 spec if performance bottleneck is removed 211bd882c8aSJames Wright // sycl::group_barrier(work_group); 212bd882c8aSJames Wright work_item.barrier(sycl::access::fence_space::local_space); 213bd882c8aSJames Wright 214bd882c8aSJames Wright pre /= P; 215bd882c8aSJames Wright const CeedScalar *op = dim_1 == dim_2 ? s_grad_1d : s_interp_1d; 216bd882c8aSJames Wright const CeedScalar *in = (dim_2 == 0 ? cur_u : (dim_2 % 2 ? s_buffer_2 : s_buffer_1)); 217bd882c8aSJames Wright CeedScalar *out = dim_2 == dim - 1 ? cur_v : (dim_2 % 2 ? s_buffer_1 : s_buffer_2); 218bd882c8aSJames Wright 219bd882c8aSJames Wright // Contract along middle index 220bd882c8aSJames Wright const CeedInt writeLen = pre * post * Q; 221bd882c8aSJames Wright for (CeedInt k = i; k < writeLen; k += group_size) { 222bd882c8aSJames Wright const CeedInt c = k % post; 223bd882c8aSJames Wright const CeedInt j = (k / post) % Q; 224bd882c8aSJames Wright const CeedInt a = k / (post * Q); 225bd882c8aSJames Wright 226bd882c8aSJames Wright CeedScalar v_k = 0; 227bd882c8aSJames Wright for (CeedInt b = 0; b < P; b++) v_k += op[j * stride_0 + b * stride_1] * in[(a * P + b) * post + c]; 228bd882c8aSJames Wright 2290ae60fd3SJeremy L Thompson if (is_transpose && dim_2 == dim - 1) out[k] += v_k; 230bd882c8aSJames Wright else out[k] = v_k; 231bd882c8aSJames Wright } 232bd882c8aSJames Wright 233bd882c8aSJames Wright post *= Q; 234bd882c8aSJames Wright } 235bd882c8aSJames Wright } 236bd882c8aSJames Wright } 237bd882c8aSJames Wright }); 238bd882c8aSJames Wright }); 239bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 240bd882c8aSJames Wright } 241bd882c8aSJames Wright 242bd882c8aSJames Wright //------------------------------------------------------------------------------ 243bd882c8aSJames Wright // Weight kernel - tensor 244bd882c8aSJames Wright //------------------------------------------------------------------------------ 245bd882c8aSJames Wright static int CeedBasisApplyWeight_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, const CeedBasis_Sycl *impl, CeedScalar *w) { 246bd882c8aSJames Wright const CeedInt dim = impl->dim; 247bd882c8aSJames Wright const CeedInt Q_1d = impl->Q_1d; 248bd882c8aSJames Wright const CeedScalar *q_weight_1d = impl->d_q_weight_1d; 249bd882c8aSJames Wright 250bd882c8aSJames Wright const CeedInt num_quad_x = Q_1d; 251bd882c8aSJames Wright const CeedInt num_quad_y = (dim > 1) ? Q_1d : 1; 252bd882c8aSJames Wright const CeedInt num_quad_z = (dim > 2) ? Q_1d : 1; 253bd882c8aSJames Wright sycl::range<3> kernel_range(num_elem * num_quad_z, num_quad_y, num_quad_x); 254bd882c8aSJames Wright 2551f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 2561f4b1b45SUmesh Unnikrishnan 2571f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 2581f4b1b45SUmesh Unnikrishnan 2591f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedBasisSyclWeight>(kernel_range, e, [=](sycl::item<3> work_item) { 260bd882c8aSJames Wright if (dim == 1) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]]; 261bd882c8aSJames Wright if (dim == 2) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]] * q_weight_1d[work_item[1]]; 262bd882c8aSJames Wright if (dim == 3) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]] * q_weight_1d[work_item[1]] * q_weight_1d[work_item[0] % Q_1d]; 263bd882c8aSJames Wright }); 264bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 265bd882c8aSJames Wright } 266bd882c8aSJames Wright 267bd882c8aSJames Wright //------------------------------------------------------------------------------ 268bd882c8aSJames Wright // Basis apply - tensor 269bd882c8aSJames Wright //------------------------------------------------------------------------------ 270bd882c8aSJames Wright static int CeedBasisApply_Sycl(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u, 271bd882c8aSJames Wright CeedVector v) { 272bd882c8aSJames Wright Ceed ceed; 2730ae60fd3SJeremy L Thompson const CeedInt is_transpose = t_mode == CEED_TRANSPOSE; 274bd882c8aSJames Wright const CeedScalar *d_u; 275bd882c8aSJames Wright CeedScalar *d_v; 2760ae60fd3SJeremy L Thompson Ceed_Sycl *data; 2770ae60fd3SJeremy L Thompson CeedBasis_Sycl *impl; 2780ae60fd3SJeremy L Thompson 2790ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 2800ae60fd3SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 2810ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisGetData(basis, &impl)); 2820ae60fd3SJeremy L Thompson 2830ae60fd3SJeremy L Thompson // Get read/write access to u, v 2840ae60fd3SJeremy L Thompson if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 2850ae60fd3SJeremy L Thompson else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); 286bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 287bd882c8aSJames Wright 288bd882c8aSJames Wright // Clear v for transpose operation 2890ae60fd3SJeremy L Thompson if (is_transpose) { 290bd882c8aSJames Wright CeedSize length; 291bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(v, &length)); 2921f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 2931f4b1b45SUmesh Unnikrishnan 2941f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 2951f4b1b45SUmesh Unnikrishnan data->sycl_queue.fill<CeedScalar>(d_v, 0, length, e); 296bd882c8aSJames Wright } 297bd882c8aSJames Wright 298bd882c8aSJames Wright // Basis action 299bd882c8aSJames Wright switch (eval_mode) { 300d07cdbe5SJeremy L Thompson case CEED_EVAL_INTERP: 3010ae60fd3SJeremy L Thompson if (is_transpose) { 302d07cdbe5SJeremy L Thompson CeedCallBackend(CeedBasisApplyInterp_Sycl<true>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v)); 303bd882c8aSJames Wright } else { 304d07cdbe5SJeremy L Thompson CeedCallBackend(CeedBasisApplyInterp_Sycl<false>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v)); 305bd882c8aSJames Wright } 306d07cdbe5SJeremy L Thompson break; 307d07cdbe5SJeremy L Thompson case CEED_EVAL_GRAD: 3080ae60fd3SJeremy L Thompson if (is_transpose) { 309d07cdbe5SJeremy L Thompson CeedCallBackend(CeedBasisApplyGrad_Sycl<true>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v)); 310bd882c8aSJames Wright } else { 311d07cdbe5SJeremy L Thompson CeedCallBackend(CeedBasisApplyGrad_Sycl<false>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v)); 312bd882c8aSJames Wright } 313d07cdbe5SJeremy L Thompson break; 314d07cdbe5SJeremy L Thompson case CEED_EVAL_WEIGHT: 315097cc795SJames Wright CeedCheck(impl->d_q_weight_1d, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weight_1d not set", CeedEvalModes[eval_mode]); 316bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyWeight_Sycl(data->sycl_queue, num_elem, impl, d_v)); 317d07cdbe5SJeremy L Thompson break; 3180ae60fd3SJeremy L Thompson case CEED_EVAL_NONE: /* handled separately below */ 3190ae60fd3SJeremy L Thompson break; 320bd882c8aSJames Wright // LCOV_EXCL_START 321bd882c8aSJames Wright case CEED_EVAL_DIV: 322bd882c8aSJames Wright case CEED_EVAL_CURL: 3234e3038a5SJeremy L Thompson return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]); 324bd882c8aSJames Wright // LCOV_EXCL_STOP 325bd882c8aSJames Wright } 326bd882c8aSJames Wright 3270ae60fd3SJeremy L Thompson // Restore vectors, cover CEED_EVAL_NONE 328bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 3290ae60fd3SJeremy L Thompson if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u)); 3300ae60fd3SJeremy L Thompson if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 331*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 332bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 333bd882c8aSJames Wright } 334bd882c8aSJames Wright 335bd882c8aSJames Wright //------------------------------------------------------------------------------ 336bd882c8aSJames Wright // Interpolation kernel - non-tensor 337bd882c8aSJames Wright //------------------------------------------------------------------------------ 3380ae60fd3SJeremy L Thompson static int CeedBasisApplyNonTensorInterp_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, CeedInt is_transpose, const CeedBasisNonTensor_Sycl *impl, 339bd882c8aSJames Wright const CeedScalar *d_U, CeedScalar *d_V) { 340bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 3410ae60fd3SJeremy L Thompson const CeedInt P = is_transpose ? impl->num_qpts : impl->num_nodes; 3420ae60fd3SJeremy L Thompson const CeedInt Q = is_transpose ? impl->num_nodes : impl->num_qpts; 3430ae60fd3SJeremy L Thompson const CeedInt stride_0 = is_transpose ? 1 : impl->num_nodes; 3440ae60fd3SJeremy L Thompson const CeedInt stride_1 = is_transpose ? impl->num_nodes : 1; 345bd882c8aSJames Wright const CeedInt u_stride = P; 346bd882c8aSJames Wright const CeedInt v_stride = Q; 347bd882c8aSJames Wright const CeedInt u_comp_stride = u_stride * num_elem; 348bd882c8aSJames Wright const CeedInt v_comp_stride = v_stride * num_elem; 349bd882c8aSJames Wright const CeedInt u_size = P; 350bd882c8aSJames Wright const CeedInt v_size = Q; 351bd882c8aSJames Wright const CeedScalar *d_B = impl->d_interp; 352bd882c8aSJames Wright 353bd882c8aSJames Wright sycl::range<2> kernel_range(num_elem, v_size); 354bd882c8aSJames Wright 3551f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 3561f4b1b45SUmesh Unnikrishnan 3571f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 3581f4b1b45SUmesh Unnikrishnan 3591f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedBasisSyclInterpNT>(kernel_range, e, [=](sycl::id<2> indx) { 360bd882c8aSJames Wright const CeedInt i = indx[1]; 361bd882c8aSJames Wright const CeedInt elem = indx[0]; 362bd882c8aSJames Wright 363bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 364bd882c8aSJames Wright const CeedScalar *U = d_U + elem * u_stride + comp * u_comp_stride; 365bd882c8aSJames Wright CeedScalar V = 0.0; 366bd882c8aSJames Wright 367bd882c8aSJames Wright for (CeedInt j = 0; j < u_size; ++j) { 368bd882c8aSJames Wright V += d_B[i * stride_0 + j * stride_1] * U[j]; 369bd882c8aSJames Wright } 370bd882c8aSJames Wright d_V[i + elem * v_stride + comp * v_comp_stride] = V; 371bd882c8aSJames Wright } 372bd882c8aSJames Wright }); 373bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 374bd882c8aSJames Wright } 375bd882c8aSJames Wright 376bd882c8aSJames Wright //------------------------------------------------------------------------------ 377bd882c8aSJames Wright // Gradient kernel - non-tensor 378bd882c8aSJames Wright //------------------------------------------------------------------------------ 3790ae60fd3SJeremy L Thompson static int CeedBasisApplyNonTensorGrad_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, CeedInt is_transpose, const CeedBasisNonTensor_Sycl *impl, 380bd882c8aSJames Wright const CeedScalar *d_U, CeedScalar *d_V) { 381bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 3820ae60fd3SJeremy L Thompson const CeedInt P = is_transpose ? impl->num_qpts : impl->num_nodes; 3830ae60fd3SJeremy L Thompson const CeedInt Q = is_transpose ? impl->num_nodes : impl->num_qpts; 3840ae60fd3SJeremy L Thompson const CeedInt stride_0 = is_transpose ? 1 : impl->num_nodes; 3850ae60fd3SJeremy L Thompson const CeedInt stride_1 = is_transpose ? impl->num_nodes : 1; 386bd882c8aSJames Wright const CeedInt g_dim_stride = P * Q; 387bd882c8aSJames Wright const CeedInt u_stride = P; 388bd882c8aSJames Wright const CeedInt v_stride = Q; 389bd882c8aSJames Wright const CeedInt u_comp_stride = u_stride * num_elem; 390bd882c8aSJames Wright const CeedInt v_comp_stride = v_stride * num_elem; 391bd882c8aSJames Wright const CeedInt u_dim_stride = u_comp_stride * num_comp; 392bd882c8aSJames Wright const CeedInt v_dim_stride = v_comp_stride * num_comp; 393bd882c8aSJames Wright const CeedInt u_size = P; 394bd882c8aSJames Wright const CeedInt v_size = Q; 3950ae60fd3SJeremy L Thompson const CeedInt in_dim = is_transpose ? impl->dim : 1; 3960ae60fd3SJeremy L Thompson const CeedInt out_dim = is_transpose ? 1 : impl->dim; 397bd882c8aSJames Wright const CeedScalar *d_G = impl->d_grad; 398bd882c8aSJames Wright 399bd882c8aSJames Wright sycl::range<2> kernel_range(num_elem, v_size); 400bd882c8aSJames Wright 4011f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 4021f4b1b45SUmesh Unnikrishnan 4031f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 4041f4b1b45SUmesh Unnikrishnan 4051f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedBasisSyclGradNT>(kernel_range, e, [=](sycl::id<2> indx) { 406bd882c8aSJames Wright const CeedInt i = indx[1]; 407bd882c8aSJames Wright const CeedInt elem = indx[0]; 408bd882c8aSJames Wright 409bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 410bd882c8aSJames Wright CeedScalar V[3] = {0.0, 0.0, 0.0}; 411bd882c8aSJames Wright 412bd882c8aSJames Wright for (CeedInt d1 = 0; d1 < in_dim; ++d1) { 413bd882c8aSJames Wright const CeedScalar *U = d_U + elem * u_stride + comp * u_comp_stride + d1 * u_dim_stride; 414bd882c8aSJames Wright const CeedScalar *G = d_G + i * stride_0 + d1 * g_dim_stride; 415bd882c8aSJames Wright 416bd882c8aSJames Wright for (CeedInt j = 0; j < u_size; ++j) { 417bd882c8aSJames Wright const CeedScalar Uj = U[j]; 418bd882c8aSJames Wright 419bd882c8aSJames Wright for (CeedInt d0 = 0; d0 < out_dim; ++d0) { 420bd882c8aSJames Wright V[d0] += G[j * stride_1 + d0 * g_dim_stride] * Uj; 421bd882c8aSJames Wright } 422bd882c8aSJames Wright } 423bd882c8aSJames Wright } 424bd882c8aSJames Wright for (CeedInt d0 = 0; d0 < out_dim; ++d0) { 425bd882c8aSJames Wright d_V[i + elem * v_stride + comp * v_comp_stride + d0 * v_dim_stride] = V[d0]; 426bd882c8aSJames Wright } 427bd882c8aSJames Wright } 428bd882c8aSJames Wright }); 429bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 430bd882c8aSJames Wright } 431bd882c8aSJames Wright 432bd882c8aSJames Wright //------------------------------------------------------------------------------ 433bd882c8aSJames Wright // Weight kernel - non-tensor 434bd882c8aSJames Wright //------------------------------------------------------------------------------ 435bd882c8aSJames Wright static int CeedBasisApplyNonTensorWeight_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, const CeedBasisNonTensor_Sycl *impl, CeedScalar *d_V) { 436bd882c8aSJames Wright const CeedInt num_qpts = impl->num_qpts; 437bd882c8aSJames Wright const CeedScalar *q_weight = impl->d_q_weight; 438bd882c8aSJames Wright 439bd882c8aSJames Wright sycl::range<2> kernel_range(num_elem, num_qpts); 440bd882c8aSJames Wright 4411f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 4421f4b1b45SUmesh Unnikrishnan 4431f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 4441f4b1b45SUmesh Unnikrishnan 4451f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedBasisSyclWeightNT>(kernel_range, e, [=](sycl::id<2> indx) { 446bd882c8aSJames Wright const CeedInt i = indx[1]; 447bd882c8aSJames Wright const CeedInt elem = indx[0]; 448bd882c8aSJames Wright d_V[i + elem * num_qpts] = q_weight[i]; 449bd882c8aSJames Wright }); 450bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 451bd882c8aSJames Wright } 452bd882c8aSJames Wright 453bd882c8aSJames Wright //------------------------------------------------------------------------------ 454bd882c8aSJames Wright // Basis apply - non-tensor 455bd882c8aSJames Wright //------------------------------------------------------------------------------ 456bd882c8aSJames Wright static int CeedBasisApplyNonTensor_Sycl(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u, 457bd882c8aSJames Wright CeedVector v) { 458bd882c8aSJames Wright Ceed ceed; 4590ae60fd3SJeremy L Thompson const CeedInt is_transpose = t_mode == CEED_TRANSPOSE; 460bd882c8aSJames Wright const CeedScalar *d_u; 461bd882c8aSJames Wright CeedScalar *d_v; 4620ae60fd3SJeremy L Thompson CeedBasisNonTensor_Sycl *impl; 4630ae60fd3SJeremy L Thompson Ceed_Sycl *data; 4640ae60fd3SJeremy L Thompson 4650ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 4660ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisGetData(basis, &impl)); 4670ae60fd3SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 4680ae60fd3SJeremy L Thompson 4690ae60fd3SJeremy L Thompson // Get read/write access to u, v 4700ae60fd3SJeremy L Thompson if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 4710ae60fd3SJeremy L Thompson else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); 472bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 473bd882c8aSJames Wright 474bd882c8aSJames Wright // Clear v for transpose operation 4750ae60fd3SJeremy L Thompson if (is_transpose) { 476bd882c8aSJames Wright CeedSize length; 477bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(v, &length)); 478bd882c8aSJames Wright // Order queue 479bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 480bd882c8aSJames Wright data->sycl_queue.fill<CeedScalar>(d_v, 0, length, {e}); 481bd882c8aSJames Wright } 482bd882c8aSJames Wright 483bd882c8aSJames Wright // Apply basis operation 484bd882c8aSJames Wright switch (eval_mode) { 485d07cdbe5SJeremy L Thompson case CEED_EVAL_INTERP: 4860ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisApplyNonTensorInterp_Sycl(data->sycl_queue, num_elem, is_transpose, impl, d_u, d_v)); 487d07cdbe5SJeremy L Thompson break; 488d07cdbe5SJeremy L Thompson case CEED_EVAL_GRAD: 4890ae60fd3SJeremy L Thompson CeedCallBackend(CeedBasisApplyNonTensorGrad_Sycl(data->sycl_queue, num_elem, is_transpose, impl, d_u, d_v)); 490d07cdbe5SJeremy L Thompson break; 491d07cdbe5SJeremy L Thompson case CEED_EVAL_WEIGHT: 492097cc795SJames Wright CeedCheck(impl->d_q_weight, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weights not set", CeedEvalModes[eval_mode]); 493bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyNonTensorWeight_Sycl(data->sycl_queue, num_elem, impl, d_v)); 494d07cdbe5SJeremy L Thompson break; 4950ae60fd3SJeremy L Thompson case CEED_EVAL_NONE: /* handled separately below */ 4960ae60fd3SJeremy L Thompson break; 497bd882c8aSJames Wright // LCOV_EXCL_START 498bd882c8aSJames Wright case CEED_EVAL_DIV: 499bd882c8aSJames Wright case CEED_EVAL_CURL: 5009d1bceceSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]); 501bd882c8aSJames Wright // LCOV_EXCL_STOP 502bd882c8aSJames Wright } 503bd882c8aSJames Wright 5040ae60fd3SJeremy L Thompson // Restore vectors, cover CEED_EVAL_NONE 505bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 5060ae60fd3SJeremy L Thompson if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u)); 5070ae60fd3SJeremy L Thompson if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 508*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 509bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 510bd882c8aSJames Wright } 511bd882c8aSJames Wright 512bd882c8aSJames Wright //------------------------------------------------------------------------------ 513bd882c8aSJames Wright // Destroy tensor basis 514bd882c8aSJames Wright //------------------------------------------------------------------------------ 515bd882c8aSJames Wright static int CeedBasisDestroy_Sycl(CeedBasis basis) { 516bd882c8aSJames Wright Ceed ceed; 517bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 518bd882c8aSJames Wright CeedBasis_Sycl *impl; 519bd882c8aSJames Wright CeedCallBackend(CeedBasisGetData(basis, &impl)); 520bd882c8aSJames Wright Ceed_Sycl *data; 521bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 522bd882c8aSJames Wright 523bd882c8aSJames Wright // Wait for all work to finish before freeing memory 524bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 525bd882c8aSJames Wright 526097cc795SJames Wright if (impl->d_q_weight_1d) CeedCallSycl(ceed, sycl::free(impl->d_q_weight_1d, data->sycl_context)); 527bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_interp_1d, data->sycl_context)); 528bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_grad_1d, data->sycl_context)); 529bd882c8aSJames Wright 530bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 531*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 532bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 533bd882c8aSJames Wright } 534bd882c8aSJames Wright 535bd882c8aSJames Wright //------------------------------------------------------------------------------ 536bd882c8aSJames Wright // Destroy non-tensor basis 537bd882c8aSJames Wright //------------------------------------------------------------------------------ 538bd882c8aSJames Wright static int CeedBasisDestroyNonTensor_Sycl(CeedBasis basis) { 539bd882c8aSJames Wright Ceed ceed; 540bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 541bd882c8aSJames Wright CeedBasisNonTensor_Sycl *impl; 542bd882c8aSJames Wright CeedCallBackend(CeedBasisGetData(basis, &impl)); 543bd882c8aSJames Wright Ceed_Sycl *data; 544bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 545bd882c8aSJames Wright 546bd882c8aSJames Wright // Wait for all work to finish before freeing memory 547bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 548bd882c8aSJames Wright 549097cc795SJames Wright if (impl->d_q_weight) CeedCallSycl(ceed, sycl::free(impl->d_q_weight, data->sycl_context)); 550bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_interp, data->sycl_context)); 551bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_grad, data->sycl_context)); 552bd882c8aSJames Wright 553bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 554*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 555bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 556bd882c8aSJames Wright } 557bd882c8aSJames Wright 558bd882c8aSJames Wright //------------------------------------------------------------------------------ 559bd882c8aSJames Wright // Create tensor 560bd882c8aSJames Wright //------------------------------------------------------------------------------ 561bd882c8aSJames Wright int CeedBasisCreateTensorH1_Sycl(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d, 562bd882c8aSJames Wright const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) { 563bd882c8aSJames Wright Ceed ceed; 564bd882c8aSJames Wright CeedBasis_Sycl *impl; 565bd882c8aSJames Wright Ceed_Sycl *data; 566*9bc66399SJeremy L Thompson 567*9bc66399SJeremy L Thompson CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 568*9bc66399SJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl)); 569bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 570bd882c8aSJames Wright 571bd882c8aSJames Wright CeedInt num_comp; 572bd882c8aSJames Wright CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 573bd882c8aSJames Wright 574bd882c8aSJames Wright const CeedInt num_nodes = CeedIntPow(P_1d, dim); 575bd882c8aSJames Wright const CeedInt num_qpts = CeedIntPow(Q_1d, dim); 576bd882c8aSJames Wright 577bd882c8aSJames Wright impl->dim = dim; 578bd882c8aSJames Wright impl->P_1d = P_1d; 579bd882c8aSJames Wright impl->Q_1d = Q_1d; 580bd882c8aSJames Wright impl->num_comp = num_comp; 581bd882c8aSJames Wright impl->num_nodes = num_nodes; 582bd882c8aSJames Wright impl->num_qpts = num_qpts; 583bd882c8aSJames Wright impl->buf_len = num_comp * CeedIntMax(num_nodes, num_qpts); 584bd882c8aSJames Wright impl->op_len = Q_1d * P_1d; 585bd882c8aSJames Wright 5861f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 5871f4b1b45SUmesh Unnikrishnan 5881f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 589bd882c8aSJames Wright 590097cc795SJames Wright std::vector<sycl::event> copy_events; 591097cc795SJames Wright if (q_weight_1d) { 592bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_q_weight_1d = sycl::malloc_device<CeedScalar>(Q_1d, data->sycl_device, data->sycl_context)); 5931f4b1b45SUmesh Unnikrishnan sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight_1d, impl->d_q_weight_1d, Q_1d, e); 594097cc795SJames Wright copy_events.push_back(copy_weight); 595097cc795SJames Wright } 596bd882c8aSJames Wright 597bd882c8aSJames Wright const CeedInt interp_length = Q_1d * P_1d; 598bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_interp_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 5991f4b1b45SUmesh Unnikrishnan sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp_1d, impl->d_interp_1d, interp_length, e); 600097cc795SJames Wright copy_events.push_back(copy_interp); 601bd882c8aSJames Wright 602bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_grad_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 6031f4b1b45SUmesh Unnikrishnan sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad_1d, impl->d_grad_1d, interp_length, e); 604097cc795SJames Wright copy_events.push_back(copy_grad); 605bd882c8aSJames Wright 606097cc795SJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw(copy_events)); 607bd882c8aSJames Wright 608bd882c8aSJames Wright std::vector<sycl::kernel_id> kernel_ids = {sycl::get_kernel_id<CeedBasisSyclInterp<1>>(), sycl::get_kernel_id<CeedBasisSyclInterp<0>>(), 609bd882c8aSJames Wright sycl::get_kernel_id<CeedBasisSyclGrad<1>>(), sycl::get_kernel_id<CeedBasisSyclGrad<0>>()}; 610bd882c8aSJames Wright 611bd882c8aSJames Wright sycl::kernel_bundle<sycl::bundle_state::input> input_bundle = sycl::get_kernel_bundle<sycl::bundle_state::input>(data->sycl_context, kernel_ids); 612bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_DIM_ID>(dim); 613bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_NUM_COMP_ID>(num_comp); 614bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_Q_1D_ID>(Q_1d); 615bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_P_1D_ID>(P_1d); 616bd882c8aSJames Wright 617bd882c8aSJames Wright CeedCallSycl(ceed, impl->sycl_module = new SyclModule_t(sycl::build(input_bundle))); 618bd882c8aSJames Wright 619bd882c8aSJames Wright CeedCallBackend(CeedBasisSetData(basis, impl)); 620bd882c8aSJames Wright 621bd882c8aSJames Wright // Register backend functions 622bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApply_Sycl)); 623bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Sycl)); 624*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 625bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 626bd882c8aSJames Wright } 627bd882c8aSJames Wright 628bd882c8aSJames Wright //------------------------------------------------------------------------------ 629bd882c8aSJames Wright // Create non-tensor 630bd882c8aSJames Wright //------------------------------------------------------------------------------ 631bd882c8aSJames Wright int CeedBasisCreateH1_Sycl(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad, 632dd64fc84SJeremy L Thompson const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) { 633bd882c8aSJames Wright Ceed ceed; 634bd882c8aSJames Wright CeedBasisNonTensor_Sycl *impl; 635bd882c8aSJames Wright Ceed_Sycl *data; 636*9bc66399SJeremy L Thompson 637*9bc66399SJeremy L Thompson CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 638*9bc66399SJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl)); 639bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 640bd882c8aSJames Wright 641bd882c8aSJames Wright CeedInt num_comp; 642bd882c8aSJames Wright CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 643bd882c8aSJames Wright 644bd882c8aSJames Wright impl->dim = dim; 645bd882c8aSJames Wright impl->num_comp = num_comp; 646bd882c8aSJames Wright impl->num_nodes = num_nodes; 647bd882c8aSJames Wright impl->num_qpts = num_qpts; 648bd882c8aSJames Wright 6491f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 6501f4b1b45SUmesh Unnikrishnan 6511f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 652bd882c8aSJames Wright 653097cc795SJames Wright std::vector<sycl::event> copy_events; 654097cc795SJames Wright if (q_weight) { 655bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_q_weight = sycl::malloc_device<CeedScalar>(num_qpts, data->sycl_device, data->sycl_context)); 6561f4b1b45SUmesh Unnikrishnan sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight, impl->d_q_weight, num_qpts, e); 657097cc795SJames Wright copy_events.push_back(copy_weight); 658097cc795SJames Wright } 659bd882c8aSJames Wright 660bd882c8aSJames Wright const CeedInt interp_length = num_qpts * num_nodes; 661bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_interp = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 6621f4b1b45SUmesh Unnikrishnan sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp, impl->d_interp, interp_length, e); 663097cc795SJames Wright copy_events.push_back(copy_interp); 664bd882c8aSJames Wright 665bd882c8aSJames Wright const CeedInt grad_length = num_qpts * num_nodes * dim; 666bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_grad = sycl::malloc_device<CeedScalar>(grad_length, data->sycl_device, data->sycl_context)); 6671f4b1b45SUmesh Unnikrishnan sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad, impl->d_grad, grad_length, e); 668097cc795SJames Wright copy_events.push_back(copy_grad); 669bd882c8aSJames Wright 670097cc795SJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw(copy_events)); 671bd882c8aSJames Wright 672bd882c8aSJames Wright CeedCallBackend(CeedBasisSetData(basis, impl)); 673bd882c8aSJames Wright 674bd882c8aSJames Wright // Register backend functions 675bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Sycl)); 676bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Sycl)); 677*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 678bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 679bd882c8aSJames Wright } 680ff1e7120SSebastian Grimberg 681bd882c8aSJames Wright //------------------------------------------------------------------------------ 682