1bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other 2bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE 3bd882c8aSJames Wright // files for details. 4bd882c8aSJames Wright // 5bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 6bd882c8aSJames Wright // 7bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 8bd882c8aSJames Wright 9bd882c8aSJames Wright #include <ceed/backend.h> 10bd882c8aSJames Wright #include <ceed/ceed.h> 11bd882c8aSJames Wright #include <ceed/jit-tools.h> 12bd882c8aSJames Wright 13bd882c8aSJames Wright #include <string> 14bd882c8aSJames Wright #include <sycl/sycl.hpp> 15bd882c8aSJames Wright 16bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 17bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 18bd882c8aSJames Wright 19bd882c8aSJames Wright class CeedElemRestrSyclStridedNT; 20bd882c8aSJames Wright class CeedElemRestrSyclOffsetNT; 21bd882c8aSJames Wright class CeedElemRestrSyclStridedT; 22bd882c8aSJames Wright class CeedElemRestrSyclOffsetT; 23bd882c8aSJames Wright 24bd882c8aSJames Wright //------------------------------------------------------------------------------ 25bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, strided 26bd882c8aSJames Wright //------------------------------------------------------------------------------ 27bd882c8aSJames Wright static int CeedElemRestrictionStridedNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 28bd882c8aSJames Wright CeedScalar *v) { 29bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 30bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 31bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 32bd882c8aSJames Wright const CeedInt stride_nodes = impl->strides[0]; 33bd882c8aSJames Wright const CeedInt stride_comp = impl->strides[1]; 34bd882c8aSJames Wright const CeedInt stride_elem = impl->strides[2]; 35bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 36bd882c8aSJames Wright 37bd882c8aSJames Wright // Order queue 38bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 39bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclStridedNT>(kernel_range, {e}, [=](sycl::id<1> node) { 40bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 41bd882c8aSJames Wright const CeedInt elem = node / elem_size; 42bd882c8aSJames Wright 43bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 44bd882c8aSJames Wright v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem]; 45bd882c8aSJames Wright } 46bd882c8aSJames Wright }); 47bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 48bd882c8aSJames Wright } 49bd882c8aSJames Wright 50bd882c8aSJames Wright //------------------------------------------------------------------------------ 51bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, offsets provided 52bd882c8aSJames Wright //------------------------------------------------------------------------------ 53bd882c8aSJames Wright static int CeedElemRestrictionOffsetNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 54bd882c8aSJames Wright CeedScalar *v) { 55bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 56bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 57bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 58bd882c8aSJames Wright const CeedInt comp_stride = impl->comp_stride; 59bd882c8aSJames Wright const CeedInt *indices = impl->d_ind; 60bd882c8aSJames Wright 61bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 62bd882c8aSJames Wright 63bd882c8aSJames Wright // Order queue 64bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 65bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclOffsetNT>(kernel_range, {e}, [=](sycl::id<1> node) { 66bd882c8aSJames Wright const CeedInt ind = indices[node]; 67bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 68bd882c8aSJames Wright const CeedInt elem = node / elem_size; 69bd882c8aSJames Wright 70bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 71bd882c8aSJames Wright v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[ind + comp * comp_stride]; 72bd882c8aSJames Wright } 73bd882c8aSJames Wright }); 74bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 75bd882c8aSJames Wright } 76bd882c8aSJames Wright 77bd882c8aSJames Wright //------------------------------------------------------------------------------ 78bd882c8aSJames Wright // Kernel: E-vector -> L-vector, strided 79bd882c8aSJames Wright //------------------------------------------------------------------------------ 80bd882c8aSJames Wright static int CeedElemRestrictionStridedTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 81bd882c8aSJames Wright CeedScalar *v) { 82bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 83bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 84bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 85bd882c8aSJames Wright const CeedInt stride_nodes = impl->strides[0]; 86bd882c8aSJames Wright const CeedInt stride_comp = impl->strides[1]; 87bd882c8aSJames Wright const CeedInt stride_elem = impl->strides[2]; 88bd882c8aSJames Wright 89bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 90bd882c8aSJames Wright 91bd882c8aSJames Wright // Order queue 92bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 93bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclStridedT>(kernel_range, {e}, [=](sycl::id<1> node) { 94bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 95bd882c8aSJames Wright const CeedInt elem = node / elem_size; 96bd882c8aSJames Wright 97bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 98bd882c8aSJames Wright v[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem] += u[loc_node + comp * elem_size * num_elem + elem * elem_size]; 99bd882c8aSJames Wright } 100bd882c8aSJames Wright }); 101bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 102bd882c8aSJames Wright } 103bd882c8aSJames Wright 104bd882c8aSJames Wright //------------------------------------------------------------------------------ 105bd882c8aSJames Wright // Kernel: E-vector -> L-vector, offsets provided 106bd882c8aSJames Wright //------------------------------------------------------------------------------ 107bd882c8aSJames Wright static int CeedElemRestrictionOffsetTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 108bd882c8aSJames Wright CeedScalar *v) { 109bd882c8aSJames Wright const CeedInt num_nodes = impl->num_nodes; 110bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 111bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 112bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 113bd882c8aSJames Wright const CeedInt comp_stride = impl->comp_stride; 114bd882c8aSJames Wright const CeedInt *l_vec_indices = impl->d_l_vec_indices; 115bd882c8aSJames Wright const CeedInt *t_offsets = impl->d_t_offsets; 116bd882c8aSJames Wright const CeedInt *t_indices = impl->d_t_indices; 117bd882c8aSJames Wright 118bd882c8aSJames Wright sycl::range<1> kernel_range(num_nodes * num_comp); 119bd882c8aSJames Wright 120bd882c8aSJames Wright // Order queue 121bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 122bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclOffsetT>(kernel_range, {e}, [=](sycl::id<1> id) { 123bd882c8aSJames Wright const CeedInt node = id % num_nodes; 124bd882c8aSJames Wright const CeedInt comp = id / num_nodes; 125bd882c8aSJames Wright const CeedInt ind = l_vec_indices[node]; 126bd882c8aSJames Wright const CeedInt range_1 = t_offsets[node]; 127bd882c8aSJames Wright const CeedInt range_N = t_offsets[node + 1]; 128bd882c8aSJames Wright CeedScalar value = 0.0; 129bd882c8aSJames Wright 130bd882c8aSJames Wright for (CeedInt j = range_1; j < range_N; j++) { 131bd882c8aSJames Wright const CeedInt t_ind = t_indices[j]; 132bd882c8aSJames Wright CeedInt loc_node = t_ind % elem_size; 133bd882c8aSJames Wright CeedInt elem = t_ind / elem_size; 134bd882c8aSJames Wright 135bd882c8aSJames Wright value += u[loc_node + comp * elem_size * num_elem + elem * elem_size]; 136bd882c8aSJames Wright } 137bd882c8aSJames Wright v[ind + comp * comp_stride] += value; 138bd882c8aSJames Wright }); 139bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 140bd882c8aSJames Wright } 141bd882c8aSJames Wright 142bd882c8aSJames Wright //------------------------------------------------------------------------------ 143bd882c8aSJames Wright // Apply restriction 144bd882c8aSJames Wright //------------------------------------------------------------------------------ 145bd882c8aSJames Wright static int CeedElemRestrictionApply_Sycl(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 146bd882c8aSJames Wright Ceed ceed; 147bd882c8aSJames Wright Ceed_Sycl *data; 148*dd64fc84SJeremy L Thompson const CeedScalar *d_u; 149*dd64fc84SJeremy L Thompson CeedScalar *d_v; 150*dd64fc84SJeremy L Thompson CeedElemRestriction_Sycl *impl; 151*dd64fc84SJeremy L Thompson 152*dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 153*dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 154bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 155bd882c8aSJames Wright 156bd882c8aSJames Wright // Get vectors 157bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 158bd882c8aSJames Wright if (t_mode == CEED_TRANSPOSE) { 159bd882c8aSJames Wright // Sum into for transpose mode, e-vec to l-vec 160bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 161bd882c8aSJames Wright } else { 162bd882c8aSJames Wright // Overwrite for notranspose mode, l-vec to e-vec 163bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 164bd882c8aSJames Wright } 165bd882c8aSJames Wright 166bd882c8aSJames Wright // Restrict 167bd882c8aSJames Wright if (t_mode == CEED_NOTRANSPOSE) { 168bd882c8aSJames Wright // L-vector -> E-vector 169bd882c8aSJames Wright if (impl->d_ind) { 170bd882c8aSJames Wright // -- Offsets provided 171bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffsetNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 172bd882c8aSJames Wright } else { 173bd882c8aSJames Wright // -- Strided restriction 174bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionStridedNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 175bd882c8aSJames Wright } 176bd882c8aSJames Wright } else { 177bd882c8aSJames Wright // E-vector -> L-vector 178bd882c8aSJames Wright if (impl->d_ind) { 179bd882c8aSJames Wright // -- Offsets provided 180bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffsetTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 181bd882c8aSJames Wright } else { 182bd882c8aSJames Wright // -- Strided restriction 183bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionStridedTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 184bd882c8aSJames Wright } 185bd882c8aSJames Wright } 186bd882c8aSJames Wright // Wait for queues to be completed. NOTE: This may not be necessary 187bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 188bd882c8aSJames Wright 189bd882c8aSJames Wright if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 190bd882c8aSJames Wright 191bd882c8aSJames Wright // Restore arrays 192bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 193bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 194bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 195bd882c8aSJames Wright } 196bd882c8aSJames Wright 197bd882c8aSJames Wright //------------------------------------------------------------------------------ 198bd882c8aSJames Wright // Get offsets 199bd882c8aSJames Wright //------------------------------------------------------------------------------ 200bd882c8aSJames Wright static int CeedElemRestrictionGetOffsets_Sycl(CeedElemRestriction r, CeedMemType m_type, const CeedInt **offsets) { 201bd882c8aSJames Wright Ceed ceed; 202bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 203*dd64fc84SJeremy L Thompson 204*dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 205bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 206bd882c8aSJames Wright 207bd882c8aSJames Wright switch (m_type) { 208bd882c8aSJames Wright case CEED_MEM_HOST: 209bd882c8aSJames Wright *offsets = impl->h_ind; 210bd882c8aSJames Wright break; 211bd882c8aSJames Wright case CEED_MEM_DEVICE: 212bd882c8aSJames Wright *offsets = impl->d_ind; 213bd882c8aSJames Wright break; 214bd882c8aSJames Wright } 215bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 216bd882c8aSJames Wright } 217bd882c8aSJames Wright 218bd882c8aSJames Wright //------------------------------------------------------------------------------ 219bd882c8aSJames Wright // Destroy restriction 220bd882c8aSJames Wright //------------------------------------------------------------------------------ 221bd882c8aSJames Wright static int CeedElemRestrictionDestroy_Sycl(CeedElemRestriction r) { 222bd882c8aSJames Wright Ceed ceed; 223bd882c8aSJames Wright Ceed_Sycl *data; 224*dd64fc84SJeremy L Thompson CeedElemRestriction_Sycl *impl; 225*dd64fc84SJeremy L Thompson 226*dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 227*dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 228bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 229bd882c8aSJames Wright 230bd882c8aSJames Wright // Wait for all work to finish before freeing memory 231bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 232bd882c8aSJames Wright 233bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 234bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_ind_allocated, data->sycl_context)); 235bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_t_offsets, data->sycl_context)); 236bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_t_indices, data->sycl_context)); 237bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_l_vec_indices, data->sycl_context)); 238bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 239bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 240bd882c8aSJames Wright } 241bd882c8aSJames Wright 242bd882c8aSJames Wright //------------------------------------------------------------------------------ 243bd882c8aSJames Wright // Create transpose offsets and indices 244bd882c8aSJames Wright //------------------------------------------------------------------------------ 245bd882c8aSJames Wright static int CeedElemRestrictionOffset_Sycl(const CeedElemRestriction r, const CeedInt *indices) { 246bd882c8aSJames Wright Ceed ceed; 247*dd64fc84SJeremy L Thompson Ceed_Sycl *data; 248*dd64fc84SJeremy L Thompson bool *is_node; 249bd882c8aSJames Wright CeedSize l_size; 250*dd64fc84SJeremy L Thompson CeedInt num_elem, elem_size, num_comp, num_nodes = 0, *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices; 251*dd64fc84SJeremy L Thompson CeedElemRestriction_Sycl *impl; 252*dd64fc84SJeremy L Thompson 253*dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 254*dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 255bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 256bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 257bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size)); 258bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 259bd882c8aSJames Wright 260bd882c8aSJames Wright // Count num_nodes 261bd882c8aSJames Wright CeedCallBackend(CeedCalloc(l_size, &is_node)); 262bd882c8aSJames Wright const CeedInt size_indices = num_elem * elem_size; 263*dd64fc84SJeremy L Thompson 264bd882c8aSJames Wright for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 265bd882c8aSJames Wright for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 266bd882c8aSJames Wright impl->num_nodes = num_nodes; 267bd882c8aSJames Wright 268bd882c8aSJames Wright // L-vector offsets array 269bd882c8aSJames Wright CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 270bd882c8aSJames Wright CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 271*dd64fc84SJeremy L Thompson for (CeedInt i = 0, j = 0; i < l_size; i++) { 272bd882c8aSJames Wright if (is_node[i]) { 273bd882c8aSJames Wright l_vec_indices[j] = i; 274bd882c8aSJames Wright ind_to_offset[i] = j++; 275bd882c8aSJames Wright } 276bd882c8aSJames Wright } 277bd882c8aSJames Wright CeedCallBackend(CeedFree(&is_node)); 278bd882c8aSJames Wright 279bd882c8aSJames Wright // Compute transpose offsets and indices 280bd882c8aSJames Wright const CeedInt size_offsets = num_nodes + 1; 281*dd64fc84SJeremy L Thompson 282bd882c8aSJames Wright CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 283bd882c8aSJames Wright CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 284bd882c8aSJames Wright // Count node multiplicity 285bd882c8aSJames Wright for (CeedInt e = 0; e < num_elem; ++e) { 286bd882c8aSJames Wright for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 287bd882c8aSJames Wright } 288bd882c8aSJames Wright // Convert to running sum 289bd882c8aSJames Wright for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 290bd882c8aSJames Wright // List all E-vec indices associated with L-vec node 291bd882c8aSJames Wright for (CeedInt e = 0; e < num_elem; ++e) { 292bd882c8aSJames Wright for (CeedInt i = 0; i < elem_size; ++i) { 293bd882c8aSJames Wright const CeedInt lid = elem_size * e + i; 294bd882c8aSJames Wright const CeedInt gid = indices[lid]; 295bd882c8aSJames Wright t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 296bd882c8aSJames Wright } 297bd882c8aSJames Wright } 298bd882c8aSJames Wright // Reset running sum 299bd882c8aSJames Wright for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 300bd882c8aSJames Wright t_offsets[0] = 0; 301bd882c8aSJames Wright 302bd882c8aSJames Wright // Copy data to device 303bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 304bd882c8aSJames Wright 305bd882c8aSJames Wright // Order queue 306bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 307bd882c8aSJames Wright 308bd882c8aSJames Wright // -- L-vector indices 309bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_l_vec_indices = sycl::malloc_device<CeedInt>(num_nodes, data->sycl_device, data->sycl_context)); 310bd882c8aSJames Wright sycl::event copy_lvec = data->sycl_queue.copy<CeedInt>(l_vec_indices, impl->d_l_vec_indices, num_nodes, {e}); 311bd882c8aSJames Wright // -- Transpose offsets 312bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_t_offsets = sycl::malloc_device<CeedInt>(size_offsets, data->sycl_device, data->sycl_context)); 313bd882c8aSJames Wright sycl::event copy_offsets = data->sycl_queue.copy<CeedInt>(t_offsets, impl->d_t_offsets, size_offsets, {e}); 314bd882c8aSJames Wright // -- Transpose indices 315bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_t_indices = sycl::malloc_device<CeedInt>(size_indices, data->sycl_device, data->sycl_context)); 316bd882c8aSJames Wright sycl::event copy_indices = data->sycl_queue.copy<CeedInt>(t_indices, impl->d_t_indices, size_indices, {e}); 317bd882c8aSJames Wright 318bd882c8aSJames Wright // Wait for all copies to complete and handle exceptions 319bd882c8aSJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_lvec, copy_offsets, copy_indices})); 320bd882c8aSJames Wright 321bd882c8aSJames Wright // Cleanup 322bd882c8aSJames Wright CeedCallBackend(CeedFree(&ind_to_offset)); 323bd882c8aSJames Wright CeedCallBackend(CeedFree(&l_vec_indices)); 324bd882c8aSJames Wright CeedCallBackend(CeedFree(&t_offsets)); 325bd882c8aSJames Wright CeedCallBackend(CeedFree(&t_indices)); 326bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 327bd882c8aSJames Wright } 328bd882c8aSJames Wright 329bd882c8aSJames Wright //------------------------------------------------------------------------------ 330bd882c8aSJames Wright // Create restriction 331bd882c8aSJames Wright //------------------------------------------------------------------------------ 33200125730SSebastian Grimberg int CeedElemRestrictionCreate_Sycl(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients, 33300125730SSebastian Grimberg const CeedInt8 *curl_orients, CeedElemRestriction r) { 334bd882c8aSJames Wright Ceed ceed; 335bd882c8aSJames Wright Ceed_Sycl *data; 336*dd64fc84SJeremy L Thompson bool is_strided; 337*dd64fc84SJeremy L Thompson CeedInt num_elem, num_comp, elem_size, comp_stride = 1; 338*dd64fc84SJeremy L Thompson CeedRestrictionType rstr_type; 339bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 340*dd64fc84SJeremy L Thompson 341*dd64fc84SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 342*dd64fc84SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 343bd882c8aSJames Wright CeedCallBackend(CeedCalloc(1, &impl)); 344bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 345bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 346bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 347bd882c8aSJames Wright CeedInt size = num_elem * elem_size; 348bd882c8aSJames Wright CeedInt strides[3] = {1, size, elem_size}; 349bd882c8aSJames Wright 35000125730SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type)); 35100125730SSebastian Grimberg CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND, 35200125730SSebastian Grimberg "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented"); 35300125730SSebastian Grimberg 354bd882c8aSJames Wright // Stride data 355bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided)); 356bd882c8aSJames Wright if (is_strided) { 357bd882c8aSJames Wright bool has_backend_strides; 358*dd64fc84SJeremy L Thompson 359bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides)); 360bd882c8aSJames Wright if (!has_backend_strides) { 361bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides)); 362bd882c8aSJames Wright } 363bd882c8aSJames Wright } else { 364bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride)); 365bd882c8aSJames Wright } 366bd882c8aSJames Wright 367bd882c8aSJames Wright impl->h_ind = NULL; 368bd882c8aSJames Wright impl->h_ind_allocated = NULL; 369bd882c8aSJames Wright impl->d_ind = NULL; 370bd882c8aSJames Wright impl->d_ind_allocated = NULL; 371bd882c8aSJames Wright impl->d_t_indices = NULL; 372bd882c8aSJames Wright impl->d_t_offsets = NULL; 373bd882c8aSJames Wright impl->num_nodes = size; 374bd882c8aSJames Wright impl->num_elem = num_elem; 375bd882c8aSJames Wright impl->num_comp = num_comp; 376bd882c8aSJames Wright impl->elem_size = elem_size; 377bd882c8aSJames Wright impl->comp_stride = comp_stride; 378bd882c8aSJames Wright impl->strides[0] = strides[0]; 379bd882c8aSJames Wright impl->strides[1] = strides[1]; 380bd882c8aSJames Wright impl->strides[2] = strides[2]; 381bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionSetData(r, impl)); 382bd882c8aSJames Wright CeedInt layout[3] = {1, elem_size * num_elem, elem_size}; 383bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionSetELayout(r, layout)); 384bd882c8aSJames Wright 385bd882c8aSJames Wright // Set up device indices/offset arrays 386*dd64fc84SJeremy L Thompson if (mem_type == CEED_MEM_HOST) { 387bd882c8aSJames Wright switch (copy_mode) { 388bd882c8aSJames Wright case CEED_OWN_POINTER: 389bd882c8aSJames Wright impl->h_ind_allocated = (CeedInt *)indices; 390bd882c8aSJames Wright impl->h_ind = (CeedInt *)indices; 391bd882c8aSJames Wright break; 392bd882c8aSJames Wright case CEED_USE_POINTER: 393bd882c8aSJames Wright impl->h_ind = (CeedInt *)indices; 394bd882c8aSJames Wright break; 395bd882c8aSJames Wright case CEED_COPY_VALUES: 396bd882c8aSJames Wright if (indices != NULL) { 397bd882c8aSJames Wright CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 398bd882c8aSJames Wright memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt)); 399bd882c8aSJames Wright impl->h_ind = impl->h_ind_allocated; 400bd882c8aSJames Wright } 401bd882c8aSJames Wright break; 402bd882c8aSJames Wright } 403bd882c8aSJames Wright if (indices != NULL) { 404bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context)); 405bd882c8aSJames Wright impl->d_ind_allocated = impl->d_ind; // We own the device memory 406bd882c8aSJames Wright // Order queue 407bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 408bd882c8aSJames Wright // Copy from host to device 409bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e}); 410bd882c8aSJames Wright // Wait for copy to finish and handle exceptions 411bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 412bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffset_Sycl(r, indices)); 413bd882c8aSJames Wright } 414*dd64fc84SJeremy L Thompson } else if (mem_type == CEED_MEM_DEVICE) { 415bd882c8aSJames Wright switch (copy_mode) { 416bd882c8aSJames Wright case CEED_COPY_VALUES: 417bd882c8aSJames Wright if (indices != NULL) { 418bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context)); 419bd882c8aSJames Wright impl->d_ind_allocated = impl->d_ind; // We own the device memory 420bd882c8aSJames Wright // Copy from device to device 421bd882c8aSJames Wright // Order queue 422bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 423bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e}); 424bd882c8aSJames Wright // Wait for copy to finish and handle exceptions 425bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 426bd882c8aSJames Wright } 427bd882c8aSJames Wright break; 428bd882c8aSJames Wright case CEED_OWN_POINTER: 429bd882c8aSJames Wright impl->d_ind = (CeedInt *)indices; 430bd882c8aSJames Wright impl->d_ind_allocated = impl->d_ind; 431bd882c8aSJames Wright break; 432bd882c8aSJames Wright case CEED_USE_POINTER: 433bd882c8aSJames Wright impl->d_ind = (CeedInt *)indices; 434bd882c8aSJames Wright } 435bd882c8aSJames Wright if (indices != NULL) { 436bd882c8aSJames Wright CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 437bd882c8aSJames Wright // Order queue 438bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 439bd882c8aSJames Wright // Copy from device to host 440bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->d_ind, impl->h_ind_allocated, elem_size * num_elem, {e}); 441bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 442bd882c8aSJames Wright impl->h_ind = impl->h_ind_allocated; 443bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffset_Sycl(r, indices)); 444bd882c8aSJames Wright } 445bd882c8aSJames Wright } else { 446bd882c8aSJames Wright // LCOV_EXCL_START 447bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported"); 448bd882c8aSJames Wright // LCOV_EXCL_STOP 449bd882c8aSJames Wright } 450bd882c8aSJames Wright 451bd882c8aSJames Wright // Register backend functions 452bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Sycl)); 453bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Sycl)); 4547c1dbaffSSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Sycl)); 455bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Sycl)); 456bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Sycl)); 457bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 458bd882c8aSJames Wright } 459