1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other 2bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE 3bd882c8aSJames Wright // files for details. 4bd882c8aSJames Wright // 5bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 6bd882c8aSJames Wright // 7bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 8bd882c8aSJames Wright 9bd882c8aSJames Wright #include <ceed/backend.h> 10bd882c8aSJames Wright #include <ceed/ceed.h> 11bd882c8aSJames Wright #include <ceed/jit-tools.h> 12bd882c8aSJames Wright 13bd882c8aSJames Wright #include <string> 14bd882c8aSJames Wright #include <sycl/sycl.hpp> 15bd882c8aSJames Wright 16bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 17bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 18bd882c8aSJames Wright 19bd882c8aSJames Wright class CeedElemRestrSyclStridedNT; 20bd882c8aSJames Wright class CeedElemRestrSyclOffsetNT; 21bd882c8aSJames Wright class CeedElemRestrSyclStridedT; 22bd882c8aSJames Wright class CeedElemRestrSyclOffsetT; 23bd882c8aSJames Wright 24bd882c8aSJames Wright //------------------------------------------------------------------------------ 25bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, strided 26bd882c8aSJames Wright //------------------------------------------------------------------------------ 27bd882c8aSJames Wright static int CeedElemRestrictionStridedNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 28bd882c8aSJames Wright CeedScalar *v) { 29bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 30bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 31bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 32bd882c8aSJames Wright const CeedInt stride_nodes = impl->strides[0]; 33bd882c8aSJames Wright const CeedInt stride_comp = impl->strides[1]; 34bd882c8aSJames Wright const CeedInt stride_elem = impl->strides[2]; 35bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 36bd882c8aSJames Wright 371f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 381f4b1b45SUmesh Unnikrishnan 391f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 401f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedElemRestrSyclStridedNT>(kernel_range, e, [=](sycl::id<1> node) { 41bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 42bd882c8aSJames Wright const CeedInt elem = node / elem_size; 43bd882c8aSJames Wright 44bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 45bd882c8aSJames Wright v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem]; 46bd882c8aSJames Wright } 47bd882c8aSJames Wright }); 48bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 49bd882c8aSJames Wright } 50bd882c8aSJames Wright 51bd882c8aSJames Wright //------------------------------------------------------------------------------ 52bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, offsets provided 53bd882c8aSJames Wright //------------------------------------------------------------------------------ 54bd882c8aSJames Wright static int CeedElemRestrictionOffsetNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 55bd882c8aSJames Wright CeedScalar *v) { 56bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 57bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 58bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 59bd882c8aSJames Wright const CeedInt comp_stride = impl->comp_stride; 60f59ebe5eSJeremy L Thompson const CeedInt *indices = impl->d_offsets; 61bd882c8aSJames Wright 62bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 63bd882c8aSJames Wright 641f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 651f4b1b45SUmesh Unnikrishnan 661f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 671f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedElemRestrSyclOffsetNT>(kernel_range, e, [=](sycl::id<1> node) { 68bd882c8aSJames Wright const CeedInt ind = indices[node]; 69bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 70bd882c8aSJames Wright const CeedInt elem = node / elem_size; 71bd882c8aSJames Wright 72bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 73bd882c8aSJames Wright v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[ind + comp * comp_stride]; 74bd882c8aSJames Wright } 75bd882c8aSJames Wright }); 76bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 77bd882c8aSJames Wright } 78bd882c8aSJames Wright 79bd882c8aSJames Wright //------------------------------------------------------------------------------ 80bd882c8aSJames Wright // Kernel: E-vector -> L-vector, strided 81bd882c8aSJames Wright //------------------------------------------------------------------------------ 82bd882c8aSJames Wright static int CeedElemRestrictionStridedTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 83bd882c8aSJames Wright CeedScalar *v) { 84bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 85bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 86bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 87bd882c8aSJames Wright const CeedInt stride_nodes = impl->strides[0]; 88bd882c8aSJames Wright const CeedInt stride_comp = impl->strides[1]; 89bd882c8aSJames Wright const CeedInt stride_elem = impl->strides[2]; 90bd882c8aSJames Wright 91bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 92bd882c8aSJames Wright 931f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 941f4b1b45SUmesh Unnikrishnan 951f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 961f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedElemRestrSyclStridedT>(kernel_range, e, [=](sycl::id<1> node) { 97bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 98bd882c8aSJames Wright const CeedInt elem = node / elem_size; 99bd882c8aSJames Wright 100bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 101bd882c8aSJames Wright v[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem] += u[loc_node + comp * elem_size * num_elem + elem * elem_size]; 102bd882c8aSJames Wright } 103bd882c8aSJames Wright }); 104bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 105bd882c8aSJames Wright } 106bd882c8aSJames Wright 107bd882c8aSJames Wright //------------------------------------------------------------------------------ 108bd882c8aSJames Wright // Kernel: E-vector -> L-vector, offsets provided 109bd882c8aSJames Wright //------------------------------------------------------------------------------ 110bd882c8aSJames Wright static int CeedElemRestrictionOffsetTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 111bd882c8aSJames Wright CeedScalar *v) { 112bd882c8aSJames Wright const CeedInt num_nodes = impl->num_nodes; 113bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 114bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 115bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 116bd882c8aSJames Wright const CeedInt comp_stride = impl->comp_stride; 117bd882c8aSJames Wright const CeedInt *l_vec_indices = impl->d_l_vec_indices; 118bd882c8aSJames Wright const CeedInt *t_offsets = impl->d_t_offsets; 119bd882c8aSJames Wright const CeedInt *t_indices = impl->d_t_indices; 120bd882c8aSJames Wright 121bd882c8aSJames Wright sycl::range<1> kernel_range(num_nodes * num_comp); 122bd882c8aSJames Wright 1231f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 1241f4b1b45SUmesh Unnikrishnan 1251f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 1261f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for<CeedElemRestrSyclOffsetT>(kernel_range, e, [=](sycl::id<1> id) { 127bd882c8aSJames Wright const CeedInt node = id % num_nodes; 128bd882c8aSJames Wright const CeedInt comp = id / num_nodes; 129bd882c8aSJames Wright const CeedInt ind = l_vec_indices[node]; 130bd882c8aSJames Wright const CeedInt range_1 = t_offsets[node]; 131bd882c8aSJames Wright const CeedInt range_N = t_offsets[node + 1]; 132bd882c8aSJames Wright CeedScalar value = 0.0; 133bd882c8aSJames Wright 134bd882c8aSJames Wright for (CeedInt j = range_1; j < range_N; j++) { 135bd882c8aSJames Wright const CeedInt t_ind = t_indices[j]; 136bd882c8aSJames Wright CeedInt loc_node = t_ind % elem_size; 137bd882c8aSJames Wright CeedInt elem = t_ind / elem_size; 138bd882c8aSJames Wright 139bd882c8aSJames Wright value += u[loc_node + comp * elem_size * num_elem + elem * elem_size]; 140bd882c8aSJames Wright } 141bd882c8aSJames Wright v[ind + comp * comp_stride] += value; 142bd882c8aSJames Wright }); 143bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 144bd882c8aSJames Wright } 145bd882c8aSJames Wright 146bd882c8aSJames Wright //------------------------------------------------------------------------------ 147bd882c8aSJames Wright // Apply restriction 148bd882c8aSJames Wright //------------------------------------------------------------------------------ 149dce49693SSebastian Grimberg static int CeedElemRestrictionApply_Sycl(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 150bd882c8aSJames Wright Ceed ceed; 151bd882c8aSJames Wright Ceed_Sycl *data; 152dd64fc84SJeremy L Thompson const CeedScalar *d_u; 153dd64fc84SJeremy L Thompson CeedScalar *d_v; 154dd64fc84SJeremy L Thompson CeedElemRestriction_Sycl *impl; 155dd64fc84SJeremy L Thompson 156dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 157dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 158bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 159bd882c8aSJames Wright 160bd882c8aSJames Wright // Get vectors 161bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 162bd882c8aSJames Wright if (t_mode == CEED_TRANSPOSE) { 163bd882c8aSJames Wright // Sum into for transpose mode, e-vec to l-vec 164bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 165bd882c8aSJames Wright } else { 166bd882c8aSJames Wright // Overwrite for notranspose mode, l-vec to e-vec 167bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 168bd882c8aSJames Wright } 169bd882c8aSJames Wright 170bd882c8aSJames Wright // Restrict 171bd882c8aSJames Wright if (t_mode == CEED_NOTRANSPOSE) { 172bd882c8aSJames Wright // L-vector -> E-vector 173f59ebe5eSJeremy L Thompson if (impl->d_offsets) { 174bd882c8aSJames Wright // -- Offsets provided 175bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffsetNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 176bd882c8aSJames Wright } else { 177bd882c8aSJames Wright // -- Strided restriction 178bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionStridedNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 179bd882c8aSJames Wright } 180bd882c8aSJames Wright } else { 181bd882c8aSJames Wright // E-vector -> L-vector 182f59ebe5eSJeremy L Thompson if (impl->d_offsets) { 183bd882c8aSJames Wright // -- Offsets provided 184bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffsetTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 185bd882c8aSJames Wright } else { 186bd882c8aSJames Wright // -- Strided restriction 187bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionStridedTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 188bd882c8aSJames Wright } 189bd882c8aSJames Wright } 190bd882c8aSJames Wright // Wait for queues to be completed. NOTE: This may not be necessary 191bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 192bd882c8aSJames Wright 193bd882c8aSJames Wright if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 194bd882c8aSJames Wright 195bd882c8aSJames Wright // Restore arrays 196bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 197bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 1989bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 199bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 200bd882c8aSJames Wright } 201bd882c8aSJames Wright 202bd882c8aSJames Wright //------------------------------------------------------------------------------ 203bd882c8aSJames Wright // Get offsets 204bd882c8aSJames Wright //------------------------------------------------------------------------------ 205dce49693SSebastian Grimberg static int CeedElemRestrictionGetOffsets_Sycl(CeedElemRestriction rstr, CeedMemType m_type, const CeedInt **offsets) { 206bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 207dd64fc84SJeremy L Thompson 208dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 209bd882c8aSJames Wright 210bd882c8aSJames Wright switch (m_type) { 211bd882c8aSJames Wright case CEED_MEM_HOST: 212f59ebe5eSJeremy L Thompson *offsets = impl->h_offsets; 213bd882c8aSJames Wright break; 214bd882c8aSJames Wright case CEED_MEM_DEVICE: 215f59ebe5eSJeremy L Thompson *offsets = impl->d_offsets; 216bd882c8aSJames Wright break; 217bd882c8aSJames Wright } 218bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 219bd882c8aSJames Wright } 220bd882c8aSJames Wright 221bd882c8aSJames Wright //------------------------------------------------------------------------------ 222bd882c8aSJames Wright // Destroy restriction 223bd882c8aSJames Wright //------------------------------------------------------------------------------ 224dce49693SSebastian Grimberg static int CeedElemRestrictionDestroy_Sycl(CeedElemRestriction rstr) { 225bd882c8aSJames Wright Ceed ceed; 226bd882c8aSJames Wright Ceed_Sycl *data; 227dd64fc84SJeremy L Thompson CeedElemRestriction_Sycl *impl; 228dd64fc84SJeremy L Thompson 229dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 230dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 231bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 232bd882c8aSJames Wright 233bd882c8aSJames Wright // Wait for all work to finish before freeing memory 234bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 235bd882c8aSJames Wright 236f59ebe5eSJeremy L Thompson CeedCallBackend(CeedFree(&impl->h_offsets_owned)); 237f59ebe5eSJeremy L Thompson CeedCallSycl(ceed, sycl::free(impl->d_offsets_owned, data->sycl_context)); 238bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_t_offsets, data->sycl_context)); 239bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_t_indices, data->sycl_context)); 240bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_l_vec_indices, data->sycl_context)); 241bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 2429bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 243bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 244bd882c8aSJames Wright } 245bd882c8aSJames Wright 246bd882c8aSJames Wright //------------------------------------------------------------------------------ 247bd882c8aSJames Wright // Create transpose offsets and indices 248bd882c8aSJames Wright //------------------------------------------------------------------------------ 249dce49693SSebastian Grimberg static int CeedElemRestrictionOffset_Sycl(const CeedElemRestriction rstr, const CeedInt *indices) { 250bd882c8aSJames Wright Ceed ceed; 251dd64fc84SJeremy L Thompson Ceed_Sycl *data; 252dd64fc84SJeremy L Thompson bool *is_node; 253bd882c8aSJames Wright CeedSize l_size; 254dd64fc84SJeremy L Thompson CeedInt num_elem, elem_size, num_comp, num_nodes = 0, *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices; 255dd64fc84SJeremy L Thompson CeedElemRestriction_Sycl *impl; 256dd64fc84SJeremy L Thompson 257dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 258dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 259dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 260dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 261dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 262dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 263bd882c8aSJames Wright 264bd882c8aSJames Wright // Count num_nodes 265bd882c8aSJames Wright CeedCallBackend(CeedCalloc(l_size, &is_node)); 266bd882c8aSJames Wright const CeedInt size_indices = num_elem * elem_size; 267dd64fc84SJeremy L Thompson 268bd882c8aSJames Wright for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 269bd882c8aSJames Wright for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 270bd882c8aSJames Wright impl->num_nodes = num_nodes; 271bd882c8aSJames Wright 272bd882c8aSJames Wright // L-vector offsets array 273bd882c8aSJames Wright CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 274bd882c8aSJames Wright CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 275dd64fc84SJeremy L Thompson for (CeedInt i = 0, j = 0; i < l_size; i++) { 276bd882c8aSJames Wright if (is_node[i]) { 277bd882c8aSJames Wright l_vec_indices[j] = i; 278bd882c8aSJames Wright ind_to_offset[i] = j++; 279bd882c8aSJames Wright } 280bd882c8aSJames Wright } 281bd882c8aSJames Wright CeedCallBackend(CeedFree(&is_node)); 282bd882c8aSJames Wright 283bd882c8aSJames Wright // Compute transpose offsets and indices 284bd882c8aSJames Wright const CeedInt size_offsets = num_nodes + 1; 285dd64fc84SJeremy L Thompson 286bd882c8aSJames Wright CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 287bd882c8aSJames Wright CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 288bd882c8aSJames Wright // Count node multiplicity 289bd882c8aSJames Wright for (CeedInt e = 0; e < num_elem; ++e) { 290bd882c8aSJames Wright for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 291bd882c8aSJames Wright } 292bd882c8aSJames Wright // Convert to running sum 293bd882c8aSJames Wright for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 294bd882c8aSJames Wright // List all E-vec indices associated with L-vec node 295bd882c8aSJames Wright for (CeedInt e = 0; e < num_elem; ++e) { 296bd882c8aSJames Wright for (CeedInt i = 0; i < elem_size; ++i) { 297bd882c8aSJames Wright const CeedInt lid = elem_size * e + i; 298bd882c8aSJames Wright const CeedInt gid = indices[lid]; 299bd882c8aSJames Wright t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 300bd882c8aSJames Wright } 301bd882c8aSJames Wright } 302bd882c8aSJames Wright // Reset running sum 303bd882c8aSJames Wright for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 304bd882c8aSJames Wright t_offsets[0] = 0; 305bd882c8aSJames Wright 306bd882c8aSJames Wright // Copy data to device 307bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 308bd882c8aSJames Wright 3091f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 3101f4b1b45SUmesh Unnikrishnan 3111f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 312bd882c8aSJames Wright 313bd882c8aSJames Wright // -- L-vector indices 314bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_l_vec_indices = sycl::malloc_device<CeedInt>(num_nodes, data->sycl_device, data->sycl_context)); 3151f4b1b45SUmesh Unnikrishnan sycl::event copy_lvec = data->sycl_queue.copy<CeedInt>(l_vec_indices, impl->d_l_vec_indices, num_nodes, e); 316bd882c8aSJames Wright // -- Transpose offsets 317bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_t_offsets = sycl::malloc_device<CeedInt>(size_offsets, data->sycl_device, data->sycl_context)); 3181f4b1b45SUmesh Unnikrishnan sycl::event copy_offsets = data->sycl_queue.copy<CeedInt>(t_offsets, impl->d_t_offsets, size_offsets, e); 319bd882c8aSJames Wright // -- Transpose indices 320bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_t_indices = sycl::malloc_device<CeedInt>(size_indices, data->sycl_device, data->sycl_context)); 3211f4b1b45SUmesh Unnikrishnan sycl::event copy_indices = data->sycl_queue.copy<CeedInt>(t_indices, impl->d_t_indices, size_indices, e); 322bd882c8aSJames Wright 323bd882c8aSJames Wright // Wait for all copies to complete and handle exceptions 324bd882c8aSJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_lvec, copy_offsets, copy_indices})); 325bd882c8aSJames Wright 326bd882c8aSJames Wright // Cleanup 327bd882c8aSJames Wright CeedCallBackend(CeedFree(&ind_to_offset)); 328bd882c8aSJames Wright CeedCallBackend(CeedFree(&l_vec_indices)); 329bd882c8aSJames Wright CeedCallBackend(CeedFree(&t_offsets)); 330bd882c8aSJames Wright CeedCallBackend(CeedFree(&t_indices)); 3319bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 332bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 333bd882c8aSJames Wright } 334bd882c8aSJames Wright 335bd882c8aSJames Wright //------------------------------------------------------------------------------ 336bd882c8aSJames Wright // Create restriction 337bd882c8aSJames Wright //------------------------------------------------------------------------------ 338f59ebe5eSJeremy L Thompson int CeedElemRestrictionCreate_Sycl(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *offsets, const bool *orients, 339dce49693SSebastian Grimberg const CeedInt8 *curl_orients, CeedElemRestriction rstr) { 340bd882c8aSJames Wright Ceed ceed; 341bd882c8aSJames Wright Ceed_Sycl *data; 342dd64fc84SJeremy L Thompson bool is_strided; 343dd64fc84SJeremy L Thompson CeedInt num_elem, num_comp, elem_size, comp_stride = 1; 344dd64fc84SJeremy L Thompson CeedRestrictionType rstr_type; 345bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 346dd64fc84SJeremy L Thompson 347dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 348dd64fc84SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 349dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 350dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 351dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 352dce49693SSebastian Grimberg const CeedInt size = num_elem * elem_size; 353bd882c8aSJames Wright CeedInt strides[3] = {1, size, elem_size}; 354bd882c8aSJames Wright 355dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 35600125730SSebastian Grimberg CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND, 35700125730SSebastian Grimberg "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented"); 35800125730SSebastian Grimberg 359bd882c8aSJames Wright // Stride data 360dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionIsStrided(rstr, &is_strided)); 361bd882c8aSJames Wright if (is_strided) { 362bd882c8aSJames Wright bool has_backend_strides; 363dd64fc84SJeremy L Thompson 364dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 365bd882c8aSJames Wright if (!has_backend_strides) { 36656c48462SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetStrides(rstr, strides)); 367bd882c8aSJames Wright } 368bd882c8aSJames Wright } else { 369dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride)); 370bd882c8aSJames Wright } 371bd882c8aSJames Wright 372dce49693SSebastian Grimberg CeedCallBackend(CeedCalloc(1, &impl)); 373bd882c8aSJames Wright impl->num_nodes = size; 374bd882c8aSJames Wright impl->num_elem = num_elem; 375bd882c8aSJames Wright impl->num_comp = num_comp; 376bd882c8aSJames Wright impl->elem_size = elem_size; 377bd882c8aSJames Wright impl->comp_stride = comp_stride; 378bd882c8aSJames Wright impl->strides[0] = strides[0]; 379bd882c8aSJames Wright impl->strides[1] = strides[1]; 380bd882c8aSJames Wright impl->strides[2] = strides[2]; 381dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionSetData(rstr, impl)); 38222eb1385SJeremy L Thompson 38322eb1385SJeremy L Thompson // Set layouts 38422eb1385SJeremy L Thompson { 38522eb1385SJeremy L Thompson bool has_backend_strides; 38622eb1385SJeremy L Thompson CeedInt layout[3] = {1, size, elem_size}; 38722eb1385SJeremy L Thompson 388dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout)); 38922eb1385SJeremy L Thompson if (rstr_type == CEED_RESTRICTION_STRIDED) { 39022eb1385SJeremy L Thompson CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 39122eb1385SJeremy L Thompson if (has_backend_strides) { 39222eb1385SJeremy L Thompson CeedCallBackend(CeedElemRestrictionSetLLayout(rstr, layout)); 39322eb1385SJeremy L Thompson } 39422eb1385SJeremy L Thompson } 39522eb1385SJeremy L Thompson } 396bd882c8aSJames Wright 397bd882c8aSJames Wright // Set up device indices/offset arrays 3989d1bceceSJames Wright switch (mem_type) { 39942b3fd1bSJeremy L Thompson case CEED_MEM_HOST: { 400bd882c8aSJames Wright switch (copy_mode) { 401f59ebe5eSJeremy L Thompson case CEED_COPY_VALUES: 402f59ebe5eSJeremy L Thompson if (offsets != NULL) { 403f59ebe5eSJeremy L Thompson CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_offsets_owned)); 404f59ebe5eSJeremy L Thompson memcpy(impl->h_offsets_owned, offsets, elem_size * num_elem * sizeof(CeedInt)); 405f59ebe5eSJeremy L Thompson impl->h_offsets_borrowed = NULL; 406f59ebe5eSJeremy L Thompson impl->h_offsets = impl->h_offsets_owned; 407f59ebe5eSJeremy L Thompson } 408f59ebe5eSJeremy L Thompson break; 409bd882c8aSJames Wright case CEED_OWN_POINTER: 410f59ebe5eSJeremy L Thompson impl->h_offsets_owned = (CeedInt *)offsets; 411f59ebe5eSJeremy L Thompson impl->h_offsets_borrowed = NULL; 412f59ebe5eSJeremy L Thompson impl->h_offsets = impl->h_offsets_owned; 413bd882c8aSJames Wright break; 414bd882c8aSJames Wright case CEED_USE_POINTER: 415f59ebe5eSJeremy L Thompson impl->h_offsets_owned = NULL; 416f59ebe5eSJeremy L Thompson impl->h_offsets_borrowed = (CeedInt *)offsets; 417f59ebe5eSJeremy L Thompson impl->h_offsets = impl->h_offsets_borrowed; 418bd882c8aSJames Wright break; 419bd882c8aSJames Wright } 420f59ebe5eSJeremy L Thompson if (offsets != NULL) { 421f59ebe5eSJeremy L Thompson CeedCallSycl(ceed, impl->d_offsets_owned = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context)); 422bd882c8aSJames Wright // Copy from host to device 423f59ebe5eSJeremy L Thompson // -- Order queue 424f59ebe5eSJeremy L Thompson sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 425f59ebe5eSJeremy L Thompson sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->h_offsets, impl->d_offsets_owned, size, {e}); 426f59ebe5eSJeremy L Thompson // -- Wait for copy to finish and handle exceptions 427bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 428f59ebe5eSJeremy L Thompson impl->d_offsets = impl->d_offsets_owned; 429f59ebe5eSJeremy L Thompson CeedCallBackend(CeedElemRestrictionOffset_Sycl(rstr, offsets)); 430bd882c8aSJames Wright } 43142b3fd1bSJeremy L Thompson } break; 4329d1bceceSJames Wright case CEED_MEM_DEVICE: { 433bd882c8aSJames Wright switch (copy_mode) { 434bd882c8aSJames Wright case CEED_COPY_VALUES: 435f59ebe5eSJeremy L Thompson if (offsets != NULL) { 436f59ebe5eSJeremy L Thompson CeedCallSycl(ceed, impl->d_offsets_owned = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context)); 437bd882c8aSJames Wright // Copy from device to device 438f59ebe5eSJeremy L Thompson // -- Order queue 439bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 440f59ebe5eSJeremy L Thompson sycl::event copy_event = data->sycl_queue.copy<CeedInt>(offsets, impl->d_offsets_owned, size, {e}); 441f59ebe5eSJeremy L Thompson // -- Wait for copy to finish and handle exceptions 442bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 443f59ebe5eSJeremy L Thompson impl->d_offsets = impl->d_offsets_owned; 444bd882c8aSJames Wright } 445bd882c8aSJames Wright break; 446bd882c8aSJames Wright case CEED_OWN_POINTER: 447f59ebe5eSJeremy L Thompson impl->d_offsets_owned = (CeedInt *)offsets; 448f59ebe5eSJeremy L Thompson impl->d_offsets_borrowed = NULL; 449f59ebe5eSJeremy L Thompson impl->d_offsets = impl->d_offsets_owned; 450bd882c8aSJames Wright break; 451bd882c8aSJames Wright case CEED_USE_POINTER: 452f59ebe5eSJeremy L Thompson impl->d_offsets_owned = NULL; 453f59ebe5eSJeremy L Thompson impl->d_offsets_borrowed = (CeedInt *)offsets; 454f59ebe5eSJeremy L Thompson impl->d_offsets = impl->d_offsets_borrowed; 455bd882c8aSJames Wright } 456f59ebe5eSJeremy L Thompson if (offsets != NULL) { 457f59ebe5eSJeremy L Thompson CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_offsets_owned)); 458bd882c8aSJames Wright // Copy from device to host 459f59ebe5eSJeremy L Thompson // -- Order queue 460f59ebe5eSJeremy L Thompson sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 461f59ebe5eSJeremy L Thompson sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->d_offsets, impl->h_offsets_owned, elem_size * num_elem, {e}); 462f59ebe5eSJeremy L Thompson // -- Wait for copy to finish and handle exceptions 463bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 464f59ebe5eSJeremy L Thompson impl->h_offsets = impl->h_offsets_owned; 465f59ebe5eSJeremy L Thompson CeedCallBackend(CeedElemRestrictionOffset_Sycl(rstr, offsets)); 466bd882c8aSJames Wright } 467bd882c8aSJames Wright } 4689d1bceceSJames Wright } 469bd882c8aSJames Wright 470bd882c8aSJames Wright // Register backend functions 471dce49693SSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Sycl)); 472dce49693SSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApply_Sycl)); 473dce49693SSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApply_Sycl)); 474dce49693SSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Sycl)); 475dce49693SSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Sycl)); 4769bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 477bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 478bd882c8aSJames Wright } 479