1*5aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other 2bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE 3bd882c8aSJames Wright // files for details. 4bd882c8aSJames Wright // 5bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 6bd882c8aSJames Wright // 7bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 8bd882c8aSJames Wright 9bd882c8aSJames Wright #include <ceed/backend.h> 10bd882c8aSJames Wright #include <ceed/ceed.h> 11bd882c8aSJames Wright #include <ceed/jit-tools.h> 12bd882c8aSJames Wright 13bd882c8aSJames Wright #include <string> 14bd882c8aSJames Wright #include <sycl/sycl.hpp> 15bd882c8aSJames Wright 16bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 17bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 18bd882c8aSJames Wright 19bd882c8aSJames Wright class CeedElemRestrSyclStridedNT; 20bd882c8aSJames Wright class CeedElemRestrSyclOffsetNT; 21bd882c8aSJames Wright class CeedElemRestrSyclStridedT; 22bd882c8aSJames Wright class CeedElemRestrSyclOffsetT; 23bd882c8aSJames Wright 24bd882c8aSJames Wright //------------------------------------------------------------------------------ 25bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, strided 26bd882c8aSJames Wright //------------------------------------------------------------------------------ 27bd882c8aSJames Wright static int CeedElemRestrictionStridedNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 28bd882c8aSJames Wright CeedScalar *v) { 29bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 30bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 31bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 32bd882c8aSJames Wright const CeedInt stride_nodes = impl->strides[0]; 33bd882c8aSJames Wright const CeedInt stride_comp = impl->strides[1]; 34bd882c8aSJames Wright const CeedInt stride_elem = impl->strides[2]; 35bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 36bd882c8aSJames Wright 37bd882c8aSJames Wright // Order queue 38bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 39bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclStridedNT>(kernel_range, {e}, [=](sycl::id<1> node) { 40bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 41bd882c8aSJames Wright const CeedInt elem = node / elem_size; 42bd882c8aSJames Wright 43bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 44bd882c8aSJames Wright v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem]; 45bd882c8aSJames Wright } 46bd882c8aSJames Wright }); 47bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 48bd882c8aSJames Wright } 49bd882c8aSJames Wright 50bd882c8aSJames Wright //------------------------------------------------------------------------------ 51bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, offsets provided 52bd882c8aSJames Wright //------------------------------------------------------------------------------ 53bd882c8aSJames Wright static int CeedElemRestrictionOffsetNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 54bd882c8aSJames Wright CeedScalar *v) { 55bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 56bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 57bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 58bd882c8aSJames Wright const CeedInt comp_stride = impl->comp_stride; 59f59ebe5eSJeremy L Thompson const CeedInt *indices = impl->d_offsets; 60bd882c8aSJames Wright 61bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 62bd882c8aSJames Wright 63bd882c8aSJames Wright // Order queue 64bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 65bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclOffsetNT>(kernel_range, {e}, [=](sycl::id<1> node) { 66bd882c8aSJames Wright const CeedInt ind = indices[node]; 67bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 68bd882c8aSJames Wright const CeedInt elem = node / elem_size; 69bd882c8aSJames Wright 70bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 71bd882c8aSJames Wright v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[ind + comp * comp_stride]; 72bd882c8aSJames Wright } 73bd882c8aSJames Wright }); 74bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 75bd882c8aSJames Wright } 76bd882c8aSJames Wright 77bd882c8aSJames Wright //------------------------------------------------------------------------------ 78bd882c8aSJames Wright // Kernel: E-vector -> L-vector, strided 79bd882c8aSJames Wright //------------------------------------------------------------------------------ 80bd882c8aSJames Wright static int CeedElemRestrictionStridedTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 81bd882c8aSJames Wright CeedScalar *v) { 82bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 83bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 84bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 85bd882c8aSJames Wright const CeedInt stride_nodes = impl->strides[0]; 86bd882c8aSJames Wright const CeedInt stride_comp = impl->strides[1]; 87bd882c8aSJames Wright const CeedInt stride_elem = impl->strides[2]; 88bd882c8aSJames Wright 89bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 90bd882c8aSJames Wright 91bd882c8aSJames Wright // Order queue 92bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 93bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclStridedT>(kernel_range, {e}, [=](sycl::id<1> node) { 94bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 95bd882c8aSJames Wright const CeedInt elem = node / elem_size; 96bd882c8aSJames Wright 97bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 98bd882c8aSJames Wright v[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem] += u[loc_node + comp * elem_size * num_elem + elem * elem_size]; 99bd882c8aSJames Wright } 100bd882c8aSJames Wright }); 101bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 102bd882c8aSJames Wright } 103bd882c8aSJames Wright 104bd882c8aSJames Wright //------------------------------------------------------------------------------ 105bd882c8aSJames Wright // Kernel: E-vector -> L-vector, offsets provided 106bd882c8aSJames Wright //------------------------------------------------------------------------------ 107bd882c8aSJames Wright static int CeedElemRestrictionOffsetTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 108bd882c8aSJames Wright CeedScalar *v) { 109bd882c8aSJames Wright const CeedInt num_nodes = impl->num_nodes; 110bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 111bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 112bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 113bd882c8aSJames Wright const CeedInt comp_stride = impl->comp_stride; 114bd882c8aSJames Wright const CeedInt *l_vec_indices = impl->d_l_vec_indices; 115bd882c8aSJames Wright const CeedInt *t_offsets = impl->d_t_offsets; 116bd882c8aSJames Wright const CeedInt *t_indices = impl->d_t_indices; 117bd882c8aSJames Wright 118bd882c8aSJames Wright sycl::range<1> kernel_range(num_nodes * num_comp); 119bd882c8aSJames Wright 120bd882c8aSJames Wright // Order queue 121bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 122bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclOffsetT>(kernel_range, {e}, [=](sycl::id<1> id) { 123bd882c8aSJames Wright const CeedInt node = id % num_nodes; 124bd882c8aSJames Wright const CeedInt comp = id / num_nodes; 125bd882c8aSJames Wright const CeedInt ind = l_vec_indices[node]; 126bd882c8aSJames Wright const CeedInt range_1 = t_offsets[node]; 127bd882c8aSJames Wright const CeedInt range_N = t_offsets[node + 1]; 128bd882c8aSJames Wright CeedScalar value = 0.0; 129bd882c8aSJames Wright 130bd882c8aSJames Wright for (CeedInt j = range_1; j < range_N; j++) { 131bd882c8aSJames Wright const CeedInt t_ind = t_indices[j]; 132bd882c8aSJames Wright CeedInt loc_node = t_ind % elem_size; 133bd882c8aSJames Wright CeedInt elem = t_ind / elem_size; 134bd882c8aSJames Wright 135bd882c8aSJames Wright value += u[loc_node + comp * elem_size * num_elem + elem * elem_size]; 136bd882c8aSJames Wright } 137bd882c8aSJames Wright v[ind + comp * comp_stride] += value; 138bd882c8aSJames Wright }); 139bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 140bd882c8aSJames Wright } 141bd882c8aSJames Wright 142bd882c8aSJames Wright //------------------------------------------------------------------------------ 143bd882c8aSJames Wright // Apply restriction 144bd882c8aSJames Wright //------------------------------------------------------------------------------ 145dce49693SSebastian Grimberg static int CeedElemRestrictionApply_Sycl(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 146bd882c8aSJames Wright Ceed ceed; 147bd882c8aSJames Wright Ceed_Sycl *data; 148dd64fc84SJeremy L Thompson const CeedScalar *d_u; 149dd64fc84SJeremy L Thompson CeedScalar *d_v; 150dd64fc84SJeremy L Thompson CeedElemRestriction_Sycl *impl; 151dd64fc84SJeremy L Thompson 152dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 153dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 154bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 155bd882c8aSJames Wright 156bd882c8aSJames Wright // Get vectors 157bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 158bd882c8aSJames Wright if (t_mode == CEED_TRANSPOSE) { 159bd882c8aSJames Wright // Sum into for transpose mode, e-vec to l-vec 160bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 161bd882c8aSJames Wright } else { 162bd882c8aSJames Wright // Overwrite for notranspose mode, l-vec to e-vec 163bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 164bd882c8aSJames Wright } 165bd882c8aSJames Wright 166bd882c8aSJames Wright // Restrict 167bd882c8aSJames Wright if (t_mode == CEED_NOTRANSPOSE) { 168bd882c8aSJames Wright // L-vector -> E-vector 169f59ebe5eSJeremy L Thompson if (impl->d_offsets) { 170bd882c8aSJames Wright // -- Offsets provided 171bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffsetNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 172bd882c8aSJames Wright } else { 173bd882c8aSJames Wright // -- Strided restriction 174bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionStridedNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 175bd882c8aSJames Wright } 176bd882c8aSJames Wright } else { 177bd882c8aSJames Wright // E-vector -> L-vector 178f59ebe5eSJeremy L Thompson if (impl->d_offsets) { 179bd882c8aSJames Wright // -- Offsets provided 180bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffsetTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 181bd882c8aSJames Wright } else { 182bd882c8aSJames Wright // -- Strided restriction 183bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionStridedTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 184bd882c8aSJames Wright } 185bd882c8aSJames Wright } 186bd882c8aSJames Wright // Wait for queues to be completed. NOTE: This may not be necessary 187bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 188bd882c8aSJames Wright 189bd882c8aSJames Wright if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 190bd882c8aSJames Wright 191bd882c8aSJames Wright // Restore arrays 192bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 193bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 194bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 195bd882c8aSJames Wright } 196bd882c8aSJames Wright 197bd882c8aSJames Wright //------------------------------------------------------------------------------ 198bd882c8aSJames Wright // Get offsets 199bd882c8aSJames Wright //------------------------------------------------------------------------------ 200dce49693SSebastian Grimberg static int CeedElemRestrictionGetOffsets_Sycl(CeedElemRestriction rstr, CeedMemType m_type, const CeedInt **offsets) { 201bd882c8aSJames Wright Ceed ceed; 202bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 203dd64fc84SJeremy L Thompson 204dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 205dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 206bd882c8aSJames Wright 207bd882c8aSJames Wright switch (m_type) { 208bd882c8aSJames Wright case CEED_MEM_HOST: 209f59ebe5eSJeremy L Thompson *offsets = impl->h_offsets; 210bd882c8aSJames Wright break; 211bd882c8aSJames Wright case CEED_MEM_DEVICE: 212f59ebe5eSJeremy L Thompson *offsets = impl->d_offsets; 213bd882c8aSJames Wright break; 214bd882c8aSJames Wright } 215bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 216bd882c8aSJames Wright } 217bd882c8aSJames Wright 218bd882c8aSJames Wright //------------------------------------------------------------------------------ 219bd882c8aSJames Wright // Destroy restriction 220bd882c8aSJames Wright //------------------------------------------------------------------------------ 221dce49693SSebastian Grimberg static int CeedElemRestrictionDestroy_Sycl(CeedElemRestriction rstr) { 222bd882c8aSJames Wright Ceed ceed; 223bd882c8aSJames Wright Ceed_Sycl *data; 224dd64fc84SJeremy L Thompson CeedElemRestriction_Sycl *impl; 225dd64fc84SJeremy L Thompson 226dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 227dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 228bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 229bd882c8aSJames Wright 230bd882c8aSJames Wright // Wait for all work to finish before freeing memory 231bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 232bd882c8aSJames Wright 233f59ebe5eSJeremy L Thompson CeedCallBackend(CeedFree(&impl->h_offsets_owned)); 234f59ebe5eSJeremy L Thompson CeedCallSycl(ceed, sycl::free(impl->d_offsets_owned, data->sycl_context)); 235bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_t_offsets, data->sycl_context)); 236bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_t_indices, data->sycl_context)); 237bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_l_vec_indices, data->sycl_context)); 238bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 239bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 240bd882c8aSJames Wright } 241bd882c8aSJames Wright 242bd882c8aSJames Wright //------------------------------------------------------------------------------ 243bd882c8aSJames Wright // Create transpose offsets and indices 244bd882c8aSJames Wright //------------------------------------------------------------------------------ 245dce49693SSebastian Grimberg static int CeedElemRestrictionOffset_Sycl(const CeedElemRestriction rstr, const CeedInt *indices) { 246bd882c8aSJames Wright Ceed ceed; 247dd64fc84SJeremy L Thompson Ceed_Sycl *data; 248dd64fc84SJeremy L Thompson bool *is_node; 249bd882c8aSJames Wright CeedSize l_size; 250dd64fc84SJeremy L Thompson CeedInt num_elem, elem_size, num_comp, num_nodes = 0, *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices; 251dd64fc84SJeremy L Thompson CeedElemRestriction_Sycl *impl; 252dd64fc84SJeremy L Thompson 253dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 254dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 255dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 256dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 257dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 258dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 259bd882c8aSJames Wright 260bd882c8aSJames Wright // Count num_nodes 261bd882c8aSJames Wright CeedCallBackend(CeedCalloc(l_size, &is_node)); 262bd882c8aSJames Wright const CeedInt size_indices = num_elem * elem_size; 263dd64fc84SJeremy L Thompson 264bd882c8aSJames Wright for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 265bd882c8aSJames Wright for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 266bd882c8aSJames Wright impl->num_nodes = num_nodes; 267bd882c8aSJames Wright 268bd882c8aSJames Wright // L-vector offsets array 269bd882c8aSJames Wright CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 270bd882c8aSJames Wright CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 271dd64fc84SJeremy L Thompson for (CeedInt i = 0, j = 0; i < l_size; i++) { 272bd882c8aSJames Wright if (is_node[i]) { 273bd882c8aSJames Wright l_vec_indices[j] = i; 274bd882c8aSJames Wright ind_to_offset[i] = j++; 275bd882c8aSJames Wright } 276bd882c8aSJames Wright } 277bd882c8aSJames Wright CeedCallBackend(CeedFree(&is_node)); 278bd882c8aSJames Wright 279bd882c8aSJames Wright // Compute transpose offsets and indices 280bd882c8aSJames Wright const CeedInt size_offsets = num_nodes + 1; 281dd64fc84SJeremy L Thompson 282bd882c8aSJames Wright CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 283bd882c8aSJames Wright CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 284bd882c8aSJames Wright // Count node multiplicity 285bd882c8aSJames Wright for (CeedInt e = 0; e < num_elem; ++e) { 286bd882c8aSJames Wright for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 287bd882c8aSJames Wright } 288bd882c8aSJames Wright // Convert to running sum 289bd882c8aSJames Wright for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 290bd882c8aSJames Wright // List all E-vec indices associated with L-vec node 291bd882c8aSJames Wright for (CeedInt e = 0; e < num_elem; ++e) { 292bd882c8aSJames Wright for (CeedInt i = 0; i < elem_size; ++i) { 293bd882c8aSJames Wright const CeedInt lid = elem_size * e + i; 294bd882c8aSJames Wright const CeedInt gid = indices[lid]; 295bd882c8aSJames Wright t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 296bd882c8aSJames Wright } 297bd882c8aSJames Wright } 298bd882c8aSJames Wright // Reset running sum 299bd882c8aSJames Wright for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 300bd882c8aSJames Wright t_offsets[0] = 0; 301bd882c8aSJames Wright 302bd882c8aSJames Wright // Copy data to device 303bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 304bd882c8aSJames Wright 305bd882c8aSJames Wright // Order queue 306bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 307bd882c8aSJames Wright 308bd882c8aSJames Wright // -- L-vector indices 309bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_l_vec_indices = sycl::malloc_device<CeedInt>(num_nodes, data->sycl_device, data->sycl_context)); 310bd882c8aSJames Wright sycl::event copy_lvec = data->sycl_queue.copy<CeedInt>(l_vec_indices, impl->d_l_vec_indices, num_nodes, {e}); 311bd882c8aSJames Wright // -- Transpose offsets 312bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_t_offsets = sycl::malloc_device<CeedInt>(size_offsets, data->sycl_device, data->sycl_context)); 313bd882c8aSJames Wright sycl::event copy_offsets = data->sycl_queue.copy<CeedInt>(t_offsets, impl->d_t_offsets, size_offsets, {e}); 314bd882c8aSJames Wright // -- Transpose indices 315bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_t_indices = sycl::malloc_device<CeedInt>(size_indices, data->sycl_device, data->sycl_context)); 316bd882c8aSJames Wright sycl::event copy_indices = data->sycl_queue.copy<CeedInt>(t_indices, impl->d_t_indices, size_indices, {e}); 317bd882c8aSJames Wright 318bd882c8aSJames Wright // Wait for all copies to complete and handle exceptions 319bd882c8aSJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_lvec, copy_offsets, copy_indices})); 320bd882c8aSJames Wright 321bd882c8aSJames Wright // Cleanup 322bd882c8aSJames Wright CeedCallBackend(CeedFree(&ind_to_offset)); 323bd882c8aSJames Wright CeedCallBackend(CeedFree(&l_vec_indices)); 324bd882c8aSJames Wright CeedCallBackend(CeedFree(&t_offsets)); 325bd882c8aSJames Wright CeedCallBackend(CeedFree(&t_indices)); 326bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 327bd882c8aSJames Wright } 328bd882c8aSJames Wright 329bd882c8aSJames Wright //------------------------------------------------------------------------------ 330bd882c8aSJames Wright // Create restriction 331bd882c8aSJames Wright //------------------------------------------------------------------------------ 332f59ebe5eSJeremy L Thompson int CeedElemRestrictionCreate_Sycl(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *offsets, const bool *orients, 333dce49693SSebastian Grimberg const CeedInt8 *curl_orients, CeedElemRestriction rstr) { 334bd882c8aSJames Wright Ceed ceed; 335bd882c8aSJames Wright Ceed_Sycl *data; 336dd64fc84SJeremy L Thompson bool is_strided; 337dd64fc84SJeremy L Thompson CeedInt num_elem, num_comp, elem_size, comp_stride = 1; 338dd64fc84SJeremy L Thompson CeedRestrictionType rstr_type; 339bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 340dd64fc84SJeremy L Thompson 341dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 342dd64fc84SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 343dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 344dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 345dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 346dce49693SSebastian Grimberg const CeedInt size = num_elem * elem_size; 347bd882c8aSJames Wright CeedInt strides[3] = {1, size, elem_size}; 348bd882c8aSJames Wright 349dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 35000125730SSebastian Grimberg CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND, 35100125730SSebastian Grimberg "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented"); 35200125730SSebastian Grimberg 353bd882c8aSJames Wright // Stride data 354dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionIsStrided(rstr, &is_strided)); 355bd882c8aSJames Wright if (is_strided) { 356bd882c8aSJames Wright bool has_backend_strides; 357dd64fc84SJeremy L Thompson 358dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 359bd882c8aSJames Wright if (!has_backend_strides) { 36056c48462SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetStrides(rstr, strides)); 361bd882c8aSJames Wright } 362bd882c8aSJames Wright } else { 363dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride)); 364bd882c8aSJames Wright } 365bd882c8aSJames Wright 366dce49693SSebastian Grimberg CeedCallBackend(CeedCalloc(1, &impl)); 367bd882c8aSJames Wright impl->num_nodes = size; 368bd882c8aSJames Wright impl->num_elem = num_elem; 369bd882c8aSJames Wright impl->num_comp = num_comp; 370bd882c8aSJames Wright impl->elem_size = elem_size; 371bd882c8aSJames Wright impl->comp_stride = comp_stride; 372bd882c8aSJames Wright impl->strides[0] = strides[0]; 373bd882c8aSJames Wright impl->strides[1] = strides[1]; 374bd882c8aSJames Wright impl->strides[2] = strides[2]; 375dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionSetData(rstr, impl)); 37622eb1385SJeremy L Thompson 37722eb1385SJeremy L Thompson // Set layouts 37822eb1385SJeremy L Thompson { 37922eb1385SJeremy L Thompson bool has_backend_strides; 38022eb1385SJeremy L Thompson CeedInt layout[3] = {1, size, elem_size}; 38122eb1385SJeremy L Thompson 382dce49693SSebastian Grimberg CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout)); 38322eb1385SJeremy L Thompson if (rstr_type == CEED_RESTRICTION_STRIDED) { 38422eb1385SJeremy L Thompson CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 38522eb1385SJeremy L Thompson if (has_backend_strides) { 38622eb1385SJeremy L Thompson CeedCallBackend(CeedElemRestrictionSetLLayout(rstr, layout)); 38722eb1385SJeremy L Thompson } 38822eb1385SJeremy L Thompson } 38922eb1385SJeremy L Thompson } 390bd882c8aSJames Wright 391bd882c8aSJames Wright // Set up device indices/offset arrays 3929d1bceceSJames Wright switch (mem_type) { 39342b3fd1bSJeremy L Thompson case CEED_MEM_HOST: { 394bd882c8aSJames Wright switch (copy_mode) { 395f59ebe5eSJeremy L Thompson case CEED_COPY_VALUES: 396f59ebe5eSJeremy L Thompson if (offsets != NULL) { 397f59ebe5eSJeremy L Thompson CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_offsets_owned)); 398f59ebe5eSJeremy L Thompson memcpy(impl->h_offsets_owned, offsets, elem_size * num_elem * sizeof(CeedInt)); 399f59ebe5eSJeremy L Thompson impl->h_offsets_borrowed = NULL; 400f59ebe5eSJeremy L Thompson impl->h_offsets = impl->h_offsets_owned; 401f59ebe5eSJeremy L Thompson } 402f59ebe5eSJeremy L Thompson break; 403bd882c8aSJames Wright case CEED_OWN_POINTER: 404f59ebe5eSJeremy L Thompson impl->h_offsets_owned = (CeedInt *)offsets; 405f59ebe5eSJeremy L Thompson impl->h_offsets_borrowed = NULL; 406f59ebe5eSJeremy L Thompson impl->h_offsets = impl->h_offsets_owned; 407bd882c8aSJames Wright break; 408bd882c8aSJames Wright case CEED_USE_POINTER: 409f59ebe5eSJeremy L Thompson impl->h_offsets_owned = NULL; 410f59ebe5eSJeremy L Thompson impl->h_offsets_borrowed = (CeedInt *)offsets; 411f59ebe5eSJeremy L Thompson impl->h_offsets = impl->h_offsets_borrowed; 412bd882c8aSJames Wright break; 413bd882c8aSJames Wright } 414f59ebe5eSJeremy L Thompson if (offsets != NULL) { 415f59ebe5eSJeremy L Thompson CeedCallSycl(ceed, impl->d_offsets_owned = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context)); 416bd882c8aSJames Wright // Copy from host to device 417f59ebe5eSJeremy L Thompson // -- Order queue 418f59ebe5eSJeremy L Thompson sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 419f59ebe5eSJeremy L Thompson sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->h_offsets, impl->d_offsets_owned, size, {e}); 420f59ebe5eSJeremy L Thompson // -- Wait for copy to finish and handle exceptions 421bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 422f59ebe5eSJeremy L Thompson impl->d_offsets = impl->d_offsets_owned; 423f59ebe5eSJeremy L Thompson CeedCallBackend(CeedElemRestrictionOffset_Sycl(rstr, offsets)); 424bd882c8aSJames Wright } 42542b3fd1bSJeremy L Thompson } break; 4269d1bceceSJames Wright case CEED_MEM_DEVICE: { 427bd882c8aSJames Wright switch (copy_mode) { 428bd882c8aSJames Wright case CEED_COPY_VALUES: 429f59ebe5eSJeremy L Thompson if (offsets != NULL) { 430f59ebe5eSJeremy L Thompson CeedCallSycl(ceed, impl->d_offsets_owned = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context)); 431bd882c8aSJames Wright // Copy from device to device 432f59ebe5eSJeremy L Thompson // -- Order queue 433bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 434f59ebe5eSJeremy L Thompson sycl::event copy_event = data->sycl_queue.copy<CeedInt>(offsets, impl->d_offsets_owned, size, {e}); 435f59ebe5eSJeremy L Thompson // -- Wait for copy to finish and handle exceptions 436bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 437f59ebe5eSJeremy L Thompson impl->d_offsets = impl->d_offsets_owned; 438bd882c8aSJames Wright } 439bd882c8aSJames Wright break; 440bd882c8aSJames Wright case CEED_OWN_POINTER: 441f59ebe5eSJeremy L Thompson impl->d_offsets_owned = (CeedInt *)offsets; 442f59ebe5eSJeremy L Thompson impl->d_offsets_borrowed = NULL; 443f59ebe5eSJeremy L Thompson impl->d_offsets = impl->d_offsets_owned; 444bd882c8aSJames Wright break; 445bd882c8aSJames Wright case CEED_USE_POINTER: 446f59ebe5eSJeremy L Thompson impl->d_offsets_owned = NULL; 447f59ebe5eSJeremy L Thompson impl->d_offsets_borrowed = (CeedInt *)offsets; 448f59ebe5eSJeremy L Thompson impl->d_offsets = impl->d_offsets_borrowed; 449bd882c8aSJames Wright } 450f59ebe5eSJeremy L Thompson if (offsets != NULL) { 451f59ebe5eSJeremy L Thompson CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_offsets_owned)); 452bd882c8aSJames Wright // Copy from device to host 453f59ebe5eSJeremy L Thompson // -- Order queue 454f59ebe5eSJeremy L Thompson sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 455f59ebe5eSJeremy L Thompson sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->d_offsets, impl->h_offsets_owned, elem_size * num_elem, {e}); 456f59ebe5eSJeremy L Thompson // -- Wait for copy to finish and handle exceptions 457bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 458f59ebe5eSJeremy L Thompson impl->h_offsets = impl->h_offsets_owned; 459f59ebe5eSJeremy L Thompson CeedCallBackend(CeedElemRestrictionOffset_Sycl(rstr, offsets)); 460bd882c8aSJames Wright } 461bd882c8aSJames Wright } 4629d1bceceSJames Wright } 463bd882c8aSJames Wright 464bd882c8aSJames Wright // Register backend functions 465dce49693SSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Sycl)); 466dce49693SSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApply_Sycl)); 467dce49693SSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApply_Sycl)); 468dce49693SSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Sycl)); 469dce49693SSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Sycl)); 470bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 471bd882c8aSJames Wright } 472