1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*bd882c8aSJames Wright // 4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5*bd882c8aSJames Wright // 6*bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7*bd882c8aSJames Wright 8*bd882c8aSJames Wright #include <ceed/backend.h> 9*bd882c8aSJames Wright #include <ceed/ceed.h> 10*bd882c8aSJames Wright #include <ceed/jit-tools.h> 11*bd882c8aSJames Wright 12*bd882c8aSJames Wright #include <sycl/sycl.hpp> 13*bd882c8aSJames Wright #include <vector> 14*bd882c8aSJames Wright 15*bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 16*bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 17*bd882c8aSJames Wright 18*bd882c8aSJames Wright template <int> 19*bd882c8aSJames Wright class CeedBasisSyclInterp; 20*bd882c8aSJames Wright template <int> 21*bd882c8aSJames Wright class CeedBasisSyclGrad; 22*bd882c8aSJames Wright class CeedBasisSyclWeight; 23*bd882c8aSJames Wright 24*bd882c8aSJames Wright class CeedBasisSyclInterpNT; 25*bd882c8aSJames Wright class CeedBasisSyclGradNT; 26*bd882c8aSJames Wright class CeedBasisSyclWeightNT; 27*bd882c8aSJames Wright 28*bd882c8aSJames Wright using SpecID = sycl::specialization_id<CeedInt>; 29*bd882c8aSJames Wright 30*bd882c8aSJames Wright static constexpr SpecID BASIS_DIM_ID; 31*bd882c8aSJames Wright static constexpr SpecID BASIS_NUM_COMP_ID; 32*bd882c8aSJames Wright static constexpr SpecID BASIS_P_1D_ID; 33*bd882c8aSJames Wright static constexpr SpecID BASIS_Q_1D_ID; 34*bd882c8aSJames Wright 35*bd882c8aSJames Wright //------------------------------------------------------------------------------ 36*bd882c8aSJames Wright // Interpolation kernel - tensor 37*bd882c8aSJames Wright //------------------------------------------------------------------------------ 38*bd882c8aSJames Wright template <int transpose> 39*bd882c8aSJames Wright static int CeedBasisApplyInterp_Sycl(sycl::queue &sycl_queue, const SyclModule_t &sycl_module, CeedInt num_elem, const CeedBasis_Sycl *impl, 40*bd882c8aSJames Wright const CeedScalar *u, CeedScalar *v) { 41*bd882c8aSJames Wright const CeedInt buf_len = impl->buf_len; 42*bd882c8aSJames Wright const CeedInt op_len = impl->op_len; 43*bd882c8aSJames Wright const CeedScalar *interp_1d = impl->d_interp_1d; 44*bd882c8aSJames Wright 45*bd882c8aSJames Wright const sycl::device &sycl_device = sycl_queue.get_device(); 46*bd882c8aSJames Wright const CeedInt max_work_group_size = 32; 47*bd882c8aSJames Wright const CeedInt work_group_size = CeedIntMin(impl->num_qpts, max_work_group_size); 48*bd882c8aSJames Wright sycl::range<1> local_range(work_group_size); 49*bd882c8aSJames Wright sycl::range<1> global_range(num_elem * work_group_size); 50*bd882c8aSJames Wright sycl::nd_range<1> kernel_range(global_range, local_range); 51*bd882c8aSJames Wright 52*bd882c8aSJames Wright // Order queue 53*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 54*bd882c8aSJames Wright sycl_queue.submit([&](sycl::handler &cgh) { 55*bd882c8aSJames Wright cgh.depends_on({e}); 56*bd882c8aSJames Wright cgh.use_kernel_bundle(sycl_module); 57*bd882c8aSJames Wright 58*bd882c8aSJames Wright sycl::local_accessor<CeedScalar> s_mem(op_len + 2 * buf_len, cgh); 59*bd882c8aSJames Wright 60*bd882c8aSJames Wright cgh.parallel_for<CeedBasisSyclInterp<transpose>>(kernel_range, [=](sycl::nd_item<1> work_item, sycl::kernel_handler kh) { 61*bd882c8aSJames Wright //--------------------------------------------------------------> 62*bd882c8aSJames Wright // Retrieve spec constant values 63*bd882c8aSJames Wright const CeedInt dim = kh.get_specialization_constant<BASIS_DIM_ID>(); 64*bd882c8aSJames Wright const CeedInt num_comp = kh.get_specialization_constant<BASIS_NUM_COMP_ID>(); 65*bd882c8aSJames Wright const CeedInt P_1d = kh.get_specialization_constant<BASIS_P_1D_ID>(); 66*bd882c8aSJames Wright const CeedInt Q_1d = kh.get_specialization_constant<BASIS_Q_1D_ID>(); 67*bd882c8aSJames Wright //--------------------------------------------------------------> 68*bd882c8aSJames Wright const CeedInt num_nodes = CeedIntPow(P_1d, dim); 69*bd882c8aSJames Wright const CeedInt num_qpts = CeedIntPow(Q_1d, dim); 70*bd882c8aSJames Wright const CeedInt P = transpose ? Q_1d : P_1d; 71*bd882c8aSJames Wright const CeedInt Q = transpose ? P_1d : Q_1d; 72*bd882c8aSJames Wright const CeedInt stride_0 = transpose ? 1 : P_1d; 73*bd882c8aSJames Wright const CeedInt stride_1 = transpose ? P_1d : 1; 74*bd882c8aSJames Wright const CeedInt u_stride = transpose ? num_qpts : num_nodes; 75*bd882c8aSJames Wright const CeedInt v_stride = transpose ? num_nodes : num_qpts; 76*bd882c8aSJames Wright const CeedInt u_comp_stride = num_elem * u_stride; 77*bd882c8aSJames Wright const CeedInt v_comp_stride = num_elem * v_stride; 78*bd882c8aSJames Wright const CeedInt u_size = u_stride; 79*bd882c8aSJames Wright 80*bd882c8aSJames Wright sycl::group work_group = work_item.get_group(); 81*bd882c8aSJames Wright const CeedInt i = work_item.get_local_linear_id(); 82*bd882c8aSJames Wright const CeedInt group_size = work_group.get_local_linear_range(); 83*bd882c8aSJames Wright const CeedInt elem = work_group.get_group_linear_id(); 84*bd882c8aSJames Wright 85*bd882c8aSJames Wright CeedScalar *s_interp_1d = s_mem.get_pointer(); 86*bd882c8aSJames Wright CeedScalar *s_buffer_1 = s_interp_1d + Q * P; 87*bd882c8aSJames Wright CeedScalar *s_buffer_2 = s_buffer_1 + buf_len; 88*bd882c8aSJames Wright 89*bd882c8aSJames Wright for (CeedInt k = i; k < P * Q; k += group_size) { 90*bd882c8aSJames Wright s_interp_1d[k] = interp_1d[k]; 91*bd882c8aSJames Wright } 92*bd882c8aSJames Wright 93*bd882c8aSJames Wright // Apply basis element by element 94*bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 95*bd882c8aSJames Wright const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride; 96*bd882c8aSJames Wright CeedScalar *cur_v = v + elem * v_stride + comp * v_comp_stride; 97*bd882c8aSJames Wright 98*bd882c8aSJames Wright for (CeedInt k = i; k < u_size; k += group_size) { 99*bd882c8aSJames Wright s_buffer_1[k] = cur_u[k]; 100*bd882c8aSJames Wright } 101*bd882c8aSJames Wright 102*bd882c8aSJames Wright CeedInt pre = u_size; 103*bd882c8aSJames Wright CeedInt post = 1; 104*bd882c8aSJames Wright 105*bd882c8aSJames Wright for (CeedInt d = 0; d < dim; d++) { 106*bd882c8aSJames Wright // Use older version of sycl workgroup barrier for performance reasons 107*bd882c8aSJames Wright // Can be updated in future to align with SYCL2020 spec if performance bottleneck is removed 108*bd882c8aSJames Wright // sycl::group_barrier(work_group); 109*bd882c8aSJames Wright work_item.barrier(sycl::access::fence_space::local_space); 110*bd882c8aSJames Wright 111*bd882c8aSJames Wright pre /= P; 112*bd882c8aSJames Wright const CeedScalar *in = d % 2 ? s_buffer_2 : s_buffer_1; 113*bd882c8aSJames Wright CeedScalar *out = d == dim - 1 ? cur_v : (d % 2 ? s_buffer_1 : s_buffer_2); 114*bd882c8aSJames Wright 115*bd882c8aSJames Wright // Contract along middle index 116*bd882c8aSJames Wright const CeedInt writeLen = pre * post * Q; 117*bd882c8aSJames Wright for (CeedInt k = i; k < writeLen; k += group_size) { 118*bd882c8aSJames Wright const CeedInt c = k % post; 119*bd882c8aSJames Wright const CeedInt j = (k / post) % Q; 120*bd882c8aSJames Wright const CeedInt a = k / (post * Q); 121*bd882c8aSJames Wright 122*bd882c8aSJames Wright CeedScalar vk = 0; 123*bd882c8aSJames Wright for (CeedInt b = 0; b < P; b++) { 124*bd882c8aSJames Wright vk += s_interp_1d[j * stride_0 + b * stride_1] * in[(a * P + b) * post + c]; 125*bd882c8aSJames Wright } 126*bd882c8aSJames Wright out[k] = vk; 127*bd882c8aSJames Wright } 128*bd882c8aSJames Wright post *= Q; 129*bd882c8aSJames Wright } 130*bd882c8aSJames Wright } 131*bd882c8aSJames Wright }); 132*bd882c8aSJames Wright }); 133*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 134*bd882c8aSJames Wright } 135*bd882c8aSJames Wright 136*bd882c8aSJames Wright //------------------------------------------------------------------------------ 137*bd882c8aSJames Wright // Gradient kernel - tensor 138*bd882c8aSJames Wright //------------------------------------------------------------------------------ 139*bd882c8aSJames Wright template <int transpose> 140*bd882c8aSJames Wright static int CeedBasisApplyGrad_Sycl(sycl::queue &sycl_queue, const SyclModule_t &sycl_module, CeedInt num_elem, const CeedBasis_Sycl *impl, 141*bd882c8aSJames Wright const CeedScalar *u, CeedScalar *v) { 142*bd882c8aSJames Wright const CeedInt buf_len = impl->buf_len; 143*bd882c8aSJames Wright const CeedInt op_len = impl->op_len; 144*bd882c8aSJames Wright const CeedScalar *interp_1d = impl->d_interp_1d; 145*bd882c8aSJames Wright const CeedScalar *grad_1d = impl->d_grad_1d; 146*bd882c8aSJames Wright 147*bd882c8aSJames Wright const sycl::device &sycl_device = sycl_queue.get_device(); 148*bd882c8aSJames Wright const CeedInt work_group_size = 32; 149*bd882c8aSJames Wright sycl::range<1> local_range(work_group_size); 150*bd882c8aSJames Wright sycl::range<1> global_range(num_elem * work_group_size); 151*bd882c8aSJames Wright sycl::nd_range<1> kernel_range(global_range, local_range); 152*bd882c8aSJames Wright 153*bd882c8aSJames Wright // Order queue 154*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 155*bd882c8aSJames Wright sycl_queue.submit([&](sycl::handler &cgh) { 156*bd882c8aSJames Wright cgh.depends_on({e}); 157*bd882c8aSJames Wright cgh.use_kernel_bundle(sycl_module); 158*bd882c8aSJames Wright 159*bd882c8aSJames Wright sycl::local_accessor<CeedScalar> s_mem(2 * (op_len + buf_len), cgh); 160*bd882c8aSJames Wright 161*bd882c8aSJames Wright cgh.parallel_for<CeedBasisSyclGrad<transpose>>(kernel_range, [=](sycl::nd_item<1> work_item, sycl::kernel_handler kh) { 162*bd882c8aSJames Wright //--------------------------------------------------------------> 163*bd882c8aSJames Wright // Retrieve spec constant values 164*bd882c8aSJames Wright const CeedInt dim = kh.get_specialization_constant<BASIS_DIM_ID>(); 165*bd882c8aSJames Wright const CeedInt num_comp = kh.get_specialization_constant<BASIS_NUM_COMP_ID>(); 166*bd882c8aSJames Wright const CeedInt P_1d = kh.get_specialization_constant<BASIS_P_1D_ID>(); 167*bd882c8aSJames Wright const CeedInt Q_1d = kh.get_specialization_constant<BASIS_Q_1D_ID>(); 168*bd882c8aSJames Wright //--------------------------------------------------------------> 169*bd882c8aSJames Wright const CeedInt num_nodes = CeedIntPow(P_1d, dim); 170*bd882c8aSJames Wright const CeedInt num_qpts = CeedIntPow(Q_1d, dim); 171*bd882c8aSJames Wright const CeedInt P = transpose ? Q_1d : P_1d; 172*bd882c8aSJames Wright const CeedInt Q = transpose ? P_1d : Q_1d; 173*bd882c8aSJames Wright const CeedInt stride_0 = transpose ? 1 : P_1d; 174*bd882c8aSJames Wright const CeedInt stride_1 = transpose ? P_1d : 1; 175*bd882c8aSJames Wright const CeedInt u_stride = transpose ? num_qpts : num_nodes; 176*bd882c8aSJames Wright const CeedInt v_stride = transpose ? num_nodes : num_qpts; 177*bd882c8aSJames Wright const CeedInt u_comp_stride = num_elem * u_stride; 178*bd882c8aSJames Wright const CeedInt v_comp_stride = num_elem * v_stride; 179*bd882c8aSJames Wright const CeedInt u_dim_stride = transpose ? num_elem * num_qpts * num_comp : 0; 180*bd882c8aSJames Wright const CeedInt v_dim_stride = transpose ? 0 : num_elem * num_qpts * num_comp; 181*bd882c8aSJames Wright 182*bd882c8aSJames Wright sycl::group work_group = work_item.get_group(); 183*bd882c8aSJames Wright const CeedInt i = work_item.get_local_linear_id(); 184*bd882c8aSJames Wright const CeedInt group_size = work_group.get_local_linear_range(); 185*bd882c8aSJames Wright const CeedInt elem = work_group.get_group_linear_id(); 186*bd882c8aSJames Wright 187*bd882c8aSJames Wright CeedScalar *s_interp_1d = s_mem.get_pointer(); 188*bd882c8aSJames Wright CeedScalar *s_grad_1d = s_interp_1d + P * Q; 189*bd882c8aSJames Wright CeedScalar *s_buffer_1 = s_grad_1d + P * Q; 190*bd882c8aSJames Wright CeedScalar *s_buffer_2 = s_buffer_1 + buf_len; 191*bd882c8aSJames Wright 192*bd882c8aSJames Wright for (CeedInt k = i; k < P * Q; k += group_size) { 193*bd882c8aSJames Wright s_interp_1d[k] = interp_1d[k]; 194*bd882c8aSJames Wright s_grad_1d[k] = grad_1d[k]; 195*bd882c8aSJames Wright } 196*bd882c8aSJames Wright 197*bd882c8aSJames Wright // Apply basis element by element 198*bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 199*bd882c8aSJames Wright for (CeedInt dim_1 = 0; dim_1 < dim; dim_1++) { 200*bd882c8aSJames Wright CeedInt pre = transpose ? num_qpts : num_nodes; 201*bd882c8aSJames Wright CeedInt post = 1; 202*bd882c8aSJames Wright const CeedScalar *cur_u = u + elem * u_stride + dim_1 * u_dim_stride + comp * u_comp_stride; 203*bd882c8aSJames Wright CeedScalar *cur_v = v + elem * v_stride + dim_1 * v_dim_stride + comp * v_comp_stride; 204*bd882c8aSJames Wright 205*bd882c8aSJames Wright for (CeedInt dim_2 = 0; dim_2 < dim; dim_2++) { 206*bd882c8aSJames Wright // Use older version of sycl workgroup barrier for performance reasons 207*bd882c8aSJames Wright // Can be updated in future to align with SYCL2020 spec if performance bottleneck is removed 208*bd882c8aSJames Wright // sycl::group_barrier(work_group); 209*bd882c8aSJames Wright work_item.barrier(sycl::access::fence_space::local_space); 210*bd882c8aSJames Wright 211*bd882c8aSJames Wright pre /= P; 212*bd882c8aSJames Wright const CeedScalar *op = dim_1 == dim_2 ? s_grad_1d : s_interp_1d; 213*bd882c8aSJames Wright const CeedScalar *in = (dim_2 == 0 ? cur_u : (dim_2 % 2 ? s_buffer_2 : s_buffer_1)); 214*bd882c8aSJames Wright CeedScalar *out = dim_2 == dim - 1 ? cur_v : (dim_2 % 2 ? s_buffer_1 : s_buffer_2); 215*bd882c8aSJames Wright 216*bd882c8aSJames Wright // Contract along middle index 217*bd882c8aSJames Wright const CeedInt writeLen = pre * post * Q; 218*bd882c8aSJames Wright for (CeedInt k = i; k < writeLen; k += group_size) { 219*bd882c8aSJames Wright const CeedInt c = k % post; 220*bd882c8aSJames Wright const CeedInt j = (k / post) % Q; 221*bd882c8aSJames Wright const CeedInt a = k / (post * Q); 222*bd882c8aSJames Wright 223*bd882c8aSJames Wright CeedScalar v_k = 0; 224*bd882c8aSJames Wright for (CeedInt b = 0; b < P; b++) v_k += op[j * stride_0 + b * stride_1] * in[(a * P + b) * post + c]; 225*bd882c8aSJames Wright 226*bd882c8aSJames Wright if (transpose && dim_2 == dim - 1) out[k] += v_k; 227*bd882c8aSJames Wright else out[k] = v_k; 228*bd882c8aSJames Wright } 229*bd882c8aSJames Wright 230*bd882c8aSJames Wright post *= Q; 231*bd882c8aSJames Wright } 232*bd882c8aSJames Wright } 233*bd882c8aSJames Wright } 234*bd882c8aSJames Wright }); 235*bd882c8aSJames Wright }); 236*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 237*bd882c8aSJames Wright } 238*bd882c8aSJames Wright 239*bd882c8aSJames Wright //------------------------------------------------------------------------------ 240*bd882c8aSJames Wright // Weight kernel - tensor 241*bd882c8aSJames Wright //------------------------------------------------------------------------------ 242*bd882c8aSJames Wright static int CeedBasisApplyWeight_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, const CeedBasis_Sycl *impl, CeedScalar *w) { 243*bd882c8aSJames Wright const CeedInt dim = impl->dim; 244*bd882c8aSJames Wright const CeedInt Q_1d = impl->Q_1d; 245*bd882c8aSJames Wright const CeedScalar *q_weight_1d = impl->d_q_weight_1d; 246*bd882c8aSJames Wright 247*bd882c8aSJames Wright const CeedInt num_quad_x = Q_1d; 248*bd882c8aSJames Wright const CeedInt num_quad_y = (dim > 1) ? Q_1d : 1; 249*bd882c8aSJames Wright const CeedInt num_quad_z = (dim > 2) ? Q_1d : 1; 250*bd882c8aSJames Wright sycl::range<3> kernel_range(num_elem * num_quad_z, num_quad_y, num_quad_x); 251*bd882c8aSJames Wright 252*bd882c8aSJames Wright // Order queue 253*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 254*bd882c8aSJames Wright sycl_queue.parallel_for<CeedBasisSyclWeight>(kernel_range, {e}, [=](sycl::item<3> work_item) { 255*bd882c8aSJames Wright if (dim == 1) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]]; 256*bd882c8aSJames Wright if (dim == 2) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]] * q_weight_1d[work_item[1]]; 257*bd882c8aSJames Wright if (dim == 3) w[work_item.get_linear_id()] = q_weight_1d[work_item[2]] * q_weight_1d[work_item[1]] * q_weight_1d[work_item[0] % Q_1d]; 258*bd882c8aSJames Wright }); 259*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 260*bd882c8aSJames Wright } 261*bd882c8aSJames Wright 262*bd882c8aSJames Wright //------------------------------------------------------------------------------ 263*bd882c8aSJames Wright // Basis apply - tensor 264*bd882c8aSJames Wright //------------------------------------------------------------------------------ 265*bd882c8aSJames Wright static int CeedBasisApply_Sycl(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u, 266*bd882c8aSJames Wright CeedVector v) { 267*bd882c8aSJames Wright Ceed ceed; 268*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 269*bd882c8aSJames Wright Ceed_Sycl *data; 270*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 271*bd882c8aSJames Wright CeedBasis_Sycl *impl; 272*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetData(basis, &impl)); 273*bd882c8aSJames Wright 274*bd882c8aSJames Wright const CeedInt transpose = t_mode == CEED_TRANSPOSE; 275*bd882c8aSJames Wright 276*bd882c8aSJames Wright // Read vectors 277*bd882c8aSJames Wright const CeedScalar *d_u; 278*bd882c8aSJames Wright CeedScalar *d_v; 279*bd882c8aSJames Wright if (eval_mode != CEED_EVAL_WEIGHT) { 280*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 281*bd882c8aSJames Wright } 282*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 283*bd882c8aSJames Wright 284*bd882c8aSJames Wright // Clear v for transpose operation 285*bd882c8aSJames Wright if (t_mode == CEED_TRANSPOSE) { 286*bd882c8aSJames Wright CeedSize length; 287*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(v, &length)); 288*bd882c8aSJames Wright // Order queue 289*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 290*bd882c8aSJames Wright data->sycl_queue.fill<CeedScalar>(d_v, 0, length, {e}); 291*bd882c8aSJames Wright } 292*bd882c8aSJames Wright 293*bd882c8aSJames Wright // Basis action 294*bd882c8aSJames Wright switch (eval_mode) { 295*bd882c8aSJames Wright case CEED_EVAL_INTERP: { 296*bd882c8aSJames Wright if (transpose) { 297*bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyInterp_Sycl<CEED_TRANSPOSE>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v)); 298*bd882c8aSJames Wright } else { 299*bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyInterp_Sycl<CEED_NOTRANSPOSE>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v)); 300*bd882c8aSJames Wright } 301*bd882c8aSJames Wright } break; 302*bd882c8aSJames Wright case CEED_EVAL_GRAD: { 303*bd882c8aSJames Wright if (transpose) { 304*bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyGrad_Sycl<1>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v)); 305*bd882c8aSJames Wright } else { 306*bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyGrad_Sycl<0>(data->sycl_queue, *impl->sycl_module, num_elem, impl, d_u, d_v)); 307*bd882c8aSJames Wright } 308*bd882c8aSJames Wright } break; 309*bd882c8aSJames Wright case CEED_EVAL_WEIGHT: { 310*bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyWeight_Sycl(data->sycl_queue, num_elem, impl, d_v)); 311*bd882c8aSJames Wright } break; 312*bd882c8aSJames Wright // LCOV_EXCL_START 313*bd882c8aSJames Wright // Evaluate the divergence to/from the quadrature points 314*bd882c8aSJames Wright case CEED_EVAL_DIV: 315*bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported"); 316*bd882c8aSJames Wright // Evaluate the curl to/from the quadrature points 317*bd882c8aSJames Wright case CEED_EVAL_CURL: 318*bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported"); 319*bd882c8aSJames Wright // Take no action, BasisApply should not have been called 320*bd882c8aSJames Wright case CEED_EVAL_NONE: 321*bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context"); 322*bd882c8aSJames Wright // LCOV_EXCL_STOP 323*bd882c8aSJames Wright } 324*bd882c8aSJames Wright 325*bd882c8aSJames Wright // Restore vectors 326*bd882c8aSJames Wright if (eval_mode != CEED_EVAL_WEIGHT) { 327*bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 328*bd882c8aSJames Wright } 329*bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 330*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 331*bd882c8aSJames Wright } 332*bd882c8aSJames Wright 333*bd882c8aSJames Wright //------------------------------------------------------------------------------ 334*bd882c8aSJames Wright // Interpolation kernel - non-tensor 335*bd882c8aSJames Wright //------------------------------------------------------------------------------ 336*bd882c8aSJames Wright static int CeedBasisApplyNonTensorInterp_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, CeedInt transpose, const CeedBasisNonTensor_Sycl *impl, 337*bd882c8aSJames Wright const CeedScalar *d_U, CeedScalar *d_V) { 338*bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 339*bd882c8aSJames Wright const CeedInt P = transpose ? impl->num_qpts : impl->num_nodes; 340*bd882c8aSJames Wright const CeedInt Q = transpose ? impl->num_nodes : impl->num_qpts; 341*bd882c8aSJames Wright const CeedInt stride_0 = transpose ? 1 : impl->num_nodes; 342*bd882c8aSJames Wright const CeedInt stride_1 = transpose ? impl->num_nodes : 1; 343*bd882c8aSJames Wright const CeedInt u_stride = P; 344*bd882c8aSJames Wright const CeedInt v_stride = Q; 345*bd882c8aSJames Wright const CeedInt u_comp_stride = u_stride * num_elem; 346*bd882c8aSJames Wright const CeedInt v_comp_stride = v_stride * num_elem; 347*bd882c8aSJames Wright const CeedInt u_size = P; 348*bd882c8aSJames Wright const CeedInt v_size = Q; 349*bd882c8aSJames Wright const CeedScalar *d_B = impl->d_interp; 350*bd882c8aSJames Wright 351*bd882c8aSJames Wright sycl::range<2> kernel_range(num_elem, v_size); 352*bd882c8aSJames Wright 353*bd882c8aSJames Wright // Order queue 354*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 355*bd882c8aSJames Wright sycl_queue.parallel_for<CeedBasisSyclInterpNT>(kernel_range, {e}, [=](sycl::id<2> indx) { 356*bd882c8aSJames Wright const CeedInt i = indx[1]; 357*bd882c8aSJames Wright const CeedInt elem = indx[0]; 358*bd882c8aSJames Wright 359*bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 360*bd882c8aSJames Wright const CeedScalar *U = d_U + elem * u_stride + comp * u_comp_stride; 361*bd882c8aSJames Wright CeedScalar V = 0.0; 362*bd882c8aSJames Wright 363*bd882c8aSJames Wright for (CeedInt j = 0; j < u_size; ++j) { 364*bd882c8aSJames Wright V += d_B[i * stride_0 + j * stride_1] * U[j]; 365*bd882c8aSJames Wright } 366*bd882c8aSJames Wright d_V[i + elem * v_stride + comp * v_comp_stride] = V; 367*bd882c8aSJames Wright } 368*bd882c8aSJames Wright }); 369*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 370*bd882c8aSJames Wright } 371*bd882c8aSJames Wright 372*bd882c8aSJames Wright //------------------------------------------------------------------------------ 373*bd882c8aSJames Wright // Gradient kernel - non-tensor 374*bd882c8aSJames Wright //------------------------------------------------------------------------------ 375*bd882c8aSJames Wright static int CeedBasisApplyNonTensorGrad_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, CeedInt transpose, const CeedBasisNonTensor_Sycl *impl, 376*bd882c8aSJames Wright const CeedScalar *d_U, CeedScalar *d_V) { 377*bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 378*bd882c8aSJames Wright const CeedInt P = transpose ? impl->num_qpts : impl->num_nodes; 379*bd882c8aSJames Wright const CeedInt Q = transpose ? impl->num_nodes : impl->num_qpts; 380*bd882c8aSJames Wright const CeedInt stride_0 = transpose ? 1 : impl->num_nodes; 381*bd882c8aSJames Wright const CeedInt stride_1 = transpose ? impl->num_nodes : 1; 382*bd882c8aSJames Wright const CeedInt g_dim_stride = P * Q; 383*bd882c8aSJames Wright const CeedInt u_stride = P; 384*bd882c8aSJames Wright const CeedInt v_stride = Q; 385*bd882c8aSJames Wright const CeedInt u_comp_stride = u_stride * num_elem; 386*bd882c8aSJames Wright const CeedInt v_comp_stride = v_stride * num_elem; 387*bd882c8aSJames Wright const CeedInt u_dim_stride = u_comp_stride * num_comp; 388*bd882c8aSJames Wright const CeedInt v_dim_stride = v_comp_stride * num_comp; 389*bd882c8aSJames Wright const CeedInt u_size = P; 390*bd882c8aSJames Wright const CeedInt v_size = Q; 391*bd882c8aSJames Wright const CeedInt in_dim = transpose ? impl->dim : 1; 392*bd882c8aSJames Wright const CeedInt out_dim = transpose ? 1 : impl->dim; 393*bd882c8aSJames Wright const CeedScalar *d_G = impl->d_grad; 394*bd882c8aSJames Wright 395*bd882c8aSJames Wright sycl::range<2> kernel_range(num_elem, v_size); 396*bd882c8aSJames Wright 397*bd882c8aSJames Wright // Order queue 398*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 399*bd882c8aSJames Wright sycl_queue.parallel_for<CeedBasisSyclGradNT>(kernel_range, {e}, [=](sycl::id<2> indx) { 400*bd882c8aSJames Wright const CeedInt i = indx[1]; 401*bd882c8aSJames Wright const CeedInt elem = indx[0]; 402*bd882c8aSJames Wright 403*bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 404*bd882c8aSJames Wright CeedScalar V[3] = {0.0, 0.0, 0.0}; 405*bd882c8aSJames Wright 406*bd882c8aSJames Wright for (CeedInt d1 = 0; d1 < in_dim; ++d1) { 407*bd882c8aSJames Wright const CeedScalar *U = d_U + elem * u_stride + comp * u_comp_stride + d1 * u_dim_stride; 408*bd882c8aSJames Wright const CeedScalar *G = d_G + i * stride_0 + d1 * g_dim_stride; 409*bd882c8aSJames Wright 410*bd882c8aSJames Wright for (CeedInt j = 0; j < u_size; ++j) { 411*bd882c8aSJames Wright const CeedScalar Uj = U[j]; 412*bd882c8aSJames Wright 413*bd882c8aSJames Wright for (CeedInt d0 = 0; d0 < out_dim; ++d0) { 414*bd882c8aSJames Wright V[d0] += G[j * stride_1 + d0 * g_dim_stride] * Uj; 415*bd882c8aSJames Wright } 416*bd882c8aSJames Wright } 417*bd882c8aSJames Wright } 418*bd882c8aSJames Wright for (CeedInt d0 = 0; d0 < out_dim; ++d0) { 419*bd882c8aSJames Wright d_V[i + elem * v_stride + comp * v_comp_stride + d0 * v_dim_stride] = V[d0]; 420*bd882c8aSJames Wright } 421*bd882c8aSJames Wright } 422*bd882c8aSJames Wright }); 423*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 424*bd882c8aSJames Wright } 425*bd882c8aSJames Wright 426*bd882c8aSJames Wright //------------------------------------------------------------------------------ 427*bd882c8aSJames Wright // Weight kernel - non-tensor 428*bd882c8aSJames Wright //------------------------------------------------------------------------------ 429*bd882c8aSJames Wright static int CeedBasisApplyNonTensorWeight_Sycl(sycl::queue &sycl_queue, CeedInt num_elem, const CeedBasisNonTensor_Sycl *impl, CeedScalar *d_V) { 430*bd882c8aSJames Wright const CeedInt num_qpts = impl->num_qpts; 431*bd882c8aSJames Wright const CeedScalar *q_weight = impl->d_q_weight; 432*bd882c8aSJames Wright 433*bd882c8aSJames Wright sycl::range<2> kernel_range(num_elem, num_qpts); 434*bd882c8aSJames Wright 435*bd882c8aSJames Wright // Order queue 436*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 437*bd882c8aSJames Wright sycl_queue.parallel_for<CeedBasisSyclWeightNT>(kernel_range, {e}, [=](sycl::id<2> indx) { 438*bd882c8aSJames Wright const CeedInt i = indx[1]; 439*bd882c8aSJames Wright const CeedInt elem = indx[0]; 440*bd882c8aSJames Wright d_V[i + elem * num_qpts] = q_weight[i]; 441*bd882c8aSJames Wright }); 442*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 443*bd882c8aSJames Wright } 444*bd882c8aSJames Wright 445*bd882c8aSJames Wright //------------------------------------------------------------------------------ 446*bd882c8aSJames Wright // Basis apply - non-tensor 447*bd882c8aSJames Wright //------------------------------------------------------------------------------ 448*bd882c8aSJames Wright static int CeedBasisApplyNonTensor_Sycl(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u, 449*bd882c8aSJames Wright CeedVector v) { 450*bd882c8aSJames Wright Ceed ceed; 451*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 452*bd882c8aSJames Wright CeedBasisNonTensor_Sycl *impl; 453*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetData(basis, &impl)); 454*bd882c8aSJames Wright Ceed_Sycl *data; 455*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 456*bd882c8aSJames Wright 457*bd882c8aSJames Wright const CeedInt transpose = t_mode == CEED_TRANSPOSE; 458*bd882c8aSJames Wright 459*bd882c8aSJames Wright // Read vectors 460*bd882c8aSJames Wright const CeedScalar *d_u; 461*bd882c8aSJames Wright CeedScalar *d_v; 462*bd882c8aSJames Wright if (eval_mode != CEED_EVAL_WEIGHT) { 463*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 464*bd882c8aSJames Wright } 465*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 466*bd882c8aSJames Wright 467*bd882c8aSJames Wright // Clear v for transpose operation 468*bd882c8aSJames Wright if (transpose) { 469*bd882c8aSJames Wright CeedSize length; 470*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(v, &length)); 471*bd882c8aSJames Wright // Order queue 472*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 473*bd882c8aSJames Wright data->sycl_queue.fill<CeedScalar>(d_v, 0, length, {e}); 474*bd882c8aSJames Wright } 475*bd882c8aSJames Wright 476*bd882c8aSJames Wright // Apply basis operation 477*bd882c8aSJames Wright switch (eval_mode) { 478*bd882c8aSJames Wright case CEED_EVAL_INTERP: { 479*bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyNonTensorInterp_Sycl(data->sycl_queue, num_elem, transpose, impl, d_u, d_v)); 480*bd882c8aSJames Wright } break; 481*bd882c8aSJames Wright case CEED_EVAL_GRAD: { 482*bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyNonTensorGrad_Sycl(data->sycl_queue, num_elem, transpose, impl, d_u, d_v)); 483*bd882c8aSJames Wright } break; 484*bd882c8aSJames Wright case CEED_EVAL_WEIGHT: { 485*bd882c8aSJames Wright CeedCallBackend(CeedBasisApplyNonTensorWeight_Sycl(data->sycl_queue, num_elem, impl, d_v)); 486*bd882c8aSJames Wright } break; 487*bd882c8aSJames Wright // LCOV_EXCL_START 488*bd882c8aSJames Wright // Evaluate the divergence to/from the quadrature points 489*bd882c8aSJames Wright case CEED_EVAL_DIV: 490*bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported"); 491*bd882c8aSJames Wright // Evaluate the curl to/from the quadrature points 492*bd882c8aSJames Wright case CEED_EVAL_CURL: 493*bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported"); 494*bd882c8aSJames Wright // Take no action, BasisApply should not have been called 495*bd882c8aSJames Wright case CEED_EVAL_NONE: 496*bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context"); 497*bd882c8aSJames Wright // LCOV_EXCL_STOP 498*bd882c8aSJames Wright } 499*bd882c8aSJames Wright 500*bd882c8aSJames Wright // Restore vectors 501*bd882c8aSJames Wright if (eval_mode != CEED_EVAL_WEIGHT) { 502*bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 503*bd882c8aSJames Wright } 504*bd882c8aSJames Wright 505*bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 506*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 507*bd882c8aSJames Wright } 508*bd882c8aSJames Wright 509*bd882c8aSJames Wright //------------------------------------------------------------------------------ 510*bd882c8aSJames Wright // Destroy tensor basis 511*bd882c8aSJames Wright //------------------------------------------------------------------------------ 512*bd882c8aSJames Wright static int CeedBasisDestroy_Sycl(CeedBasis basis) { 513*bd882c8aSJames Wright Ceed ceed; 514*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 515*bd882c8aSJames Wright CeedBasis_Sycl *impl; 516*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetData(basis, &impl)); 517*bd882c8aSJames Wright Ceed_Sycl *data; 518*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 519*bd882c8aSJames Wright 520*bd882c8aSJames Wright // Wait for all work to finish before freeing memory 521*bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 522*bd882c8aSJames Wright 523*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_q_weight_1d, data->sycl_context)); 524*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_interp_1d, data->sycl_context)); 525*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_grad_1d, data->sycl_context)); 526*bd882c8aSJames Wright 527*bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 528*bd882c8aSJames Wright 529*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 530*bd882c8aSJames Wright } 531*bd882c8aSJames Wright 532*bd882c8aSJames Wright //------------------------------------------------------------------------------ 533*bd882c8aSJames Wright // Destroy non-tensor basis 534*bd882c8aSJames Wright //------------------------------------------------------------------------------ 535*bd882c8aSJames Wright static int CeedBasisDestroyNonTensor_Sycl(CeedBasis basis) { 536*bd882c8aSJames Wright Ceed ceed; 537*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 538*bd882c8aSJames Wright CeedBasisNonTensor_Sycl *impl; 539*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetData(basis, &impl)); 540*bd882c8aSJames Wright Ceed_Sycl *data; 541*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 542*bd882c8aSJames Wright 543*bd882c8aSJames Wright // Wait for all work to finish before freeing memory 544*bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 545*bd882c8aSJames Wright 546*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_q_weight, data->sycl_context)); 547*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_interp, data->sycl_context)); 548*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_grad, data->sycl_context)); 549*bd882c8aSJames Wright 550*bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 551*bd882c8aSJames Wright 552*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 553*bd882c8aSJames Wright } 554*bd882c8aSJames Wright 555*bd882c8aSJames Wright //------------------------------------------------------------------------------ 556*bd882c8aSJames Wright // Create tensor 557*bd882c8aSJames Wright //------------------------------------------------------------------------------ 558*bd882c8aSJames Wright int CeedBasisCreateTensorH1_Sycl(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d, 559*bd882c8aSJames Wright const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) { 560*bd882c8aSJames Wright Ceed ceed; 561*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 562*bd882c8aSJames Wright CeedBasis_Sycl *impl; 563*bd882c8aSJames Wright CeedCallBackend(CeedCalloc(1, &impl)); 564*bd882c8aSJames Wright Ceed_Sycl *data; 565*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 566*bd882c8aSJames Wright 567*bd882c8aSJames Wright CeedInt num_comp; 568*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 569*bd882c8aSJames Wright 570*bd882c8aSJames Wright const CeedInt num_nodes = CeedIntPow(P_1d, dim); 571*bd882c8aSJames Wright const CeedInt num_qpts = CeedIntPow(Q_1d, dim); 572*bd882c8aSJames Wright 573*bd882c8aSJames Wright impl->dim = dim; 574*bd882c8aSJames Wright impl->P_1d = P_1d; 575*bd882c8aSJames Wright impl->Q_1d = Q_1d; 576*bd882c8aSJames Wright impl->num_comp = num_comp; 577*bd882c8aSJames Wright impl->num_nodes = num_nodes; 578*bd882c8aSJames Wright impl->num_qpts = num_qpts; 579*bd882c8aSJames Wright impl->buf_len = num_comp * CeedIntMax(num_nodes, num_qpts); 580*bd882c8aSJames Wright impl->op_len = Q_1d * P_1d; 581*bd882c8aSJames Wright 582*bd882c8aSJames Wright // Order queue 583*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 584*bd882c8aSJames Wright 585*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_q_weight_1d = sycl::malloc_device<CeedScalar>(Q_1d, data->sycl_device, data->sycl_context)); 586*bd882c8aSJames Wright sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight_1d, impl->d_q_weight_1d, Q_1d, {e}); 587*bd882c8aSJames Wright 588*bd882c8aSJames Wright const CeedInt interp_length = Q_1d * P_1d; 589*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_interp_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 590*bd882c8aSJames Wright sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp_1d, impl->d_interp_1d, interp_length, {e}); 591*bd882c8aSJames Wright 592*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_grad_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 593*bd882c8aSJames Wright sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad_1d, impl->d_grad_1d, interp_length, {e}); 594*bd882c8aSJames Wright 595*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_weight, copy_interp, copy_grad})); 596*bd882c8aSJames Wright 597*bd882c8aSJames Wright std::vector<sycl::kernel_id> kernel_ids = {sycl::get_kernel_id<CeedBasisSyclInterp<1>>(), sycl::get_kernel_id<CeedBasisSyclInterp<0>>(), 598*bd882c8aSJames Wright sycl::get_kernel_id<CeedBasisSyclGrad<1>>(), sycl::get_kernel_id<CeedBasisSyclGrad<0>>()}; 599*bd882c8aSJames Wright 600*bd882c8aSJames Wright sycl::kernel_bundle<sycl::bundle_state::input> input_bundle = sycl::get_kernel_bundle<sycl::bundle_state::input>(data->sycl_context, kernel_ids); 601*bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_DIM_ID>(dim); 602*bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_NUM_COMP_ID>(num_comp); 603*bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_Q_1D_ID>(Q_1d); 604*bd882c8aSJames Wright input_bundle.set_specialization_constant<BASIS_P_1D_ID>(P_1d); 605*bd882c8aSJames Wright 606*bd882c8aSJames Wright CeedCallSycl(ceed, impl->sycl_module = new SyclModule_t(sycl::build(input_bundle))); 607*bd882c8aSJames Wright 608*bd882c8aSJames Wright CeedCallBackend(CeedBasisSetData(basis, impl)); 609*bd882c8aSJames Wright 610*bd882c8aSJames Wright // Register backend functions 611*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApply_Sycl)); 612*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Sycl)); 613*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 614*bd882c8aSJames Wright } 615*bd882c8aSJames Wright 616*bd882c8aSJames Wright //------------------------------------------------------------------------------ 617*bd882c8aSJames Wright // Create non-tensor 618*bd882c8aSJames Wright //------------------------------------------------------------------------------ 619*bd882c8aSJames Wright int CeedBasisCreateH1_Sycl(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad, 620*bd882c8aSJames Wright const CeedScalar *qref, const CeedScalar *q_weight, CeedBasis basis) { 621*bd882c8aSJames Wright Ceed ceed; 622*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 623*bd882c8aSJames Wright CeedBasisNonTensor_Sycl *impl; 624*bd882c8aSJames Wright CeedCallBackend(CeedCalloc(1, &impl)); 625*bd882c8aSJames Wright Ceed_Sycl *data; 626*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 627*bd882c8aSJames Wright 628*bd882c8aSJames Wright CeedInt num_comp; 629*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 630*bd882c8aSJames Wright 631*bd882c8aSJames Wright impl->dim = dim; 632*bd882c8aSJames Wright impl->num_comp = num_comp; 633*bd882c8aSJames Wright impl->num_nodes = num_nodes; 634*bd882c8aSJames Wright impl->num_qpts = num_qpts; 635*bd882c8aSJames Wright 636*bd882c8aSJames Wright // Order queue 637*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 638*bd882c8aSJames Wright 639*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_q_weight = sycl::malloc_device<CeedScalar>(num_qpts, data->sycl_device, data->sycl_context)); 640*bd882c8aSJames Wright sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight, impl->d_q_weight, num_qpts, {e}); 641*bd882c8aSJames Wright 642*bd882c8aSJames Wright const CeedInt interp_length = num_qpts * num_nodes; 643*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_interp = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 644*bd882c8aSJames Wright sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp, impl->d_interp, interp_length, {e}); 645*bd882c8aSJames Wright 646*bd882c8aSJames Wright const CeedInt grad_length = num_qpts * num_nodes * dim; 647*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_grad = sycl::malloc_device<CeedScalar>(grad_length, data->sycl_device, data->sycl_context)); 648*bd882c8aSJames Wright sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad, impl->d_grad, grad_length, {e}); 649*bd882c8aSJames Wright 650*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_weight, copy_interp, copy_grad})); 651*bd882c8aSJames Wright 652*bd882c8aSJames Wright CeedCallBackend(CeedBasisSetData(basis, impl)); 653*bd882c8aSJames Wright 654*bd882c8aSJames Wright // Register backend functions 655*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Sycl)); 656*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Sycl)); 657*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 658*bd882c8aSJames Wright } 659*bd882c8aSJames Wright //------------------------------------------------------------------------------ 660