1bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other 2bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE 3bd882c8aSJames Wright // files for details. 4bd882c8aSJames Wright // 5bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 6bd882c8aSJames Wright // 7bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 8bd882c8aSJames Wright 9bd882c8aSJames Wright #include <ceed/backend.h> 10bd882c8aSJames Wright #include <ceed/ceed.h> 11bd882c8aSJames Wright #include <ceed/jit-tools.h> 12bd882c8aSJames Wright 13bd882c8aSJames Wright #include <string> 14bd882c8aSJames Wright #include <sycl/sycl.hpp> 15bd882c8aSJames Wright 16bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 17bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 18bd882c8aSJames Wright 19bd882c8aSJames Wright class CeedElemRestrSyclStridedNT; 20bd882c8aSJames Wright class CeedElemRestrSyclOffsetNT; 21bd882c8aSJames Wright class CeedElemRestrSyclStridedT; 22bd882c8aSJames Wright class CeedElemRestrSyclOffsetT; 23bd882c8aSJames Wright 24bd882c8aSJames Wright //------------------------------------------------------------------------------ 25bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, strided 26bd882c8aSJames Wright //------------------------------------------------------------------------------ 27bd882c8aSJames Wright static int CeedElemRestrictionStridedNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 28bd882c8aSJames Wright CeedScalar *v) { 29bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 30bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 31bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 32bd882c8aSJames Wright const CeedInt stride_nodes = impl->strides[0]; 33bd882c8aSJames Wright const CeedInt stride_comp = impl->strides[1]; 34bd882c8aSJames Wright const CeedInt stride_elem = impl->strides[2]; 35bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 36bd882c8aSJames Wright 37bd882c8aSJames Wright // Order queue 38bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 39bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclStridedNT>(kernel_range, {e}, [=](sycl::id<1> node) { 40bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 41bd882c8aSJames Wright const CeedInt elem = node / elem_size; 42bd882c8aSJames Wright 43bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 44bd882c8aSJames Wright v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem]; 45bd882c8aSJames Wright } 46bd882c8aSJames Wright }); 47bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 48bd882c8aSJames Wright } 49bd882c8aSJames Wright 50bd882c8aSJames Wright //------------------------------------------------------------------------------ 51bd882c8aSJames Wright // Restriction Kernel : L-vector -> E-vector, offsets provided 52bd882c8aSJames Wright //------------------------------------------------------------------------------ 53bd882c8aSJames Wright static int CeedElemRestrictionOffsetNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 54bd882c8aSJames Wright CeedScalar *v) { 55bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 56bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 57bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 58bd882c8aSJames Wright const CeedInt comp_stride = impl->comp_stride; 59bd882c8aSJames Wright 60bd882c8aSJames Wright const CeedInt *indices = impl->d_ind; 61bd882c8aSJames Wright 62bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 63bd882c8aSJames Wright 64bd882c8aSJames Wright // Order queue 65bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 66bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclOffsetNT>(kernel_range, {e}, [=](sycl::id<1> node) { 67bd882c8aSJames Wright const CeedInt ind = indices[node]; 68bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 69bd882c8aSJames Wright const CeedInt elem = node / elem_size; 70bd882c8aSJames Wright 71bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 72bd882c8aSJames Wright v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[ind + comp * comp_stride]; 73bd882c8aSJames Wright } 74bd882c8aSJames Wright }); 75bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 76bd882c8aSJames Wright } 77bd882c8aSJames Wright 78bd882c8aSJames Wright //------------------------------------------------------------------------------ 79bd882c8aSJames Wright // Kernel: E-vector -> L-vector, strided 80bd882c8aSJames Wright //------------------------------------------------------------------------------ 81bd882c8aSJames Wright static int CeedElemRestrictionStridedTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 82bd882c8aSJames Wright CeedScalar *v) { 83bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 84bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 85bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 86bd882c8aSJames Wright const CeedInt stride_nodes = impl->strides[0]; 87bd882c8aSJames Wright const CeedInt stride_comp = impl->strides[1]; 88bd882c8aSJames Wright const CeedInt stride_elem = impl->strides[2]; 89bd882c8aSJames Wright 90bd882c8aSJames Wright sycl::range<1> kernel_range(num_elem * elem_size); 91bd882c8aSJames Wright 92bd882c8aSJames Wright // Order queue 93bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 94bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclStridedT>(kernel_range, {e}, [=](sycl::id<1> node) { 95bd882c8aSJames Wright const CeedInt loc_node = node % elem_size; 96bd882c8aSJames Wright const CeedInt elem = node / elem_size; 97bd882c8aSJames Wright 98bd882c8aSJames Wright for (CeedInt comp = 0; comp < num_comp; comp++) { 99bd882c8aSJames Wright v[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem] += u[loc_node + comp * elem_size * num_elem + elem * elem_size]; 100bd882c8aSJames Wright } 101bd882c8aSJames Wright }); 102bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 103bd882c8aSJames Wright } 104bd882c8aSJames Wright 105bd882c8aSJames Wright //------------------------------------------------------------------------------ 106bd882c8aSJames Wright // Kernel: E-vector -> L-vector, offsets provided 107bd882c8aSJames Wright //------------------------------------------------------------------------------ 108bd882c8aSJames Wright static int CeedElemRestrictionOffsetTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u, 109bd882c8aSJames Wright CeedScalar *v) { 110bd882c8aSJames Wright const CeedInt num_nodes = impl->num_nodes; 111bd882c8aSJames Wright const CeedInt elem_size = impl->elem_size; 112bd882c8aSJames Wright const CeedInt num_elem = impl->num_elem; 113bd882c8aSJames Wright const CeedInt num_comp = impl->num_comp; 114bd882c8aSJames Wright const CeedInt comp_stride = impl->comp_stride; 115bd882c8aSJames Wright 116bd882c8aSJames Wright const CeedInt *l_vec_indices = impl->d_l_vec_indices; 117bd882c8aSJames Wright const CeedInt *t_offsets = impl->d_t_offsets; 118bd882c8aSJames Wright const CeedInt *t_indices = impl->d_t_indices; 119bd882c8aSJames Wright 120bd882c8aSJames Wright sycl::range<1> kernel_range(num_nodes * num_comp); 121bd882c8aSJames Wright 122bd882c8aSJames Wright // Order queue 123bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 124bd882c8aSJames Wright sycl_queue.parallel_for<CeedElemRestrSyclOffsetT>(kernel_range, {e}, [=](sycl::id<1> id) { 125bd882c8aSJames Wright const CeedInt node = id % num_nodes; 126bd882c8aSJames Wright const CeedInt comp = id / num_nodes; 127bd882c8aSJames Wright const CeedInt ind = l_vec_indices[node]; 128bd882c8aSJames Wright const CeedInt range_1 = t_offsets[node]; 129bd882c8aSJames Wright const CeedInt range_N = t_offsets[node + 1]; 130bd882c8aSJames Wright 131bd882c8aSJames Wright CeedScalar value = 0.0; 132bd882c8aSJames Wright 133bd882c8aSJames Wright for (CeedInt j = range_1; j < range_N; j++) { 134bd882c8aSJames Wright const CeedInt t_ind = t_indices[j]; 135bd882c8aSJames Wright CeedInt loc_node = t_ind % elem_size; 136bd882c8aSJames Wright CeedInt elem = t_ind / elem_size; 137bd882c8aSJames Wright 138bd882c8aSJames Wright value += u[loc_node + comp * elem_size * num_elem + elem * elem_size]; 139bd882c8aSJames Wright } 140bd882c8aSJames Wright v[ind + comp * comp_stride] += value; 141bd882c8aSJames Wright }); 142bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 143bd882c8aSJames Wright } 144bd882c8aSJames Wright 145bd882c8aSJames Wright //------------------------------------------------------------------------------ 146bd882c8aSJames Wright // Apply restriction 147bd882c8aSJames Wright //------------------------------------------------------------------------------ 148bd882c8aSJames Wright static int CeedElemRestrictionApply_Sycl(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 149bd882c8aSJames Wright Ceed ceed; 150bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 151bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 152bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 153bd882c8aSJames Wright Ceed_Sycl *data; 154bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 155bd882c8aSJames Wright 156bd882c8aSJames Wright // Get vectors 157bd882c8aSJames Wright const CeedScalar *d_u; 158bd882c8aSJames Wright CeedScalar *d_v; 159bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 160bd882c8aSJames Wright if (t_mode == CEED_TRANSPOSE) { 161bd882c8aSJames Wright // Sum into for transpose mode, e-vec to l-vec 162bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 163bd882c8aSJames Wright } else { 164bd882c8aSJames Wright // Overwrite for notranspose mode, l-vec to e-vec 165bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 166bd882c8aSJames Wright } 167bd882c8aSJames Wright 168bd882c8aSJames Wright // Restrict 169bd882c8aSJames Wright if (t_mode == CEED_NOTRANSPOSE) { 170bd882c8aSJames Wright // L-vector -> E-vector 171bd882c8aSJames Wright if (impl->d_ind) { 172bd882c8aSJames Wright // -- Offsets provided 173bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffsetNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 174bd882c8aSJames Wright } else { 175bd882c8aSJames Wright // -- Strided restriction 176bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionStridedNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 177bd882c8aSJames Wright } 178bd882c8aSJames Wright } else { 179bd882c8aSJames Wright // E-vector -> L-vector 180bd882c8aSJames Wright if (impl->d_ind) { 181bd882c8aSJames Wright // -- Offsets provided 182bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffsetTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 183bd882c8aSJames Wright } else { 184bd882c8aSJames Wright // -- Strided restriction 185bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionStridedTranspose_Sycl(data->sycl_queue, impl, d_u, d_v)); 186bd882c8aSJames Wright } 187bd882c8aSJames Wright } 188bd882c8aSJames Wright // Wait for queues to be completed. NOTE: This may not be necessary 189bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 190bd882c8aSJames Wright 191bd882c8aSJames Wright if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 192bd882c8aSJames Wright 193bd882c8aSJames Wright // Restore arrays 194bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 195bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 196bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 197bd882c8aSJames Wright } 198bd882c8aSJames Wright 199bd882c8aSJames Wright //------------------------------------------------------------------------------ 200bd882c8aSJames Wright // Get offsets 201bd882c8aSJames Wright //------------------------------------------------------------------------------ 202bd882c8aSJames Wright static int CeedElemRestrictionGetOffsets_Sycl(CeedElemRestriction r, CeedMemType m_type, const CeedInt **offsets) { 203bd882c8aSJames Wright Ceed ceed; 204bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 205bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 206bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 207bd882c8aSJames Wright 208bd882c8aSJames Wright switch (m_type) { 209bd882c8aSJames Wright case CEED_MEM_HOST: 210bd882c8aSJames Wright *offsets = impl->h_ind; 211bd882c8aSJames Wright break; 212bd882c8aSJames Wright case CEED_MEM_DEVICE: 213bd882c8aSJames Wright *offsets = impl->d_ind; 214bd882c8aSJames Wright break; 215bd882c8aSJames Wright } 216bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 217bd882c8aSJames Wright } 218bd882c8aSJames Wright 219bd882c8aSJames Wright //------------------------------------------------------------------------------ 220bd882c8aSJames Wright // Destroy restriction 221bd882c8aSJames Wright //------------------------------------------------------------------------------ 222bd882c8aSJames Wright static int CeedElemRestrictionDestroy_Sycl(CeedElemRestriction r) { 223bd882c8aSJames Wright Ceed ceed; 224bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 225bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 226bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 227bd882c8aSJames Wright Ceed_Sycl *data; 228bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 229bd882c8aSJames Wright 230bd882c8aSJames Wright // Wait for all work to finish before freeing memory 231bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 232bd882c8aSJames Wright 233bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 234bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_ind_allocated, data->sycl_context)); 235bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_t_offsets, data->sycl_context)); 236bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_t_indices, data->sycl_context)); 237bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_l_vec_indices, data->sycl_context)); 238bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 239bd882c8aSJames Wright 240bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 241bd882c8aSJames Wright } 242bd882c8aSJames Wright 243bd882c8aSJames Wright //------------------------------------------------------------------------------ 244bd882c8aSJames Wright // Create transpose offsets and indices 245bd882c8aSJames Wright //------------------------------------------------------------------------------ 246bd882c8aSJames Wright static int CeedElemRestrictionOffset_Sycl(const CeedElemRestriction r, const CeedInt *indices) { 247bd882c8aSJames Wright Ceed ceed; 248bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 249bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 250bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 251bd882c8aSJames Wright CeedSize l_size; 252bd882c8aSJames Wright CeedInt num_elem, elem_size, num_comp; 253bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 254bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 255bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size)); 256bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 257bd882c8aSJames Wright 258bd882c8aSJames Wright // Count num_nodes 259bd882c8aSJames Wright bool *is_node; 260bd882c8aSJames Wright CeedCallBackend(CeedCalloc(l_size, &is_node)); 261bd882c8aSJames Wright const CeedInt size_indices = num_elem * elem_size; 262bd882c8aSJames Wright for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 263bd882c8aSJames Wright CeedInt num_nodes = 0; 264bd882c8aSJames Wright for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 265bd882c8aSJames Wright impl->num_nodes = num_nodes; 266bd882c8aSJames Wright 267bd882c8aSJames Wright // L-vector offsets array 268bd882c8aSJames Wright CeedInt *ind_to_offset, *l_vec_indices; 269bd882c8aSJames Wright CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 270bd882c8aSJames Wright CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 271bd882c8aSJames Wright CeedInt j = 0; 272bd882c8aSJames Wright for (CeedInt i = 0; i < l_size; i++) { 273bd882c8aSJames Wright if (is_node[i]) { 274bd882c8aSJames Wright l_vec_indices[j] = i; 275bd882c8aSJames Wright ind_to_offset[i] = j++; 276bd882c8aSJames Wright } 277bd882c8aSJames Wright } 278bd882c8aSJames Wright CeedCallBackend(CeedFree(&is_node)); 279bd882c8aSJames Wright 280bd882c8aSJames Wright // Compute transpose offsets and indices 281bd882c8aSJames Wright const CeedInt size_offsets = num_nodes + 1; 282bd882c8aSJames Wright CeedInt *t_offsets; 283bd882c8aSJames Wright CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 284bd882c8aSJames Wright CeedInt *t_indices; 285bd882c8aSJames Wright CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 286bd882c8aSJames Wright // Count node multiplicity 287bd882c8aSJames Wright for (CeedInt e = 0; e < num_elem; ++e) { 288bd882c8aSJames Wright for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 289bd882c8aSJames Wright } 290bd882c8aSJames Wright // Convert to running sum 291bd882c8aSJames Wright for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 292bd882c8aSJames Wright // List all E-vec indices associated with L-vec node 293bd882c8aSJames Wright for (CeedInt e = 0; e < num_elem; ++e) { 294bd882c8aSJames Wright for (CeedInt i = 0; i < elem_size; ++i) { 295bd882c8aSJames Wright const CeedInt lid = elem_size * e + i; 296bd882c8aSJames Wright const CeedInt gid = indices[lid]; 297bd882c8aSJames Wright t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 298bd882c8aSJames Wright } 299bd882c8aSJames Wright } 300bd882c8aSJames Wright // Reset running sum 301bd882c8aSJames Wright for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 302bd882c8aSJames Wright t_offsets[0] = 0; 303bd882c8aSJames Wright 304bd882c8aSJames Wright // Copy data to device 305bd882c8aSJames Wright Ceed_Sycl *data; 306bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 307bd882c8aSJames Wright 308bd882c8aSJames Wright // Order queue 309bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 310bd882c8aSJames Wright 311bd882c8aSJames Wright // -- L-vector indices 312bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_l_vec_indices = sycl::malloc_device<CeedInt>(num_nodes, data->sycl_device, data->sycl_context)); 313bd882c8aSJames Wright sycl::event copy_lvec = data->sycl_queue.copy<CeedInt>(l_vec_indices, impl->d_l_vec_indices, num_nodes, {e}); 314bd882c8aSJames Wright // -- Transpose offsets 315bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_t_offsets = sycl::malloc_device<CeedInt>(size_offsets, data->sycl_device, data->sycl_context)); 316bd882c8aSJames Wright sycl::event copy_offsets = data->sycl_queue.copy<CeedInt>(t_offsets, impl->d_t_offsets, size_offsets, {e}); 317bd882c8aSJames Wright // -- Transpose indices 318bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_t_indices = sycl::malloc_device<CeedInt>(size_indices, data->sycl_device, data->sycl_context)); 319bd882c8aSJames Wright sycl::event copy_indices = data->sycl_queue.copy<CeedInt>(t_indices, impl->d_t_indices, size_indices, {e}); 320bd882c8aSJames Wright 321bd882c8aSJames Wright // Wait for all copies to complete and handle exceptions 322bd882c8aSJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_lvec, copy_offsets, copy_indices})); 323bd882c8aSJames Wright 324bd882c8aSJames Wright // Cleanup 325bd882c8aSJames Wright CeedCallBackend(CeedFree(&ind_to_offset)); 326bd882c8aSJames Wright CeedCallBackend(CeedFree(&l_vec_indices)); 327bd882c8aSJames Wright CeedCallBackend(CeedFree(&t_offsets)); 328bd882c8aSJames Wright CeedCallBackend(CeedFree(&t_indices)); 329bd882c8aSJames Wright 330bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 331bd882c8aSJames Wright } 332bd882c8aSJames Wright 333bd882c8aSJames Wright //------------------------------------------------------------------------------ 334bd882c8aSJames Wright // Create restriction 335bd882c8aSJames Wright //------------------------------------------------------------------------------ 336*00125730SSebastian Grimberg int CeedElemRestrictionCreate_Sycl(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients, 337*00125730SSebastian Grimberg const CeedInt8 *curl_orients, CeedElemRestriction r) { 338bd882c8aSJames Wright Ceed ceed; 339bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 340bd882c8aSJames Wright Ceed_Sycl *data; 341bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 342bd882c8aSJames Wright CeedElemRestriction_Sycl *impl; 343bd882c8aSJames Wright CeedCallBackend(CeedCalloc(1, &impl)); 344bd882c8aSJames Wright CeedInt num_elem, num_comp, elem_size; 345bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 346bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 347bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 348bd882c8aSJames Wright CeedInt size = num_elem * elem_size; 349bd882c8aSJames Wright CeedInt strides[3] = {1, size, elem_size}; 350bd882c8aSJames Wright CeedInt comp_stride = 1; 351bd882c8aSJames Wright 352*00125730SSebastian Grimberg CeedRestrictionType rstr_type; 353*00125730SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type)); 354*00125730SSebastian Grimberg CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND, 355*00125730SSebastian Grimberg "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented"); 356*00125730SSebastian Grimberg 357bd882c8aSJames Wright // Stride data 358bd882c8aSJames Wright bool is_strided; 359bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided)); 360bd882c8aSJames Wright if (is_strided) { 361bd882c8aSJames Wright bool has_backend_strides; 362bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides)); 363bd882c8aSJames Wright if (!has_backend_strides) { 364bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides)); 365bd882c8aSJames Wright } 366bd882c8aSJames Wright } else { 367bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride)); 368bd882c8aSJames Wright } 369bd882c8aSJames Wright 370bd882c8aSJames Wright impl->h_ind = NULL; 371bd882c8aSJames Wright impl->h_ind_allocated = NULL; 372bd882c8aSJames Wright impl->d_ind = NULL; 373bd882c8aSJames Wright impl->d_ind_allocated = NULL; 374bd882c8aSJames Wright impl->d_t_indices = NULL; 375bd882c8aSJames Wright impl->d_t_offsets = NULL; 376bd882c8aSJames Wright impl->num_nodes = size; 377bd882c8aSJames Wright impl->num_elem = num_elem; 378bd882c8aSJames Wright impl->num_comp = num_comp; 379bd882c8aSJames Wright impl->elem_size = elem_size; 380bd882c8aSJames Wright impl->comp_stride = comp_stride; 381bd882c8aSJames Wright impl->strides[0] = strides[0]; 382bd882c8aSJames Wright impl->strides[1] = strides[1]; 383bd882c8aSJames Wright impl->strides[2] = strides[2]; 384bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionSetData(r, impl)); 385bd882c8aSJames Wright CeedInt layout[3] = {1, elem_size * num_elem, elem_size}; 386bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionSetELayout(r, layout)); 387bd882c8aSJames Wright 388bd882c8aSJames Wright // Set up device indices/offset arrays 389bd882c8aSJames Wright if (m_type == CEED_MEM_HOST) { 390bd882c8aSJames Wright switch (copy_mode) { 391bd882c8aSJames Wright case CEED_OWN_POINTER: 392bd882c8aSJames Wright impl->h_ind_allocated = (CeedInt *)indices; 393bd882c8aSJames Wright impl->h_ind = (CeedInt *)indices; 394bd882c8aSJames Wright break; 395bd882c8aSJames Wright case CEED_USE_POINTER: 396bd882c8aSJames Wright impl->h_ind = (CeedInt *)indices; 397bd882c8aSJames Wright break; 398bd882c8aSJames Wright case CEED_COPY_VALUES: 399bd882c8aSJames Wright if (indices != NULL) { 400bd882c8aSJames Wright CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 401bd882c8aSJames Wright memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt)); 402bd882c8aSJames Wright impl->h_ind = impl->h_ind_allocated; 403bd882c8aSJames Wright } 404bd882c8aSJames Wright break; 405bd882c8aSJames Wright } 406bd882c8aSJames Wright if (indices != NULL) { 407bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context)); 408bd882c8aSJames Wright impl->d_ind_allocated = impl->d_ind; // We own the device memory 409bd882c8aSJames Wright // Order queue 410bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 411bd882c8aSJames Wright // Copy from host to device 412bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e}); 413bd882c8aSJames Wright // Wait for copy to finish and handle exceptions 414bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 415bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffset_Sycl(r, indices)); 416bd882c8aSJames Wright } 417bd882c8aSJames Wright } else if (m_type == CEED_MEM_DEVICE) { 418bd882c8aSJames Wright switch (copy_mode) { 419bd882c8aSJames Wright case CEED_COPY_VALUES: 420bd882c8aSJames Wright if (indices != NULL) { 421bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context)); 422bd882c8aSJames Wright impl->d_ind_allocated = impl->d_ind; // We own the device memory 423bd882c8aSJames Wright // Copy from device to device 424bd882c8aSJames Wright // Order queue 425bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 426bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e}); 427bd882c8aSJames Wright // Wait for copy to finish and handle exceptions 428bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 429bd882c8aSJames Wright } 430bd882c8aSJames Wright break; 431bd882c8aSJames Wright case CEED_OWN_POINTER: 432bd882c8aSJames Wright impl->d_ind = (CeedInt *)indices; 433bd882c8aSJames Wright impl->d_ind_allocated = impl->d_ind; 434bd882c8aSJames Wright break; 435bd882c8aSJames Wright case CEED_USE_POINTER: 436bd882c8aSJames Wright impl->d_ind = (CeedInt *)indices; 437bd882c8aSJames Wright } 438bd882c8aSJames Wright if (indices != NULL) { 439bd882c8aSJames Wright CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 440bd882c8aSJames Wright // Order queue 441bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 442bd882c8aSJames Wright // Copy from device to host 443bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->d_ind, impl->h_ind_allocated, elem_size * num_elem, {e}); 444bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 445bd882c8aSJames Wright impl->h_ind = impl->h_ind_allocated; 446bd882c8aSJames Wright CeedCallBackend(CeedElemRestrictionOffset_Sycl(r, indices)); 447bd882c8aSJames Wright } 448bd882c8aSJames Wright } else { 449bd882c8aSJames Wright // LCOV_EXCL_START 450bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported"); 451bd882c8aSJames Wright // LCOV_EXCL_STOP 452bd882c8aSJames Wright } 453bd882c8aSJames Wright 454bd882c8aSJames Wright // Register backend functions 455bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Sycl)); 456bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Sycl)); 4577c1dbaffSSebastian Grimberg CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Sycl)); 458bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Sycl)); 459bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Sycl)); 460bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 461bd882c8aSJames Wright } 462