1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other 2*bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE 3*bd882c8aSJames Wright // files for details. 4*bd882c8aSJames Wright // 5*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 6*bd882c8aSJames Wright // 7*bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 8*bd882c8aSJames Wright 9*bd882c8aSJames Wright #include <ceed/backend.h> 10*bd882c8aSJames Wright #include <ceed/ceed.h> 11*bd882c8aSJames Wright #include <ceed/jit-tools.h> 12*bd882c8aSJames Wright 13*bd882c8aSJames Wright #include <string> 14*bd882c8aSJames Wright #include <sycl/sycl.hpp> 15*bd882c8aSJames Wright 16*bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 17*bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 18*bd882c8aSJames Wright 19*bd882c8aSJames Wright class CeedElemRestrSyclStridedNT; 20*bd882c8aSJames Wright class CeedElemRestrSyclOffsetNT; 21*bd882c8aSJames Wright class CeedElemRestrSyclStridedT; 22*bd882c8aSJames Wright class CeedElemRestrSyclOffsetT; 23*bd882c8aSJames Wright 24*bd882c8aSJames Wright //------------------------------------------------------------------------------ 25*bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, strided 26*bd882c8aSJames Wright //------------------------------------------------------------------------------ 27*bd882c8aSJames Wright static int CeedElemRestrictionStridedNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 28*bd882c8aSJames Wright CeedScalar *v) { 29*bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 30*bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 31*bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 32*bd882c8aSJames Wright const CeedInt stride_nodes = impl->strides[0]; 33*bd882c8aSJames Wright const CeedInt stride_comp = impl->strides[1]; 34*bd882c8aSJames Wright const CeedInt stride_elem = impl->strides[2]; 35*bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 36*bd882c8aSJames Wright 37*bd882c8aSJames Wright // Order queue 38*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 39*bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclStridedNT>(kernel_range, {e}, [=](sycl::id<1> node) { 40*bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 41*bd882c8aSJames Wright const CeedInt elem = node / elem_size; 42*bd882c8aSJames Wright 43*bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 44*bd882c8aSJames Wright v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem]; 45*bd882c8aSJames Wright } 46*bd882c8aSJames Wright }); 47*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 48*bd882c8aSJames Wright } 49*bd882c8aSJames Wright 50*bd882c8aSJames Wright //------------------------------------------------------------------------------ 51*bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, offsets provided 52*bd882c8aSJames Wright //------------------------------------------------------------------------------ 53*bd882c8aSJames Wright static int CeedElemRestrictionOffsetNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 54*bd882c8aSJames Wright CeedScalar *v) { 55*bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 56*bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 57*bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 58*bd882c8aSJames Wright const CeedInt comp_stride = impl->comp_stride; 59*bd882c8aSJames Wright 60*bd882c8aSJames Wright const CeedInt *indices = impl->d_ind; 61*bd882c8aSJames Wright 62*bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 63*bd882c8aSJames Wright 64*bd882c8aSJames Wright // Order queue 65*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 66*bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclOffsetNT>(kernel_range, {e}, [=](sycl::id<1> node) { 67*bd882c8aSJames Wright const CeedInt ind = indices[node]; 68*bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 69*bd882c8aSJames Wright const CeedInt elem = node / elem_size; 70*bd882c8aSJames Wright 71*bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 72*bd882c8aSJames Wright v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[ind + comp * comp_stride]; 73*bd882c8aSJames Wright } 74*bd882c8aSJames Wright }); 75*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 76*bd882c8aSJames Wright } 77*bd882c8aSJames Wright 78*bd882c8aSJames Wright //------------------------------------------------------------------------------ 79*bd882c8aSJames Wright // Kernel: E-vector -> L-vector, strided 80*bd882c8aSJames Wright //------------------------------------------------------------------------------ 81*bd882c8aSJames Wright static int CeedElemRestrictionStridedTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 82*bd882c8aSJames Wright CeedScalar *v) { 83*bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 84*bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 85*bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 86*bd882c8aSJames Wright const CeedInt stride_nodes = impl->strides[0]; 87*bd882c8aSJames Wright const CeedInt stride_comp = impl->strides[1]; 88*bd882c8aSJames Wright const CeedInt stride_elem = impl->strides[2]; 89*bd882c8aSJames Wright 90*bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 91*bd882c8aSJames Wright 92*bd882c8aSJames Wright // Order queue 93*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 94*bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclStridedT>(kernel_range, {e}, [=](sycl::id<1> node) { 95*bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 96*bd882c8aSJames Wright const CeedInt elem = node / elem_size; 97*bd882c8aSJames Wright 98*bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 99*bd882c8aSJames Wright v[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem] += u[loc_node + comp * elem_size * num_elem + elem * elem_size]; 100*bd882c8aSJames Wright } 101*bd882c8aSJames Wright }); 102*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 103*bd882c8aSJames Wright } 104*bd882c8aSJames Wright 105*bd882c8aSJames Wright //------------------------------------------------------------------------------ 106*bd882c8aSJames Wright // Kernel: E-vector -> L-vector, offsets provided 107*bd882c8aSJames Wright //------------------------------------------------------------------------------ 108*bd882c8aSJames Wright static int CeedElemRestrictionOffsetTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 109*bd882c8aSJames Wright CeedScalar *v) { 110*bd882c8aSJames Wright const CeedInt num_nodes = impl->num_nodes; 111*bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 112*bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 113*bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 114*bd882c8aSJames Wright const CeedInt comp_stride = impl->comp_stride; 115*bd882c8aSJames Wright 116*bd882c8aSJames Wright const CeedInt *l_vec_indices = impl->d_l_vec_indices; 117*bd882c8aSJames Wright const CeedInt *t_offsets = impl->d_t_offsets; 118*bd882c8aSJames Wright const CeedInt *t_indices = impl->d_t_indices; 119*bd882c8aSJames Wright 120*bd882c8aSJames Wright sycl::range<1> kernel_range(num_nodes * num_comp); 121*bd882c8aSJames Wright 122*bd882c8aSJames Wright // Order queue 123*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 124*bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclOffsetT>(kernel_range, {e}, [=](sycl::id<1> id) { 125*bd882c8aSJames Wright const CeedInt node = id % num_nodes; 126*bd882c8aSJames Wright const CeedInt comp = id / num_nodes; 127*bd882c8aSJames Wright const CeedInt ind = l_vec_indices[node]; 128*bd882c8aSJames Wright const CeedInt range_1 = t_offsets[node]; 129*bd882c8aSJames Wright const CeedInt range_N = t_offsets[node + 1]; 130*bd882c8aSJames Wright 131*bd882c8aSJames Wright CeedScalar value = 0.0; 132*bd882c8aSJames Wright 133*bd882c8aSJames Wright for (CeedInt j = range_1; j < range_N; j++) { 134*bd882c8aSJames Wright const CeedInt t_ind = t_indices[j]; 135*bd882c8aSJames Wright CeedInt loc_node = t_ind % elem_size; 136*bd882c8aSJames Wright CeedInt elem = t_ind / elem_size; 137*bd882c8aSJames Wright 138*bd882c8aSJames Wright value += u[loc_node + comp * elem_size * num_elem + elem * elem_size]; 139*bd882c8aSJames Wright } 140*bd882c8aSJames Wright v[ind + comp * comp_stride] += value; 141*bd882c8aSJames Wright }); 142*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 143*bd882c8aSJames Wright } 144*bd882c8aSJames Wright 145*bd882c8aSJames Wright //------------------------------------------------------------------------------ 146*bd882c8aSJames Wright // Apply restriction 147*bd882c8aSJames Wright //------------------------------------------------------------------------------ 148*bd882c8aSJames Wright static int CeedElemRestrictionApply_Sycl(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 149*bd882c8aSJames Wright Ceed ceed; 150*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 151*bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 152*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 153*bd882c8aSJames Wright Ceed_Sycl *data; 154*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 155*bd882c8aSJames Wright 156*bd882c8aSJames Wright // Get vectors 157*bd882c8aSJames Wright const CeedScalar *d_u; 158*bd882c8aSJames Wright CeedScalar *d_v; 159*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 160*bd882c8aSJames Wright if (t_mode == CEED_TRANSPOSE) { 161*bd882c8aSJames Wright // Sum into for transpose mode, e-vec to l-vec 162*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 163*bd882c8aSJames Wright } else { 164*bd882c8aSJames Wright // Overwrite for notranspose mode, l-vec to e-vec 165*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 166*bd882c8aSJames Wright } 167*bd882c8aSJames Wright 168*bd882c8aSJames Wright // Restrict 169*bd882c8aSJames Wright if (t_mode == CEED_NOTRANSPOSE) { 170*bd882c8aSJames Wright // L-vector -> E-vector 171*bd882c8aSJames Wright if (impl->d_ind) { 172*bd882c8aSJames Wright // -- Offsets provided 173*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffsetNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 174*bd882c8aSJames Wright } else { 175*bd882c8aSJames Wright // -- Strided restriction 176*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionStridedNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 177*bd882c8aSJames Wright } 178*bd882c8aSJames Wright } else { 179*bd882c8aSJames Wright // E-vector -> L-vector 180*bd882c8aSJames Wright if (impl->d_ind) { 181*bd882c8aSJames Wright // -- Offsets provided 182*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffsetTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 183*bd882c8aSJames Wright } else { 184*bd882c8aSJames Wright // -- Strided restriction 185*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionStridedTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 186*bd882c8aSJames Wright } 187*bd882c8aSJames Wright } 188*bd882c8aSJames Wright // Wait for queues to be completed. NOTE: This may not be necessary 189*bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 190*bd882c8aSJames Wright 191*bd882c8aSJames Wright if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 192*bd882c8aSJames Wright 193*bd882c8aSJames Wright // Restore arrays 194*bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 195*bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 196*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 197*bd882c8aSJames Wright } 198*bd882c8aSJames Wright 199*bd882c8aSJames Wright //------------------------------------------------------------------------------ 200*bd882c8aSJames Wright // Get offsets 201*bd882c8aSJames Wright //------------------------------------------------------------------------------ 202*bd882c8aSJames Wright static int CeedElemRestrictionGetOffsets_Sycl(CeedElemRestriction r, CeedMemType m_type, const CeedInt **offsets) { 203*bd882c8aSJames Wright Ceed ceed; 204*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 205*bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 206*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 207*bd882c8aSJames Wright 208*bd882c8aSJames Wright switch (m_type) { 209*bd882c8aSJames Wright case CEED_MEM_HOST: 210*bd882c8aSJames Wright *offsets = impl->h_ind; 211*bd882c8aSJames Wright break; 212*bd882c8aSJames Wright case CEED_MEM_DEVICE: 213*bd882c8aSJames Wright *offsets = impl->d_ind; 214*bd882c8aSJames Wright break; 215*bd882c8aSJames Wright } 216*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 217*bd882c8aSJames Wright } 218*bd882c8aSJames Wright 219*bd882c8aSJames Wright //------------------------------------------------------------------------------ 220*bd882c8aSJames Wright // Destroy restriction 221*bd882c8aSJames Wright //------------------------------------------------------------------------------ 222*bd882c8aSJames Wright static int CeedElemRestrictionDestroy_Sycl(CeedElemRestriction r) { 223*bd882c8aSJames Wright Ceed ceed; 224*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 225*bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 226*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 227*bd882c8aSJames Wright Ceed_Sycl *data; 228*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 229*bd882c8aSJames Wright 230*bd882c8aSJames Wright // Wait for all work to finish before freeing memory 231*bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 232*bd882c8aSJames Wright 233*bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 234*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_ind_allocated, data->sycl_context)); 235*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_t_offsets, data->sycl_context)); 236*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_t_indices, data->sycl_context)); 237*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_l_vec_indices, data->sycl_context)); 238*bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 239*bd882c8aSJames Wright 240*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 241*bd882c8aSJames Wright } 242*bd882c8aSJames Wright 243*bd882c8aSJames Wright //------------------------------------------------------------------------------ 244*bd882c8aSJames Wright // Create transpose offsets and indices 245*bd882c8aSJames Wright //------------------------------------------------------------------------------ 246*bd882c8aSJames Wright static int CeedElemRestrictionOffset_Sycl(const CeedElemRestriction r, const CeedInt *indices) { 247*bd882c8aSJames Wright Ceed ceed; 248*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 249*bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 250*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 251*bd882c8aSJames Wright CeedSize l_size; 252*bd882c8aSJames Wright CeedInt num_elem, elem_size, num_comp; 253*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 254*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 255*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size)); 256*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 257*bd882c8aSJames Wright 258*bd882c8aSJames Wright // Count num_nodes 259*bd882c8aSJames Wright bool *is_node; 260*bd882c8aSJames Wright CeedCallBackend(CeedCalloc(l_size, &is_node)); 261*bd882c8aSJames Wright const CeedInt size_indices = num_elem * elem_size; 262*bd882c8aSJames Wright for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 263*bd882c8aSJames Wright CeedInt num_nodes = 0; 264*bd882c8aSJames Wright for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 265*bd882c8aSJames Wright impl->num_nodes = num_nodes; 266*bd882c8aSJames Wright 267*bd882c8aSJames Wright // L-vector offsets array 268*bd882c8aSJames Wright CeedInt *ind_to_offset, *l_vec_indices; 269*bd882c8aSJames Wright CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 270*bd882c8aSJames Wright CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 271*bd882c8aSJames Wright CeedInt j = 0; 272*bd882c8aSJames Wright for (CeedInt i = 0; i < l_size; i++) { 273*bd882c8aSJames Wright if (is_node[i]) { 274*bd882c8aSJames Wright l_vec_indices[j] = i; 275*bd882c8aSJames Wright ind_to_offset[i] = j++; 276*bd882c8aSJames Wright } 277*bd882c8aSJames Wright } 278*bd882c8aSJames Wright CeedCallBackend(CeedFree(&is_node)); 279*bd882c8aSJames Wright 280*bd882c8aSJames Wright // Compute transpose offsets and indices 281*bd882c8aSJames Wright const CeedInt size_offsets = num_nodes + 1; 282*bd882c8aSJames Wright CeedInt *t_offsets; 283*bd882c8aSJames Wright CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 284*bd882c8aSJames Wright CeedInt *t_indices; 285*bd882c8aSJames Wright CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 286*bd882c8aSJames Wright // Count node multiplicity 287*bd882c8aSJames Wright for (CeedInt e = 0; e < num_elem; ++e) { 288*bd882c8aSJames Wright for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 289*bd882c8aSJames Wright } 290*bd882c8aSJames Wright // Convert to running sum 291*bd882c8aSJames Wright for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 292*bd882c8aSJames Wright // List all E-vec indices associated with L-vec node 293*bd882c8aSJames Wright for (CeedInt e = 0; e < num_elem; ++e) { 294*bd882c8aSJames Wright for (CeedInt i = 0; i < elem_size; ++i) { 295*bd882c8aSJames Wright const CeedInt lid = elem_size * e + i; 296*bd882c8aSJames Wright const CeedInt gid = indices[lid]; 297*bd882c8aSJames Wright t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 298*bd882c8aSJames Wright } 299*bd882c8aSJames Wright } 300*bd882c8aSJames Wright // Reset running sum 301*bd882c8aSJames Wright for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 302*bd882c8aSJames Wright t_offsets[0] = 0; 303*bd882c8aSJames Wright 304*bd882c8aSJames Wright // Copy data to device 305*bd882c8aSJames Wright Ceed_Sycl *data; 306*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 307*bd882c8aSJames Wright 308*bd882c8aSJames Wright // Order queue 309*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 310*bd882c8aSJames Wright 311*bd882c8aSJames Wright // -- L-vector indices 312*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_l_vec_indices = sycl::malloc_device<CeedInt>(num_nodes, data->sycl_device, data->sycl_context)); 313*bd882c8aSJames Wright sycl::event copy_lvec = data->sycl_queue.copy<CeedInt>(l_vec_indices, impl->d_l_vec_indices, num_nodes, {e}); 314*bd882c8aSJames Wright // -- Transpose offsets 315*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_t_offsets = sycl::malloc_device<CeedInt>(size_offsets, data->sycl_device, data->sycl_context)); 316*bd882c8aSJames Wright sycl::event copy_offsets = data->sycl_queue.copy<CeedInt>(t_offsets, impl->d_t_offsets, size_offsets, {e}); 317*bd882c8aSJames Wright // -- Transpose indices 318*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_t_indices = sycl::malloc_device<CeedInt>(size_indices, data->sycl_device, data->sycl_context)); 319*bd882c8aSJames Wright sycl::event copy_indices = data->sycl_queue.copy<CeedInt>(t_indices, impl->d_t_indices, size_indices, {e}); 320*bd882c8aSJames Wright 321*bd882c8aSJames Wright // Wait for all copies to complete and handle exceptions 322*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_lvec, copy_offsets, copy_indices})); 323*bd882c8aSJames Wright 324*bd882c8aSJames Wright // Cleanup 325*bd882c8aSJames Wright CeedCallBackend(CeedFree(&ind_to_offset)); 326*bd882c8aSJames Wright CeedCallBackend(CeedFree(&l_vec_indices)); 327*bd882c8aSJames Wright CeedCallBackend(CeedFree(&t_offsets)); 328*bd882c8aSJames Wright CeedCallBackend(CeedFree(&t_indices)); 329*bd882c8aSJames Wright 330*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 331*bd882c8aSJames Wright } 332*bd882c8aSJames Wright 333*bd882c8aSJames Wright //------------------------------------------------------------------------------ 334*bd882c8aSJames Wright // Create restriction 335*bd882c8aSJames Wright //------------------------------------------------------------------------------ 336*bd882c8aSJames Wright int CeedElemRestrictionCreate_Sycl(CeedMemType m_type, CeedCopyMode copy_mode, const CeedInt *indices, CeedElemRestriction r) { 337*bd882c8aSJames Wright Ceed ceed; 338*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 339*bd882c8aSJames Wright Ceed_Sycl *data; 340*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 341*bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 342*bd882c8aSJames Wright CeedCallBackend(CeedCalloc(1, &impl)); 343*bd882c8aSJames Wright CeedInt num_elem, num_comp, elem_size; 344*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 345*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 346*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 347*bd882c8aSJames Wright CeedInt size = num_elem * elem_size; 348*bd882c8aSJames Wright CeedInt strides[3] = {1, size, elem_size}; 349*bd882c8aSJames Wright CeedInt comp_stride = 1; 350*bd882c8aSJames Wright 351*bd882c8aSJames Wright // Stride data 352*bd882c8aSJames Wright bool is_strided; 353*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided)); 354*bd882c8aSJames Wright if (is_strided) { 355*bd882c8aSJames Wright bool has_backend_strides; 356*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides)); 357*bd882c8aSJames Wright if (!has_backend_strides) { 358*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides)); 359*bd882c8aSJames Wright } 360*bd882c8aSJames Wright } else { 361*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride)); 362*bd882c8aSJames Wright } 363*bd882c8aSJames Wright 364*bd882c8aSJames Wright impl->h_ind = NULL; 365*bd882c8aSJames Wright impl->h_ind_allocated = NULL; 366*bd882c8aSJames Wright impl->d_ind = NULL; 367*bd882c8aSJames Wright impl->d_ind_allocated = NULL; 368*bd882c8aSJames Wright impl->d_t_indices = NULL; 369*bd882c8aSJames Wright impl->d_t_offsets = NULL; 370*bd882c8aSJames Wright impl->num_nodes = size; 371*bd882c8aSJames Wright impl->num_elem = num_elem; 372*bd882c8aSJames Wright impl->num_comp = num_comp; 373*bd882c8aSJames Wright impl->elem_size = elem_size; 374*bd882c8aSJames Wright impl->comp_stride = comp_stride; 375*bd882c8aSJames Wright impl->strides[0] = strides[0]; 376*bd882c8aSJames Wright impl->strides[1] = strides[1]; 377*bd882c8aSJames Wright impl->strides[2] = strides[2]; 378*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionSetData(r, impl)); 379*bd882c8aSJames Wright CeedInt layout[3] = {1, elem_size * num_elem, elem_size}; 380*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionSetELayout(r, layout)); 381*bd882c8aSJames Wright 382*bd882c8aSJames Wright // Set up device indices/offset arrays 383*bd882c8aSJames Wright if (m_type == CEED_MEM_HOST) { 384*bd882c8aSJames Wright switch (copy_mode) { 385*bd882c8aSJames Wright case CEED_OWN_POINTER: 386*bd882c8aSJames Wright impl->h_ind_allocated = (CeedInt *)indices; 387*bd882c8aSJames Wright impl->h_ind = (CeedInt *)indices; 388*bd882c8aSJames Wright break; 389*bd882c8aSJames Wright case CEED_USE_POINTER: 390*bd882c8aSJames Wright impl->h_ind = (CeedInt *)indices; 391*bd882c8aSJames Wright break; 392*bd882c8aSJames Wright case CEED_COPY_VALUES: 393*bd882c8aSJames Wright if (indices != NULL) { 394*bd882c8aSJames Wright CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 395*bd882c8aSJames Wright memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt)); 396*bd882c8aSJames Wright impl->h_ind = impl->h_ind_allocated; 397*bd882c8aSJames Wright } 398*bd882c8aSJames Wright break; 399*bd882c8aSJames Wright } 400*bd882c8aSJames Wright if (indices != NULL) { 401*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context)); 402*bd882c8aSJames Wright impl->d_ind_allocated = impl->d_ind; // We own the device memory 403*bd882c8aSJames Wright // Order queue 404*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 405*bd882c8aSJames Wright // Copy from host to device 406*bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e}); 407*bd882c8aSJames Wright // Wait for copy to finish and handle exceptions 408*bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 409*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffset_Sycl(r, indices)); 410*bd882c8aSJames Wright } 411*bd882c8aSJames Wright } else if (m_type == CEED_MEM_DEVICE) { 412*bd882c8aSJames Wright switch (copy_mode) { 413*bd882c8aSJames Wright case CEED_COPY_VALUES: 414*bd882c8aSJames Wright if (indices != NULL) { 415*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context)); 416*bd882c8aSJames Wright impl->d_ind_allocated = impl->d_ind; // We own the device memory 417*bd882c8aSJames Wright // Copy from device to device 418*bd882c8aSJames Wright // Order queue 419*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 420*bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e}); 421*bd882c8aSJames Wright // Wait for copy to finish and handle exceptions 422*bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 423*bd882c8aSJames Wright } 424*bd882c8aSJames Wright break; 425*bd882c8aSJames Wright case CEED_OWN_POINTER: 426*bd882c8aSJames Wright impl->d_ind = (CeedInt *)indices; 427*bd882c8aSJames Wright impl->d_ind_allocated = impl->d_ind; 428*bd882c8aSJames Wright break; 429*bd882c8aSJames Wright case CEED_USE_POINTER: 430*bd882c8aSJames Wright impl->d_ind = (CeedInt *)indices; 431*bd882c8aSJames Wright } 432*bd882c8aSJames Wright if (indices != NULL) { 433*bd882c8aSJames Wright CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 434*bd882c8aSJames Wright // Order queue 435*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 436*bd882c8aSJames Wright // Copy from device to host 437*bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->d_ind, impl->h_ind_allocated, elem_size * num_elem, {e}); 438*bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 439*bd882c8aSJames Wright impl->h_ind = impl->h_ind_allocated; 440*bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffset_Sycl(r, indices)); 441*bd882c8aSJames Wright } 442*bd882c8aSJames Wright } else { 443*bd882c8aSJames Wright // LCOV_EXCL_START 444*bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported"); 445*bd882c8aSJames Wright // LCOV_EXCL_STOP 446*bd882c8aSJames Wright } 447*bd882c8aSJames Wright 448*bd882c8aSJames Wright // Register backend functions 449*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Sycl)); 450*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Sycl)); 451*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Sycl)); 452*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Sycl)); 453*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 454*bd882c8aSJames Wright } 455