1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*bd882c8aSJames Wright // 4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5*bd882c8aSJames Wright // 6*bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7*bd882c8aSJames Wright 8*bd882c8aSJames Wright #include <ceed/backend.h> 9*bd882c8aSJames Wright #include <ceed/ceed.h> 10*bd882c8aSJames Wright 11*bd882c8aSJames Wright #include <cmath> 12*bd882c8aSJames Wright #include <string> 13*bd882c8aSJames Wright #include <sycl/sycl.hpp> 14*bd882c8aSJames Wright 15*bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 16*bd882c8aSJames Wright 17*bd882c8aSJames Wright //------------------------------------------------------------------------------ 18*bd882c8aSJames Wright // Check if host/device sync is needed 19*bd882c8aSJames Wright //------------------------------------------------------------------------------ 20*bd882c8aSJames Wright static inline int CeedVectorNeedSync_Sycl(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 21*bd882c8aSJames Wright CeedVector_Sycl *impl; 22*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 23*bd882c8aSJames Wright 24*bd882c8aSJames Wright bool has_valid_array = false; 25*bd882c8aSJames Wright CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 26*bd882c8aSJames Wright switch (mem_type) { 27*bd882c8aSJames Wright case CEED_MEM_HOST: 28*bd882c8aSJames Wright *need_sync = has_valid_array && !impl->h_array; 29*bd882c8aSJames Wright break; 30*bd882c8aSJames Wright case CEED_MEM_DEVICE: 31*bd882c8aSJames Wright *need_sync = has_valid_array && !impl->d_array; 32*bd882c8aSJames Wright break; 33*bd882c8aSJames Wright } 34*bd882c8aSJames Wright 35*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 36*bd882c8aSJames Wright } 37*bd882c8aSJames Wright 38*bd882c8aSJames Wright //------------------------------------------------------------------------------ 39*bd882c8aSJames Wright // Sync host to device 40*bd882c8aSJames Wright //------------------------------------------------------------------------------ 41*bd882c8aSJames Wright static inline int CeedVectorSyncH2D_Sycl(const CeedVector vec) { 42*bd882c8aSJames Wright Ceed ceed; 43*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 44*bd882c8aSJames Wright CeedVector_Sycl *impl; 45*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 46*bd882c8aSJames Wright Ceed_Sycl *data; 47*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 48*bd882c8aSJames Wright 49*bd882c8aSJames Wright if (!impl->h_array) { 50*bd882c8aSJames Wright // LCOV_EXCL_START 51*bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 52*bd882c8aSJames Wright // LCOV_EXCL_STOP 53*bd882c8aSJames Wright } 54*bd882c8aSJames Wright 55*bd882c8aSJames Wright CeedSize length; 56*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 57*bd882c8aSJames Wright 58*bd882c8aSJames Wright if (impl->d_array_borrowed) { 59*bd882c8aSJames Wright impl->d_array = impl->d_array_borrowed; 60*bd882c8aSJames Wright } else if (impl->d_array_owned) { 61*bd882c8aSJames Wright impl->d_array = impl->d_array_owned; 62*bd882c8aSJames Wright } else { 63*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 64*bd882c8aSJames Wright impl->d_array = impl->d_array_owned; 65*bd882c8aSJames Wright } 66*bd882c8aSJames Wright 67*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 68*bd882c8aSJames Wright // Copy from host to device 69*bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->h_array, impl->d_array, length, {e}); 70*bd882c8aSJames Wright // Wait for copy to finish and handle exceptions. 71*bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 72*bd882c8aSJames Wright 73*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 74*bd882c8aSJames Wright } 75*bd882c8aSJames Wright 76*bd882c8aSJames Wright //------------------------------------------------------------------------------ 77*bd882c8aSJames Wright // Sync device to host 78*bd882c8aSJames Wright //------------------------------------------------------------------------------ 79*bd882c8aSJames Wright static inline int CeedVectorSyncD2H_Sycl(const CeedVector vec) { 80*bd882c8aSJames Wright Ceed ceed; 81*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 82*bd882c8aSJames Wright CeedVector_Sycl *impl; 83*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 84*bd882c8aSJames Wright Ceed_Sycl *data; 85*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 86*bd882c8aSJames Wright 87*bd882c8aSJames Wright CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 88*bd882c8aSJames Wright 89*bd882c8aSJames Wright CeedSize length; 90*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 91*bd882c8aSJames Wright 92*bd882c8aSJames Wright if (impl->h_array_borrowed) { 93*bd882c8aSJames Wright impl->h_array = impl->h_array_borrowed; 94*bd882c8aSJames Wright } else if (impl->h_array_owned) { 95*bd882c8aSJames Wright impl->h_array = impl->h_array_owned; 96*bd882c8aSJames Wright } else { 97*bd882c8aSJames Wright CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 98*bd882c8aSJames Wright impl->h_array = impl->h_array_owned; 99*bd882c8aSJames Wright } 100*bd882c8aSJames Wright 101*bd882c8aSJames Wright // Order queue 102*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 103*bd882c8aSJames Wright // Copy from device to host 104*bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->d_array, impl->h_array, length, {e}); 105*bd882c8aSJames Wright 106*bd882c8aSJames Wright // Wait for copy to finish and handle exceptions. 107*bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 108*bd882c8aSJames Wright 109*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 110*bd882c8aSJames Wright } 111*bd882c8aSJames Wright 112*bd882c8aSJames Wright //------------------------------------------------------------------------------ 113*bd882c8aSJames Wright // Sync arrays 114*bd882c8aSJames Wright //------------------------------------------------------------------------------ 115*bd882c8aSJames Wright static int CeedVectorSyncArray_Sycl(const CeedVector vec, CeedMemType mem_type) { 116*bd882c8aSJames Wright // Check whether device/host sync is needed 117*bd882c8aSJames Wright bool need_sync = false; 118*bd882c8aSJames Wright CeedCallBackend(CeedVectorNeedSync_Sycl(vec, mem_type, &need_sync)); 119*bd882c8aSJames Wright if (!need_sync) return CEED_ERROR_SUCCESS; 120*bd882c8aSJames Wright 121*bd882c8aSJames Wright switch (mem_type) { 122*bd882c8aSJames Wright case CEED_MEM_HOST: 123*bd882c8aSJames Wright return CeedVectorSyncD2H_Sycl(vec); 124*bd882c8aSJames Wright case CEED_MEM_DEVICE: 125*bd882c8aSJames Wright return CeedVectorSyncH2D_Sycl(vec); 126*bd882c8aSJames Wright } 127*bd882c8aSJames Wright return CEED_ERROR_UNSUPPORTED; 128*bd882c8aSJames Wright } 129*bd882c8aSJames Wright 130*bd882c8aSJames Wright //------------------------------------------------------------------------------ 131*bd882c8aSJames Wright // Set all pointers as invalid 132*bd882c8aSJames Wright //------------------------------------------------------------------------------ 133*bd882c8aSJames Wright static inline int CeedVectorSetAllInvalid_Sycl(const CeedVector vec) { 134*bd882c8aSJames Wright CeedVector_Sycl *impl; 135*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 136*bd882c8aSJames Wright 137*bd882c8aSJames Wright impl->h_array = NULL; 138*bd882c8aSJames Wright impl->d_array = NULL; 139*bd882c8aSJames Wright 140*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 141*bd882c8aSJames Wright } 142*bd882c8aSJames Wright 143*bd882c8aSJames Wright //------------------------------------------------------------------------------ 144*bd882c8aSJames Wright // Check if CeedVector has any valid pointer 145*bd882c8aSJames Wright //------------------------------------------------------------------------------ 146*bd882c8aSJames Wright static inline int CeedVectorHasValidArray_Sycl(const CeedVector vec, bool *has_valid_array) { 147*bd882c8aSJames Wright CeedVector_Sycl *impl; 148*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 149*bd882c8aSJames Wright 150*bd882c8aSJames Wright *has_valid_array = !!impl->h_array || !!impl->d_array; 151*bd882c8aSJames Wright 152*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 153*bd882c8aSJames Wright } 154*bd882c8aSJames Wright 155*bd882c8aSJames Wright //------------------------------------------------------------------------------ 156*bd882c8aSJames Wright // Check if has array of given type 157*bd882c8aSJames Wright //------------------------------------------------------------------------------ 158*bd882c8aSJames Wright static inline int CeedVectorHasArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 159*bd882c8aSJames Wright CeedVector_Sycl *impl; 160*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 161*bd882c8aSJames Wright 162*bd882c8aSJames Wright switch (mem_type) { 163*bd882c8aSJames Wright case CEED_MEM_HOST: 164*bd882c8aSJames Wright *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned; 165*bd882c8aSJames Wright break; 166*bd882c8aSJames Wright case CEED_MEM_DEVICE: 167*bd882c8aSJames Wright *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned; 168*bd882c8aSJames Wright break; 169*bd882c8aSJames Wright } 170*bd882c8aSJames Wright 171*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 172*bd882c8aSJames Wright } 173*bd882c8aSJames Wright 174*bd882c8aSJames Wright //------------------------------------------------------------------------------ 175*bd882c8aSJames Wright // Check if has borrowed array of given type 176*bd882c8aSJames Wright //------------------------------------------------------------------------------ 177*bd882c8aSJames Wright static inline int CeedVectorHasBorrowedArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 178*bd882c8aSJames Wright CeedVector_Sycl *impl; 179*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 180*bd882c8aSJames Wright 181*bd882c8aSJames Wright switch (mem_type) { 182*bd882c8aSJames Wright case CEED_MEM_HOST: 183*bd882c8aSJames Wright *has_borrowed_array_of_type = !!impl->h_array_borrowed; 184*bd882c8aSJames Wright break; 185*bd882c8aSJames Wright case CEED_MEM_DEVICE: 186*bd882c8aSJames Wright *has_borrowed_array_of_type = !!impl->d_array_borrowed; 187*bd882c8aSJames Wright break; 188*bd882c8aSJames Wright } 189*bd882c8aSJames Wright 190*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 191*bd882c8aSJames Wright } 192*bd882c8aSJames Wright 193*bd882c8aSJames Wright //------------------------------------------------------------------------------ 194*bd882c8aSJames Wright // Set array from host 195*bd882c8aSJames Wright //------------------------------------------------------------------------------ 196*bd882c8aSJames Wright static int CeedVectorSetArrayHost_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 197*bd882c8aSJames Wright CeedVector_Sycl *impl; 198*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 199*bd882c8aSJames Wright 200*bd882c8aSJames Wright switch (copy_mode) { 201*bd882c8aSJames Wright case CEED_COPY_VALUES: { 202*bd882c8aSJames Wright if (!impl->h_array_owned) { 203*bd882c8aSJames Wright CeedSize length; 204*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 205*bd882c8aSJames Wright CeedCallBackend(CeedMalloc(length, &impl->h_array_owned)); 206*bd882c8aSJames Wright } 207*bd882c8aSJames Wright impl->h_array_borrowed = NULL; 208*bd882c8aSJames Wright impl->h_array = impl->h_array_owned; 209*bd882c8aSJames Wright if (array) { 210*bd882c8aSJames Wright CeedSize length; 211*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 212*bd882c8aSJames Wright size_t bytes = length * sizeof(CeedScalar); 213*bd882c8aSJames Wright memcpy(impl->h_array, array, bytes); 214*bd882c8aSJames Wright } 215*bd882c8aSJames Wright } break; 216*bd882c8aSJames Wright case CEED_OWN_POINTER: 217*bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl->h_array_owned)); 218*bd882c8aSJames Wright impl->h_array_owned = array; 219*bd882c8aSJames Wright impl->h_array_borrowed = NULL; 220*bd882c8aSJames Wright impl->h_array = array; 221*bd882c8aSJames Wright break; 222*bd882c8aSJames Wright case CEED_USE_POINTER: 223*bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl->h_array_owned)); 224*bd882c8aSJames Wright impl->h_array_borrowed = array; 225*bd882c8aSJames Wright impl->h_array = array; 226*bd882c8aSJames Wright break; 227*bd882c8aSJames Wright } 228*bd882c8aSJames Wright 229*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 230*bd882c8aSJames Wright } 231*bd882c8aSJames Wright 232*bd882c8aSJames Wright //------------------------------------------------------------------------------ 233*bd882c8aSJames Wright // Set array from device 234*bd882c8aSJames Wright //------------------------------------------------------------------------------ 235*bd882c8aSJames Wright static int CeedVectorSetArrayDevice_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 236*bd882c8aSJames Wright Ceed ceed; 237*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 238*bd882c8aSJames Wright CeedVector_Sycl *impl; 239*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 240*bd882c8aSJames Wright Ceed_Sycl *data; 241*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 242*bd882c8aSJames Wright 243*bd882c8aSJames Wright // Order queue 244*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 245*bd882c8aSJames Wright 246*bd882c8aSJames Wright switch (copy_mode) { 247*bd882c8aSJames Wright case CEED_COPY_VALUES: { 248*bd882c8aSJames Wright CeedSize length; 249*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 250*bd882c8aSJames Wright if (!impl->d_array_owned) { 251*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 252*bd882c8aSJames Wright impl->d_array = impl->d_array_owned; 253*bd882c8aSJames Wright } 254*bd882c8aSJames Wright if (array) { 255*bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(array, impl->d_array, length, {e}); 256*bd882c8aSJames Wright // Wait for copy to finish and handle exceptions. 257*bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 258*bd882c8aSJames Wright } 259*bd882c8aSJames Wright } break; 260*bd882c8aSJames Wright case CEED_OWN_POINTER: 261*bd882c8aSJames Wright if (impl->d_array_owned) { 262*bd882c8aSJames Wright // Wait for all work to finish before freeing memory 263*bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 264*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 265*bd882c8aSJames Wright } 266*bd882c8aSJames Wright impl->d_array_owned = array; 267*bd882c8aSJames Wright impl->d_array_borrowed = NULL; 268*bd882c8aSJames Wright impl->d_array = array; 269*bd882c8aSJames Wright break; 270*bd882c8aSJames Wright case CEED_USE_POINTER: 271*bd882c8aSJames Wright if (impl->d_array_owned) { 272*bd882c8aSJames Wright // Wait for all work to finish before freeing memory 273*bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 274*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 275*bd882c8aSJames Wright } 276*bd882c8aSJames Wright impl->d_array_owned = NULL; 277*bd882c8aSJames Wright impl->d_array_borrowed = array; 278*bd882c8aSJames Wright impl->d_array = array; 279*bd882c8aSJames Wright break; 280*bd882c8aSJames Wright } 281*bd882c8aSJames Wright 282*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 283*bd882c8aSJames Wright } 284*bd882c8aSJames Wright 285*bd882c8aSJames Wright //------------------------------------------------------------------------------ 286*bd882c8aSJames Wright // Set the array used by a vector, 287*bd882c8aSJames Wright // freeing any previously allocated array if applicable 288*bd882c8aSJames Wright //------------------------------------------------------------------------------ 289*bd882c8aSJames Wright static int CeedVectorSetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 290*bd882c8aSJames Wright Ceed ceed; 291*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 292*bd882c8aSJames Wright CeedVector_Sycl *impl; 293*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 294*bd882c8aSJames Wright 295*bd882c8aSJames Wright CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 296*bd882c8aSJames Wright switch (mem_type) { 297*bd882c8aSJames Wright case CEED_MEM_HOST: 298*bd882c8aSJames Wright return CeedVectorSetArrayHost_Sycl(vec, copy_mode, array); 299*bd882c8aSJames Wright case CEED_MEM_DEVICE: 300*bd882c8aSJames Wright return CeedVectorSetArrayDevice_Sycl(vec, copy_mode, array); 301*bd882c8aSJames Wright } 302*bd882c8aSJames Wright 303*bd882c8aSJames Wright return CEED_ERROR_UNSUPPORTED; 304*bd882c8aSJames Wright } 305*bd882c8aSJames Wright 306*bd882c8aSJames Wright //------------------------------------------------------------------------------ 307*bd882c8aSJames Wright // Set host array to value 308*bd882c8aSJames Wright //------------------------------------------------------------------------------ 309*bd882c8aSJames Wright static int CeedHostSetValue_Sycl(CeedScalar *h_array, CeedInt length, CeedScalar val) { 310*bd882c8aSJames Wright for (int i = 0; i < length; i++) h_array[i] = val; 311*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 312*bd882c8aSJames Wright } 313*bd882c8aSJames Wright 314*bd882c8aSJames Wright //------------------------------------------------------------------------------ 315*bd882c8aSJames Wright // Set device array to value 316*bd882c8aSJames Wright //------------------------------------------------------------------------------ 317*bd882c8aSJames Wright static int CeedDeviceSetValue_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedInt length, CeedScalar val) { 318*bd882c8aSJames Wright // Order queue 319*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 320*bd882c8aSJames Wright sycl_queue.fill(d_array, val, length, {e}); 321*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 322*bd882c8aSJames Wright } 323*bd882c8aSJames Wright 324*bd882c8aSJames Wright //------------------------------------------------------------------------------ 325*bd882c8aSJames Wright // Set a vector to a value, 326*bd882c8aSJames Wright //------------------------------------------------------------------------------ 327*bd882c8aSJames Wright static int CeedVectorSetValue_Sycl(CeedVector vec, CeedScalar val) { 328*bd882c8aSJames Wright Ceed ceed; 329*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 330*bd882c8aSJames Wright CeedVector_Sycl *impl; 331*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 332*bd882c8aSJames Wright CeedSize length; 333*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 334*bd882c8aSJames Wright Ceed_Sycl *data; 335*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 336*bd882c8aSJames Wright 337*bd882c8aSJames Wright // Set value for synced device/host array 338*bd882c8aSJames Wright if (!impl->d_array && !impl->h_array) { 339*bd882c8aSJames Wright if (impl->d_array_borrowed) { 340*bd882c8aSJames Wright impl->d_array = impl->d_array_borrowed; 341*bd882c8aSJames Wright } else if (impl->h_array_borrowed) { 342*bd882c8aSJames Wright impl->h_array = impl->h_array_borrowed; 343*bd882c8aSJames Wright } else if (impl->d_array_owned) { 344*bd882c8aSJames Wright impl->d_array = impl->d_array_owned; 345*bd882c8aSJames Wright } else if (impl->h_array_owned) { 346*bd882c8aSJames Wright impl->h_array = impl->h_array_owned; 347*bd882c8aSJames Wright } else { 348*bd882c8aSJames Wright CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 349*bd882c8aSJames Wright } 350*bd882c8aSJames Wright } 351*bd882c8aSJames Wright if (impl->d_array) { 352*bd882c8aSJames Wright CeedCallBackend(CeedDeviceSetValue_Sycl(data->sycl_queue, impl->d_array, length, val)); 353*bd882c8aSJames Wright impl->h_array = NULL; 354*bd882c8aSJames Wright } 355*bd882c8aSJames Wright if (impl->h_array) { 356*bd882c8aSJames Wright CeedCallBackend(CeedHostSetValue_Sycl(impl->h_array, length, val)); 357*bd882c8aSJames Wright impl->d_array = NULL; 358*bd882c8aSJames Wright } 359*bd882c8aSJames Wright 360*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 361*bd882c8aSJames Wright } 362*bd882c8aSJames Wright 363*bd882c8aSJames Wright //------------------------------------------------------------------------------ 364*bd882c8aSJames Wright // Vector Take Array 365*bd882c8aSJames Wright //------------------------------------------------------------------------------ 366*bd882c8aSJames Wright static int CeedVectorTakeArray_Sycl(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 367*bd882c8aSJames Wright Ceed ceed; 368*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 369*bd882c8aSJames Wright CeedVector_Sycl *impl; 370*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 371*bd882c8aSJames Wright 372*bd882c8aSJames Wright Ceed_Sycl *data; 373*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 374*bd882c8aSJames Wright 375*bd882c8aSJames Wright // Order queue 376*bd882c8aSJames Wright data->sycl_queue.ext_oneapi_submit_barrier(); 377*bd882c8aSJames Wright 378*bd882c8aSJames Wright // Sync array to requested mem_type 379*bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 380*bd882c8aSJames Wright 381*bd882c8aSJames Wright // Update pointer 382*bd882c8aSJames Wright switch (mem_type) { 383*bd882c8aSJames Wright case CEED_MEM_HOST: 384*bd882c8aSJames Wright (*array) = impl->h_array_borrowed; 385*bd882c8aSJames Wright impl->h_array_borrowed = NULL; 386*bd882c8aSJames Wright impl->h_array = NULL; 387*bd882c8aSJames Wright break; 388*bd882c8aSJames Wright case CEED_MEM_DEVICE: 389*bd882c8aSJames Wright (*array) = impl->d_array_borrowed; 390*bd882c8aSJames Wright impl->d_array_borrowed = NULL; 391*bd882c8aSJames Wright impl->d_array = NULL; 392*bd882c8aSJames Wright break; 393*bd882c8aSJames Wright } 394*bd882c8aSJames Wright 395*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 396*bd882c8aSJames Wright } 397*bd882c8aSJames Wright 398*bd882c8aSJames Wright //------------------------------------------------------------------------------ 399*bd882c8aSJames Wright // Core logic for array syncronization for GetArray. 400*bd882c8aSJames Wright // If a different memory type is most up to date, this will perform a copy 401*bd882c8aSJames Wright //------------------------------------------------------------------------------ 402*bd882c8aSJames Wright static int CeedVectorGetArrayCore_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 403*bd882c8aSJames Wright Ceed ceed; 404*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 405*bd882c8aSJames Wright CeedVector_Sycl *impl; 406*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 407*bd882c8aSJames Wright 408*bd882c8aSJames Wright // Sync array to requested mem_type 409*bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 410*bd882c8aSJames Wright 411*bd882c8aSJames Wright // Update pointer 412*bd882c8aSJames Wright switch (mem_type) { 413*bd882c8aSJames Wright case CEED_MEM_HOST: 414*bd882c8aSJames Wright *array = impl->h_array; 415*bd882c8aSJames Wright break; 416*bd882c8aSJames Wright case CEED_MEM_DEVICE: 417*bd882c8aSJames Wright *array = impl->d_array; 418*bd882c8aSJames Wright break; 419*bd882c8aSJames Wright } 420*bd882c8aSJames Wright 421*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 422*bd882c8aSJames Wright } 423*bd882c8aSJames Wright //------------------------------------------------------------------------------ 424*bd882c8aSJames Wright // Get read-only access to a vector via the specified mem_type 425*bd882c8aSJames Wright //------------------------------------------------------------------------------ 426*bd882c8aSJames Wright static int CeedVectorGetArrayRead_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 427*bd882c8aSJames Wright return CeedVectorGetArrayCore_Sycl(vec, mem_type, (CeedScalar **)array); 428*bd882c8aSJames Wright } 429*bd882c8aSJames Wright 430*bd882c8aSJames Wright //------------------------------------------------------------------------------ 431*bd882c8aSJames Wright // Get read/write access to a vector via the specified mem_type 432*bd882c8aSJames Wright //------------------------------------------------------------------------------ 433*bd882c8aSJames Wright static int CeedVectorGetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 434*bd882c8aSJames Wright CeedVector_Sycl *impl; 435*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 436*bd882c8aSJames Wright 437*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayCore_Sycl(vec, mem_type, array)); 438*bd882c8aSJames Wright 439*bd882c8aSJames Wright CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 440*bd882c8aSJames Wright switch (mem_type) { 441*bd882c8aSJames Wright case CEED_MEM_HOST: 442*bd882c8aSJames Wright impl->h_array = *array; 443*bd882c8aSJames Wright break; 444*bd882c8aSJames Wright case CEED_MEM_DEVICE: 445*bd882c8aSJames Wright impl->d_array = *array; 446*bd882c8aSJames Wright break; 447*bd882c8aSJames Wright } 448*bd882c8aSJames Wright 449*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 450*bd882c8aSJames Wright } 451*bd882c8aSJames Wright 452*bd882c8aSJames Wright //------------------------------------------------------------------------------ 453*bd882c8aSJames Wright // Get write access to a vector via the specified mem_type 454*bd882c8aSJames Wright //------------------------------------------------------------------------------ 455*bd882c8aSJames Wright static int CeedVectorGetArrayWrite_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 456*bd882c8aSJames Wright CeedVector_Sycl *impl; 457*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 458*bd882c8aSJames Wright 459*bd882c8aSJames Wright bool has_array_of_type = true; 460*bd882c8aSJames Wright CeedCallBackend(CeedVectorHasArrayOfType_Sycl(vec, mem_type, &has_array_of_type)); 461*bd882c8aSJames Wright if (!has_array_of_type) { 462*bd882c8aSJames Wright // Allocate if array is not yet allocated 463*bd882c8aSJames Wright CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 464*bd882c8aSJames Wright } else { 465*bd882c8aSJames Wright // Select dirty array 466*bd882c8aSJames Wright switch (mem_type) { 467*bd882c8aSJames Wright case CEED_MEM_HOST: 468*bd882c8aSJames Wright if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 469*bd882c8aSJames Wright else impl->h_array = impl->h_array_owned; 470*bd882c8aSJames Wright break; 471*bd882c8aSJames Wright case CEED_MEM_DEVICE: 472*bd882c8aSJames Wright if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 473*bd882c8aSJames Wright else impl->d_array = impl->d_array_owned; 474*bd882c8aSJames Wright } 475*bd882c8aSJames Wright } 476*bd882c8aSJames Wright 477*bd882c8aSJames Wright return CeedVectorGetArray_Sycl(vec, mem_type, array); 478*bd882c8aSJames Wright } 479*bd882c8aSJames Wright 480*bd882c8aSJames Wright //------------------------------------------------------------------------------ 481*bd882c8aSJames Wright // Get the norm of a CeedVector 482*bd882c8aSJames Wright //------------------------------------------------------------------------------ 483*bd882c8aSJames Wright static int CeedVectorNorm_Sycl(CeedVector vec, CeedNormType type, CeedScalar *norm) { 484*bd882c8aSJames Wright Ceed ceed; 485*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 486*bd882c8aSJames Wright CeedVector_Sycl *impl; 487*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 488*bd882c8aSJames Wright CeedSize length; 489*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 490*bd882c8aSJames Wright Ceed_Sycl *data; 491*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 492*bd882c8aSJames Wright 493*bd882c8aSJames Wright // Compute norm 494*bd882c8aSJames Wright const CeedScalar *d_array; 495*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 496*bd882c8aSJames Wright 497*bd882c8aSJames Wright switch (type) { 498*bd882c8aSJames Wright case CEED_NORM_1: { 499*bd882c8aSJames Wright // Order queue 500*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 501*bd882c8aSJames Wright auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 502*bd882c8aSJames Wright data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += abs(d_array[i]); }).wait_and_throw(); 503*bd882c8aSJames Wright } break; 504*bd882c8aSJames Wright case CEED_NORM_2: { 505*bd882c8aSJames Wright // Order queue 506*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 507*bd882c8aSJames Wright auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 508*bd882c8aSJames Wright data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += (d_array[i] * d_array[i]); }).wait_and_throw(); 509*bd882c8aSJames Wright } break; 510*bd882c8aSJames Wright case CEED_NORM_MAX: { 511*bd882c8aSJames Wright // Order queue 512*bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 513*bd882c8aSJames Wright auto maxReduction = sycl::reduction(impl->reduction_norm, sycl::maximum<>(), {sycl::property::reduction::initialize_to_identity{}}); 514*bd882c8aSJames Wright data->sycl_queue.parallel_for(length, {e}, maxReduction, [=](sycl::id<1> i, auto &max) { max.combine(abs(d_array[i])); }).wait_and_throw(); 515*bd882c8aSJames Wright } break; 516*bd882c8aSJames Wright } 517*bd882c8aSJames Wright // L2 norm - square root over reduced value 518*bd882c8aSJames Wright if (type == CEED_NORM_2) *norm = sqrt(*impl->reduction_norm); 519*bd882c8aSJames Wright else *norm = *impl->reduction_norm; 520*bd882c8aSJames Wright 521*bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 522*bd882c8aSJames Wright 523*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 524*bd882c8aSJames Wright } 525*bd882c8aSJames Wright 526*bd882c8aSJames Wright //------------------------------------------------------------------------------ 527*bd882c8aSJames Wright // Take reciprocal of a vector on host 528*bd882c8aSJames Wright //------------------------------------------------------------------------------ 529*bd882c8aSJames Wright static int CeedHostReciprocal_Sycl(CeedScalar *h_array, CeedInt length) { 530*bd882c8aSJames Wright for (int i = 0; i < length; i++) { 531*bd882c8aSJames Wright if (std::fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 532*bd882c8aSJames Wright } 533*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 534*bd882c8aSJames Wright } 535*bd882c8aSJames Wright 536*bd882c8aSJames Wright //------------------------------------------------------------------------------ 537*bd882c8aSJames Wright // Take reciprocal of a vector on device 538*bd882c8aSJames Wright //------------------------------------------------------------------------------ 539*bd882c8aSJames Wright static int CeedDeviceReciprocal_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedInt length) { 540*bd882c8aSJames Wright // Order queue 541*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 542*bd882c8aSJames Wright sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { 543*bd882c8aSJames Wright if (std::fabs(d_array[i]) > CEED_EPSILON) d_array[i] = 1. / d_array[i]; 544*bd882c8aSJames Wright }); 545*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 546*bd882c8aSJames Wright } 547*bd882c8aSJames Wright 548*bd882c8aSJames Wright //------------------------------------------------------------------------------ 549*bd882c8aSJames Wright // Take reciprocal of a vector 550*bd882c8aSJames Wright //------------------------------------------------------------------------------ 551*bd882c8aSJames Wright static int CeedVectorReciprocal_Sycl(CeedVector vec) { 552*bd882c8aSJames Wright Ceed ceed; 553*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 554*bd882c8aSJames Wright CeedVector_Sycl *impl; 555*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 556*bd882c8aSJames Wright CeedSize length; 557*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 558*bd882c8aSJames Wright Ceed_Sycl *data; 559*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 560*bd882c8aSJames Wright 561*bd882c8aSJames Wright // Set value for synced device/host array 562*bd882c8aSJames Wright if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Sycl(data->sycl_queue, impl->d_array, length)); 563*bd882c8aSJames Wright if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Sycl(impl->h_array, length)); 564*bd882c8aSJames Wright 565*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 566*bd882c8aSJames Wright } 567*bd882c8aSJames Wright 568*bd882c8aSJames Wright //------------------------------------------------------------------------------ 569*bd882c8aSJames Wright // Compute x = alpha x on the host 570*bd882c8aSJames Wright //------------------------------------------------------------------------------ 571*bd882c8aSJames Wright static int CeedHostScale_Sycl(CeedScalar *x_array, CeedScalar alpha, CeedInt length) { 572*bd882c8aSJames Wright for (int i = 0; i < length; i++) x_array[i] *= alpha; 573*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 574*bd882c8aSJames Wright } 575*bd882c8aSJames Wright 576*bd882c8aSJames Wright //------------------------------------------------------------------------------ 577*bd882c8aSJames Wright // Compute x = alpha x on device 578*bd882c8aSJames Wright //------------------------------------------------------------------------------ 579*bd882c8aSJames Wright static int CeedDeviceScale_Sycl(sycl::queue &sycl_queue, CeedScalar *x_array, CeedScalar alpha, CeedInt length) { 580*bd882c8aSJames Wright // Order queue 581*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 582*bd882c8aSJames Wright sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { x_array[i] *= alpha; }); 583*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 584*bd882c8aSJames Wright } 585*bd882c8aSJames Wright 586*bd882c8aSJames Wright //------------------------------------------------------------------------------ 587*bd882c8aSJames Wright // Compute x = alpha x 588*bd882c8aSJames Wright //------------------------------------------------------------------------------ 589*bd882c8aSJames Wright static int CeedVectorScale_Sycl(CeedVector x, CeedScalar alpha) { 590*bd882c8aSJames Wright Ceed ceed; 591*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 592*bd882c8aSJames Wright CeedVector_Sycl *x_impl; 593*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(x, &x_impl)); 594*bd882c8aSJames Wright CeedSize length; 595*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(x, &length)); 596*bd882c8aSJames Wright Ceed_Sycl *data; 597*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 598*bd882c8aSJames Wright 599*bd882c8aSJames Wright // Set value for synced device/host array 600*bd882c8aSJames Wright if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Sycl(data->sycl_queue, x_impl->d_array, alpha, length)); 601*bd882c8aSJames Wright if (x_impl->h_array) CeedCallBackend(CeedHostScale_Sycl(x_impl->h_array, alpha, length)); 602*bd882c8aSJames Wright 603*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 604*bd882c8aSJames Wright } 605*bd882c8aSJames Wright 606*bd882c8aSJames Wright //------------------------------------------------------------------------------ 607*bd882c8aSJames Wright // Compute y = alpha x + y on the host 608*bd882c8aSJames Wright //------------------------------------------------------------------------------ 609*bd882c8aSJames Wright static int CeedHostAXPY_Sycl(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) { 610*bd882c8aSJames Wright for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 611*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 612*bd882c8aSJames Wright } 613*bd882c8aSJames Wright 614*bd882c8aSJames Wright //------------------------------------------------------------------------------ 615*bd882c8aSJames Wright // Compute y = alpha x + y on device 616*bd882c8aSJames Wright //------------------------------------------------------------------------------ 617*bd882c8aSJames Wright static int CeedDeviceAXPY_Sycl(sycl::queue &sycl_queue, CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) { 618*bd882c8aSJames Wright // Order queue 619*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 620*bd882c8aSJames Wright sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { y_array[i] += alpha * x_array[i]; }); 621*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 622*bd882c8aSJames Wright } 623*bd882c8aSJames Wright 624*bd882c8aSJames Wright //------------------------------------------------------------------------------ 625*bd882c8aSJames Wright // Compute y = alpha x + y 626*bd882c8aSJames Wright //------------------------------------------------------------------------------ 627*bd882c8aSJames Wright static int CeedVectorAXPY_Sycl(CeedVector y, CeedScalar alpha, CeedVector x) { 628*bd882c8aSJames Wright Ceed ceed; 629*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 630*bd882c8aSJames Wright CeedVector_Sycl *y_impl, *x_impl; 631*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(y, &y_impl)); 632*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(x, &x_impl)); 633*bd882c8aSJames Wright CeedSize length; 634*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(y, &length)); 635*bd882c8aSJames Wright Ceed_Sycl *data; 636*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 637*bd882c8aSJames Wright 638*bd882c8aSJames Wright // Set value for synced device/host array 639*bd882c8aSJames Wright if (y_impl->d_array) { 640*bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 641*bd882c8aSJames Wright CeedCallBackend(CeedDeviceAXPY_Sycl(data->sycl_queue, y_impl->d_array, alpha, x_impl->d_array, length)); 642*bd882c8aSJames Wright } 643*bd882c8aSJames Wright if (y_impl->h_array) { 644*bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 645*bd882c8aSJames Wright CeedCallBackend(CeedHostAXPY_Sycl(y_impl->h_array, alpha, x_impl->h_array, length)); 646*bd882c8aSJames Wright } 647*bd882c8aSJames Wright 648*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 649*bd882c8aSJames Wright } 650*bd882c8aSJames Wright 651*bd882c8aSJames Wright //------------------------------------------------------------------------------ 652*bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y on the host 653*bd882c8aSJames Wright //------------------------------------------------------------------------------ 654*bd882c8aSJames Wright static int CeedHostPointwiseMult_Sycl(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) { 655*bd882c8aSJames Wright for (int i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 656*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 657*bd882c8aSJames Wright } 658*bd882c8aSJames Wright 659*bd882c8aSJames Wright //------------------------------------------------------------------------------ 660*bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 661*bd882c8aSJames Wright //------------------------------------------------------------------------------ 662*bd882c8aSJames Wright static int CeedDevicePointwiseMult_Sycl(sycl::queue &sycl_queue, CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) { 663*bd882c8aSJames Wright // Order queue 664*bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 665*bd882c8aSJames Wright sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { w_array[i] = x_array[i] * y_array[i]; }); 666*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 667*bd882c8aSJames Wright } 668*bd882c8aSJames Wright 669*bd882c8aSJames Wright //------------------------------------------------------------------------------ 670*bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y 671*bd882c8aSJames Wright //------------------------------------------------------------------------------ 672*bd882c8aSJames Wright static int CeedVectorPointwiseMult_Sycl(CeedVector w, CeedVector x, CeedVector y) { 673*bd882c8aSJames Wright Ceed ceed; 674*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 675*bd882c8aSJames Wright CeedVector_Sycl *w_impl, *x_impl, *y_impl; 676*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(w, &w_impl)); 677*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(x, &x_impl)); 678*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(y, &y_impl)); 679*bd882c8aSJames Wright CeedSize length; 680*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(w, &length)); 681*bd882c8aSJames Wright Ceed_Sycl *data; 682*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 683*bd882c8aSJames Wright 684*bd882c8aSJames Wright // Set value for synced device/host array 685*bd882c8aSJames Wright if (!w_impl->d_array && !w_impl->h_array) { 686*bd882c8aSJames Wright CeedCallBackend(CeedVectorSetValue(w, 0.0)); 687*bd882c8aSJames Wright } 688*bd882c8aSJames Wright if (w_impl->d_array) { 689*bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 690*bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 691*bd882c8aSJames Wright CeedCallBackend(CeedDevicePointwiseMult_Sycl(data->sycl_queue, w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 692*bd882c8aSJames Wright } 693*bd882c8aSJames Wright if (w_impl->h_array) { 694*bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 695*bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 696*bd882c8aSJames Wright CeedCallBackend(CeedHostPointwiseMult_Sycl(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 697*bd882c8aSJames Wright } 698*bd882c8aSJames Wright 699*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 700*bd882c8aSJames Wright } 701*bd882c8aSJames Wright 702*bd882c8aSJames Wright //------------------------------------------------------------------------------ 703*bd882c8aSJames Wright // Destroy the vector 704*bd882c8aSJames Wright //------------------------------------------------------------------------------ 705*bd882c8aSJames Wright static int CeedVectorDestroy_Sycl(const CeedVector vec) { 706*bd882c8aSJames Wright Ceed ceed; 707*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 708*bd882c8aSJames Wright CeedVector_Sycl *impl; 709*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 710*bd882c8aSJames Wright Ceed_Sycl *data; 711*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 712*bd882c8aSJames Wright 713*bd882c8aSJames Wright // Wait for all work to finish before freeing memory 714*bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 715*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 716*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->reduction_norm, data->sycl_context)); 717*bd882c8aSJames Wright 718*bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl->h_array_owned)); 719*bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 720*bd882c8aSJames Wright 721*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 722*bd882c8aSJames Wright } 723*bd882c8aSJames Wright 724*bd882c8aSJames Wright //------------------------------------------------------------------------------ 725*bd882c8aSJames Wright // Create a vector of the specified length (does not allocate memory) 726*bd882c8aSJames Wright //------------------------------------------------------------------------------ 727*bd882c8aSJames Wright int CeedVectorCreate_Sycl(CeedSize n, CeedVector vec) { 728*bd882c8aSJames Wright CeedVector_Sycl *impl; 729*bd882c8aSJames Wright Ceed ceed; 730*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 731*bd882c8aSJames Wright Ceed_Sycl *data; 732*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 733*bd882c8aSJames Wright 734*bd882c8aSJames Wright CeedCallBackend(CeedCalloc(1, &impl)); 735*bd882c8aSJames Wright CeedCallSycl(ceed, impl->reduction_norm = sycl::malloc_host<CeedScalar>(1, data->sycl_context)); 736*bd882c8aSJames Wright 737*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Sycl)); 738*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Sycl)); 739*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Sycl)); 740*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Sycl)); 741*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Sycl)); 742*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Sycl)); 743*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Sycl)); 744*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Sycl)); 745*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Sycl)); 746*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Norm", CeedVectorNorm_Sycl)); 747*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Sycl)); 748*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Sycl)); 749*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Scale", CeedVectorScale_Sycl)); 750*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Sycl)); 751*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Sycl)); 752*bd882c8aSJames Wright 753*bd882c8aSJames Wright CeedCallBackend(CeedVectorSetData(vec, impl)); 754*bd882c8aSJames Wright 755*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 756*bd882c8aSJames Wright } 757