1bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3bd882c8aSJames Wright // 4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5bd882c8aSJames Wright // 6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7bd882c8aSJames Wright 8bd882c8aSJames Wright #include <ceed/backend.h> 9bd882c8aSJames Wright #include <ceed/ceed.h> 10bd882c8aSJames Wright 11bd882c8aSJames Wright #include <cmath> 12bd882c8aSJames Wright #include <string> 13bd882c8aSJames Wright #include <sycl/sycl.hpp> 14bd882c8aSJames Wright 15bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 16bd882c8aSJames Wright 17bd882c8aSJames Wright //------------------------------------------------------------------------------ 18bd882c8aSJames Wright // Check if host/device sync is needed 19bd882c8aSJames Wright //------------------------------------------------------------------------------ 20bd882c8aSJames Wright static inline int CeedVectorNeedSync_Sycl(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 21bd882c8aSJames Wright bool has_valid_array = false; 22*dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 23*dd64fc84SJeremy L Thompson 24*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 25bd882c8aSJames Wright CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 26bd882c8aSJames Wright switch (mem_type) { 27bd882c8aSJames Wright case CEED_MEM_HOST: 28bd882c8aSJames Wright *need_sync = has_valid_array && !impl->h_array; 29bd882c8aSJames Wright break; 30bd882c8aSJames Wright case CEED_MEM_DEVICE: 31bd882c8aSJames Wright *need_sync = has_valid_array && !impl->d_array; 32bd882c8aSJames Wright break; 33bd882c8aSJames Wright } 34bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 35bd882c8aSJames Wright } 36bd882c8aSJames Wright 37bd882c8aSJames Wright //------------------------------------------------------------------------------ 38bd882c8aSJames Wright // Sync host to device 39bd882c8aSJames Wright //------------------------------------------------------------------------------ 40bd882c8aSJames Wright static inline int CeedVectorSyncH2D_Sycl(const CeedVector vec) { 41bd882c8aSJames Wright Ceed ceed; 42bd882c8aSJames Wright Ceed_Sycl *data; 43*dd64fc84SJeremy L Thompson CeedSize length; 44*dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 45*dd64fc84SJeremy L Thompson 46*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 47*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 48bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 49bd882c8aSJames Wright 50bd882c8aSJames Wright if (!impl->h_array) { 51bd882c8aSJames Wright // LCOV_EXCL_START 52bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 53bd882c8aSJames Wright // LCOV_EXCL_STOP 54bd882c8aSJames Wright } 55bd882c8aSJames Wright 56bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 57bd882c8aSJames Wright if (impl->d_array_borrowed) { 58bd882c8aSJames Wright impl->d_array = impl->d_array_borrowed; 59bd882c8aSJames Wright } else if (impl->d_array_owned) { 60bd882c8aSJames Wright impl->d_array = impl->d_array_owned; 61bd882c8aSJames Wright } else { 62bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 63bd882c8aSJames Wright impl->d_array = impl->d_array_owned; 64bd882c8aSJames Wright } 65bd882c8aSJames Wright 66bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 67bd882c8aSJames Wright // Copy from host to device 68bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->h_array, impl->d_array, length, {e}); 69bd882c8aSJames Wright // Wait for copy to finish and handle exceptions. 70bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 71bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 72bd882c8aSJames Wright } 73bd882c8aSJames Wright 74bd882c8aSJames Wright //------------------------------------------------------------------------------ 75bd882c8aSJames Wright // Sync device to host 76bd882c8aSJames Wright //------------------------------------------------------------------------------ 77bd882c8aSJames Wright static inline int CeedVectorSyncD2H_Sycl(const CeedVector vec) { 78bd882c8aSJames Wright Ceed ceed; 79bd882c8aSJames Wright Ceed_Sycl *data; 80*dd64fc84SJeremy L Thompson CeedSize length; 81*dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 82*dd64fc84SJeremy L Thompson 83*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 84*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 85bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 86bd882c8aSJames Wright 87bd882c8aSJames Wright CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 88bd882c8aSJames Wright 89bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 90bd882c8aSJames Wright if (impl->h_array_borrowed) { 91bd882c8aSJames Wright impl->h_array = impl->h_array_borrowed; 92bd882c8aSJames Wright } else if (impl->h_array_owned) { 93bd882c8aSJames Wright impl->h_array = impl->h_array_owned; 94bd882c8aSJames Wright } else { 95bd882c8aSJames Wright CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 96bd882c8aSJames Wright impl->h_array = impl->h_array_owned; 97bd882c8aSJames Wright } 98bd882c8aSJames Wright 99bd882c8aSJames Wright // Order queue 100bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 101bd882c8aSJames Wright // Copy from device to host 102bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->d_array, impl->h_array, length, {e}); 103bd882c8aSJames Wright // Wait for copy to finish and handle exceptions. 104bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 105bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 106bd882c8aSJames Wright } 107bd882c8aSJames Wright 108bd882c8aSJames Wright //------------------------------------------------------------------------------ 109bd882c8aSJames Wright // Sync arrays 110bd882c8aSJames Wright //------------------------------------------------------------------------------ 111bd882c8aSJames Wright static int CeedVectorSyncArray_Sycl(const CeedVector vec, CeedMemType mem_type) { 112bd882c8aSJames Wright bool need_sync = false; 113*dd64fc84SJeremy L Thompson 114*dd64fc84SJeremy L Thompson // Check whether device/host sync is needed 115bd882c8aSJames Wright CeedCallBackend(CeedVectorNeedSync_Sycl(vec, mem_type, &need_sync)); 116bd882c8aSJames Wright if (!need_sync) return CEED_ERROR_SUCCESS; 117bd882c8aSJames Wright 118bd882c8aSJames Wright switch (mem_type) { 119bd882c8aSJames Wright case CEED_MEM_HOST: 120bd882c8aSJames Wright return CeedVectorSyncD2H_Sycl(vec); 121bd882c8aSJames Wright case CEED_MEM_DEVICE: 122bd882c8aSJames Wright return CeedVectorSyncH2D_Sycl(vec); 123bd882c8aSJames Wright } 124bd882c8aSJames Wright return CEED_ERROR_UNSUPPORTED; 125bd882c8aSJames Wright } 126bd882c8aSJames Wright 127bd882c8aSJames Wright //------------------------------------------------------------------------------ 128bd882c8aSJames Wright // Set all pointers as invalid 129bd882c8aSJames Wright //------------------------------------------------------------------------------ 130bd882c8aSJames Wright static inline int CeedVectorSetAllInvalid_Sycl(const CeedVector vec) { 131bd882c8aSJames Wright CeedVector_Sycl *impl; 132bd882c8aSJames Wright 133*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 134bd882c8aSJames Wright impl->h_array = NULL; 135bd882c8aSJames Wright impl->d_array = NULL; 136bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 137bd882c8aSJames Wright } 138bd882c8aSJames Wright 139bd882c8aSJames Wright //------------------------------------------------------------------------------ 140bd882c8aSJames Wright // Check if CeedVector has any valid pointer 141bd882c8aSJames Wright //------------------------------------------------------------------------------ 142bd882c8aSJames Wright static inline int CeedVectorHasValidArray_Sycl(const CeedVector vec, bool *has_valid_array) { 143bd882c8aSJames Wright CeedVector_Sycl *impl; 144*dd64fc84SJeremy L Thompson 145bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 1461c66c397SJeremy L Thompson *has_valid_array = impl->h_array || impl->d_array; 147bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 148bd882c8aSJames Wright } 149bd882c8aSJames Wright 150bd882c8aSJames Wright //------------------------------------------------------------------------------ 151bd882c8aSJames Wright // Check if has array of given type 152bd882c8aSJames Wright //------------------------------------------------------------------------------ 153bd882c8aSJames Wright static inline int CeedVectorHasArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 154bd882c8aSJames Wright CeedVector_Sycl *impl; 155bd882c8aSJames Wright 156*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 157bd882c8aSJames Wright switch (mem_type) { 158bd882c8aSJames Wright case CEED_MEM_HOST: 1591c66c397SJeremy L Thompson *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned; 160bd882c8aSJames Wright break; 161bd882c8aSJames Wright case CEED_MEM_DEVICE: 1621c66c397SJeremy L Thompson *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned; 163bd882c8aSJames Wright break; 164bd882c8aSJames Wright } 165bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 166bd882c8aSJames Wright } 167bd882c8aSJames Wright 168bd882c8aSJames Wright //------------------------------------------------------------------------------ 169bd882c8aSJames Wright // Check if has borrowed array of given type 170bd882c8aSJames Wright //------------------------------------------------------------------------------ 171bd882c8aSJames Wright static inline int CeedVectorHasBorrowedArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 172bd882c8aSJames Wright CeedVector_Sycl *impl; 173bd882c8aSJames Wright 174*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 175bd882c8aSJames Wright switch (mem_type) { 176bd882c8aSJames Wright case CEED_MEM_HOST: 1771c66c397SJeremy L Thompson *has_borrowed_array_of_type = impl->h_array_borrowed; 178bd882c8aSJames Wright break; 179bd882c8aSJames Wright case CEED_MEM_DEVICE: 1801c66c397SJeremy L Thompson *has_borrowed_array_of_type = impl->d_array_borrowed; 181bd882c8aSJames Wright break; 182bd882c8aSJames Wright } 183bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 184bd882c8aSJames Wright } 185bd882c8aSJames Wright 186bd882c8aSJames Wright //------------------------------------------------------------------------------ 187bd882c8aSJames Wright // Set array from host 188bd882c8aSJames Wright //------------------------------------------------------------------------------ 189bd882c8aSJames Wright static int CeedVectorSetArrayHost_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 190bd882c8aSJames Wright CeedVector_Sycl *impl; 191bd882c8aSJames Wright 192*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 193bd882c8aSJames Wright switch (copy_mode) { 194bd882c8aSJames Wright case CEED_COPY_VALUES: { 195bd882c8aSJames Wright if (!impl->h_array_owned) { 196bd882c8aSJames Wright CeedSize length; 197*dd64fc84SJeremy L Thompson 198bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 199bd882c8aSJames Wright CeedCallBackend(CeedMalloc(length, &impl->h_array_owned)); 200bd882c8aSJames Wright } 201bd882c8aSJames Wright impl->h_array_borrowed = NULL; 202bd882c8aSJames Wright impl->h_array = impl->h_array_owned; 203bd882c8aSJames Wright if (array) { 204bd882c8aSJames Wright CeedSize length; 205*dd64fc84SJeremy L Thompson 206bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 207bd882c8aSJames Wright size_t bytes = length * sizeof(CeedScalar); 208*dd64fc84SJeremy L Thompson 209bd882c8aSJames Wright memcpy(impl->h_array, array, bytes); 210bd882c8aSJames Wright } 211bd882c8aSJames Wright } break; 212bd882c8aSJames Wright case CEED_OWN_POINTER: 213bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl->h_array_owned)); 214bd882c8aSJames Wright impl->h_array_owned = array; 215bd882c8aSJames Wright impl->h_array_borrowed = NULL; 216bd882c8aSJames Wright impl->h_array = array; 217bd882c8aSJames Wright break; 218bd882c8aSJames Wright case CEED_USE_POINTER: 219bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl->h_array_owned)); 220bd882c8aSJames Wright impl->h_array_borrowed = array; 221bd882c8aSJames Wright impl->h_array = array; 222bd882c8aSJames Wright break; 223bd882c8aSJames Wright } 224bd882c8aSJames Wright 225bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 226bd882c8aSJames Wright } 227bd882c8aSJames Wright 228bd882c8aSJames Wright //------------------------------------------------------------------------------ 229bd882c8aSJames Wright // Set array from device 230bd882c8aSJames Wright //------------------------------------------------------------------------------ 231bd882c8aSJames Wright static int CeedVectorSetArrayDevice_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 232bd882c8aSJames Wright Ceed ceed; 233bd882c8aSJames Wright Ceed_Sycl *data; 234*dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 235*dd64fc84SJeremy L Thompson 236*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 237*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 238bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 239bd882c8aSJames Wright 240bd882c8aSJames Wright // Order queue 241bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 242bd882c8aSJames Wright 243bd882c8aSJames Wright switch (copy_mode) { 244bd882c8aSJames Wright case CEED_COPY_VALUES: { 245bd882c8aSJames Wright CeedSize length; 246*dd64fc84SJeremy L Thompson 247bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 248bd882c8aSJames Wright if (!impl->d_array_owned) { 249bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 250bd882c8aSJames Wright } 2513ce2313bSJeremy L Thompson impl->d_array = impl->d_array_owned; 252bd882c8aSJames Wright if (array) { 253bd882c8aSJames Wright sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(array, impl->d_array, length, {e}); 254bd882c8aSJames Wright // Wait for copy to finish and handle exceptions. 255bd882c8aSJames Wright CeedCallSycl(ceed, copy_event.wait_and_throw()); 256bd882c8aSJames Wright } 257bd882c8aSJames Wright } break; 258bd882c8aSJames Wright case CEED_OWN_POINTER: 259bd882c8aSJames Wright if (impl->d_array_owned) { 260bd882c8aSJames Wright // Wait for all work to finish before freeing memory 261bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 262bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 263bd882c8aSJames Wright } 264bd882c8aSJames Wright impl->d_array_owned = array; 265bd882c8aSJames Wright impl->d_array_borrowed = NULL; 266bd882c8aSJames Wright impl->d_array = array; 267bd882c8aSJames Wright break; 268bd882c8aSJames Wright case CEED_USE_POINTER: 269bd882c8aSJames Wright if (impl->d_array_owned) { 270bd882c8aSJames Wright // Wait for all work to finish before freeing memory 271bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 272bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 273bd882c8aSJames Wright } 274bd882c8aSJames Wright impl->d_array_owned = NULL; 275bd882c8aSJames Wright impl->d_array_borrowed = array; 276bd882c8aSJames Wright impl->d_array = array; 277bd882c8aSJames Wright break; 278bd882c8aSJames Wright } 279bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 280bd882c8aSJames Wright } 281bd882c8aSJames Wright 282bd882c8aSJames Wright //------------------------------------------------------------------------------ 283bd882c8aSJames Wright // Set the array used by a vector, 284bd882c8aSJames Wright // freeing any previously allocated array if applicable 285bd882c8aSJames Wright //------------------------------------------------------------------------------ 286bd882c8aSJames Wright static int CeedVectorSetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 287bd882c8aSJames Wright Ceed ceed; 288bd882c8aSJames Wright CeedVector_Sycl *impl; 289*dd64fc84SJeremy L Thompson 290*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 291bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 292bd882c8aSJames Wright 293bd882c8aSJames Wright CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 294bd882c8aSJames Wright switch (mem_type) { 295bd882c8aSJames Wright case CEED_MEM_HOST: 296bd882c8aSJames Wright return CeedVectorSetArrayHost_Sycl(vec, copy_mode, array); 297bd882c8aSJames Wright case CEED_MEM_DEVICE: 298bd882c8aSJames Wright return CeedVectorSetArrayDevice_Sycl(vec, copy_mode, array); 299bd882c8aSJames Wright } 300bd882c8aSJames Wright return CEED_ERROR_UNSUPPORTED; 301bd882c8aSJames Wright } 302bd882c8aSJames Wright 303bd882c8aSJames Wright //------------------------------------------------------------------------------ 304bd882c8aSJames Wright // Set host array to value 305bd882c8aSJames Wright //------------------------------------------------------------------------------ 3066ca0f394SUmesh Unnikrishnan static int CeedHostSetValue_Sycl(CeedScalar *h_array, CeedSize length, CeedScalar val) { 3076ca0f394SUmesh Unnikrishnan for (CeedSize i = 0; i < length; i++) h_array[i] = val; 308bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 309bd882c8aSJames Wright } 310bd882c8aSJames Wright 311bd882c8aSJames Wright //------------------------------------------------------------------------------ 312bd882c8aSJames Wright // Set device array to value 313bd882c8aSJames Wright //------------------------------------------------------------------------------ 3146ca0f394SUmesh Unnikrishnan static int CeedDeviceSetValue_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length, CeedScalar val) { 315bd882c8aSJames Wright // Order queue 316bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 317bd882c8aSJames Wright sycl_queue.fill(d_array, val, length, {e}); 318bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 319bd882c8aSJames Wright } 320bd882c8aSJames Wright 321bd882c8aSJames Wright //------------------------------------------------------------------------------ 322bd882c8aSJames Wright // Set a vector to a value, 323bd882c8aSJames Wright //------------------------------------------------------------------------------ 324bd882c8aSJames Wright static int CeedVectorSetValue_Sycl(CeedVector vec, CeedScalar val) { 325bd882c8aSJames Wright Ceed ceed; 326bd882c8aSJames Wright Ceed_Sycl *data; 327*dd64fc84SJeremy L Thompson CeedSize length; 328*dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 329*dd64fc84SJeremy L Thompson 330*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 331*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 332*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 333bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 334bd882c8aSJames Wright 335bd882c8aSJames Wright // Set value for synced device/host array 336bd882c8aSJames Wright if (!impl->d_array && !impl->h_array) { 337bd882c8aSJames Wright if (impl->d_array_borrowed) { 338bd882c8aSJames Wright impl->d_array = impl->d_array_borrowed; 339bd882c8aSJames Wright } else if (impl->h_array_borrowed) { 340bd882c8aSJames Wright impl->h_array = impl->h_array_borrowed; 341bd882c8aSJames Wright } else if (impl->d_array_owned) { 342bd882c8aSJames Wright impl->d_array = impl->d_array_owned; 343bd882c8aSJames Wright } else if (impl->h_array_owned) { 344bd882c8aSJames Wright impl->h_array = impl->h_array_owned; 345bd882c8aSJames Wright } else { 346bd882c8aSJames Wright CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 347bd882c8aSJames Wright } 348bd882c8aSJames Wright } 349bd882c8aSJames Wright if (impl->d_array) { 350bd882c8aSJames Wright CeedCallBackend(CeedDeviceSetValue_Sycl(data->sycl_queue, impl->d_array, length, val)); 351bd882c8aSJames Wright impl->h_array = NULL; 352bd882c8aSJames Wright } 353bd882c8aSJames Wright if (impl->h_array) { 354bd882c8aSJames Wright CeedCallBackend(CeedHostSetValue_Sycl(impl->h_array, length, val)); 355bd882c8aSJames Wright impl->d_array = NULL; 356bd882c8aSJames Wright } 357bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 358bd882c8aSJames Wright } 359bd882c8aSJames Wright 360bd882c8aSJames Wright //------------------------------------------------------------------------------ 361bd882c8aSJames Wright // Vector Take Array 362bd882c8aSJames Wright //------------------------------------------------------------------------------ 363bd882c8aSJames Wright static int CeedVectorTakeArray_Sycl(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 364bd882c8aSJames Wright Ceed ceed; 365bd882c8aSJames Wright Ceed_Sycl *data; 366*dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 367*dd64fc84SJeremy L Thompson 368*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 369*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 370bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 371bd882c8aSJames Wright 372bd882c8aSJames Wright // Order queue 373bd882c8aSJames Wright data->sycl_queue.ext_oneapi_submit_barrier(); 374bd882c8aSJames Wright 375bd882c8aSJames Wright // Sync array to requested mem_type 376bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 377bd882c8aSJames Wright 378bd882c8aSJames Wright // Update pointer 379bd882c8aSJames Wright switch (mem_type) { 380bd882c8aSJames Wright case CEED_MEM_HOST: 381bd882c8aSJames Wright (*array) = impl->h_array_borrowed; 382bd882c8aSJames Wright impl->h_array_borrowed = NULL; 383bd882c8aSJames Wright impl->h_array = NULL; 384bd882c8aSJames Wright break; 385bd882c8aSJames Wright case CEED_MEM_DEVICE: 386bd882c8aSJames Wright (*array) = impl->d_array_borrowed; 387bd882c8aSJames Wright impl->d_array_borrowed = NULL; 388bd882c8aSJames Wright impl->d_array = NULL; 389bd882c8aSJames Wright break; 390bd882c8aSJames Wright } 391bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 392bd882c8aSJames Wright } 393bd882c8aSJames Wright 394bd882c8aSJames Wright //------------------------------------------------------------------------------ 395bd882c8aSJames Wright // Core logic for array syncronization for GetArray. 396bd882c8aSJames Wright // If a different memory type is most up to date, this will perform a copy 397bd882c8aSJames Wright //------------------------------------------------------------------------------ 398bd882c8aSJames Wright static int CeedVectorGetArrayCore_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 399bd882c8aSJames Wright Ceed ceed; 400bd882c8aSJames Wright CeedVector_Sycl *impl; 401*dd64fc84SJeremy L Thompson 402*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 403bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 404bd882c8aSJames Wright 405bd882c8aSJames Wright // Sync array to requested mem_type 406bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 407bd882c8aSJames Wright 408bd882c8aSJames Wright // Update pointer 409bd882c8aSJames Wright switch (mem_type) { 410bd882c8aSJames Wright case CEED_MEM_HOST: 411bd882c8aSJames Wright *array = impl->h_array; 412bd882c8aSJames Wright break; 413bd882c8aSJames Wright case CEED_MEM_DEVICE: 414bd882c8aSJames Wright *array = impl->d_array; 415bd882c8aSJames Wright break; 416bd882c8aSJames Wright } 417bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 418bd882c8aSJames Wright } 419ff1e7120SSebastian Grimberg 420bd882c8aSJames Wright //------------------------------------------------------------------------------ 421bd882c8aSJames Wright // Get read-only access to a vector via the specified mem_type 422bd882c8aSJames Wright //------------------------------------------------------------------------------ 423bd882c8aSJames Wright static int CeedVectorGetArrayRead_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 424bd882c8aSJames Wright return CeedVectorGetArrayCore_Sycl(vec, mem_type, (CeedScalar **)array); 425bd882c8aSJames Wright } 426bd882c8aSJames Wright 427bd882c8aSJames Wright //------------------------------------------------------------------------------ 428bd882c8aSJames Wright // Get read/write access to a vector via the specified mem_type 429bd882c8aSJames Wright //------------------------------------------------------------------------------ 430bd882c8aSJames Wright static int CeedVectorGetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 431bd882c8aSJames Wright CeedVector_Sycl *impl; 432*dd64fc84SJeremy L Thompson 433bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 434bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayCore_Sycl(vec, mem_type, array)); 435bd882c8aSJames Wright CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 436bd882c8aSJames Wright switch (mem_type) { 437bd882c8aSJames Wright case CEED_MEM_HOST: 438bd882c8aSJames Wright impl->h_array = *array; 439bd882c8aSJames Wright break; 440bd882c8aSJames Wright case CEED_MEM_DEVICE: 441bd882c8aSJames Wright impl->d_array = *array; 442bd882c8aSJames Wright break; 443bd882c8aSJames Wright } 444bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 445bd882c8aSJames Wright } 446bd882c8aSJames Wright 447bd882c8aSJames Wright //------------------------------------------------------------------------------ 448bd882c8aSJames Wright // Get write access to a vector via the specified mem_type 449bd882c8aSJames Wright //------------------------------------------------------------------------------ 450bd882c8aSJames Wright static int CeedVectorGetArrayWrite_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 451bd882c8aSJames Wright bool has_array_of_type = true; 452*dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 453*dd64fc84SJeremy L Thompson 454*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 455bd882c8aSJames Wright CeedCallBackend(CeedVectorHasArrayOfType_Sycl(vec, mem_type, &has_array_of_type)); 456bd882c8aSJames Wright if (!has_array_of_type) { 457bd882c8aSJames Wright // Allocate if array is not yet allocated 458bd882c8aSJames Wright CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 459bd882c8aSJames Wright } else { 460bd882c8aSJames Wright // Select dirty array 461bd882c8aSJames Wright switch (mem_type) { 462bd882c8aSJames Wright case CEED_MEM_HOST: 463bd882c8aSJames Wright if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 464bd882c8aSJames Wright else impl->h_array = impl->h_array_owned; 465bd882c8aSJames Wright break; 466bd882c8aSJames Wright case CEED_MEM_DEVICE: 467bd882c8aSJames Wright if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 468bd882c8aSJames Wright else impl->d_array = impl->d_array_owned; 469bd882c8aSJames Wright } 470bd882c8aSJames Wright } 471bd882c8aSJames Wright return CeedVectorGetArray_Sycl(vec, mem_type, array); 472bd882c8aSJames Wright } 473bd882c8aSJames Wright 474bd882c8aSJames Wright //------------------------------------------------------------------------------ 475bd882c8aSJames Wright // Get the norm of a CeedVector 476bd882c8aSJames Wright //------------------------------------------------------------------------------ 477bd882c8aSJames Wright static int CeedVectorNorm_Sycl(CeedVector vec, CeedNormType type, CeedScalar *norm) { 478bd882c8aSJames Wright Ceed ceed; 479bd882c8aSJames Wright Ceed_Sycl *data; 480*dd64fc84SJeremy L Thompson CeedSize length; 481*dd64fc84SJeremy L Thompson const CeedScalar *d_array; 482*dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 483*dd64fc84SJeremy L Thompson 484*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 485*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 486*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 487bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 488bd882c8aSJames Wright 489bd882c8aSJames Wright // Compute norm 490bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 491bd882c8aSJames Wright switch (type) { 492bd882c8aSJames Wright case CEED_NORM_1: { 493bd882c8aSJames Wright // Order queue 494bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 495bd882c8aSJames Wright auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 496bd882c8aSJames Wright data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += abs(d_array[i]); }).wait_and_throw(); 497bd882c8aSJames Wright } break; 498bd882c8aSJames Wright case CEED_NORM_2: { 499bd882c8aSJames Wright // Order queue 500bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 501bd882c8aSJames Wright auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 502bd882c8aSJames Wright data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += (d_array[i] * d_array[i]); }).wait_and_throw(); 503bd882c8aSJames Wright } break; 504bd882c8aSJames Wright case CEED_NORM_MAX: { 505bd882c8aSJames Wright // Order queue 506bd882c8aSJames Wright sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 507bd882c8aSJames Wright auto maxReduction = sycl::reduction(impl->reduction_norm, sycl::maximum<>(), {sycl::property::reduction::initialize_to_identity{}}); 508bd882c8aSJames Wright data->sycl_queue.parallel_for(length, {e}, maxReduction, [=](sycl::id<1> i, auto &max) { max.combine(abs(d_array[i])); }).wait_and_throw(); 509bd882c8aSJames Wright } break; 510bd882c8aSJames Wright } 511bd882c8aSJames Wright // L2 norm - square root over reduced value 512bd882c8aSJames Wright if (type == CEED_NORM_2) *norm = sqrt(*impl->reduction_norm); 513bd882c8aSJames Wright else *norm = *impl->reduction_norm; 514bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 515bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 516bd882c8aSJames Wright } 517bd882c8aSJames Wright 518bd882c8aSJames Wright //------------------------------------------------------------------------------ 519bd882c8aSJames Wright // Take reciprocal of a vector on host 520bd882c8aSJames Wright //------------------------------------------------------------------------------ 5216ca0f394SUmesh Unnikrishnan static int CeedHostReciprocal_Sycl(CeedScalar *h_array, CeedSize length) { 5226ca0f394SUmesh Unnikrishnan for (CeedSize i = 0; i < length; i++) { 523bd882c8aSJames Wright if (std::fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 524bd882c8aSJames Wright } 525bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 526bd882c8aSJames Wright } 527bd882c8aSJames Wright 528bd882c8aSJames Wright //------------------------------------------------------------------------------ 529bd882c8aSJames Wright // Take reciprocal of a vector on device 530bd882c8aSJames Wright //------------------------------------------------------------------------------ 5316ca0f394SUmesh Unnikrishnan static int CeedDeviceReciprocal_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length) { 532bd882c8aSJames Wright // Order queue 533bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 534bd882c8aSJames Wright sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { 535bd882c8aSJames Wright if (std::fabs(d_array[i]) > CEED_EPSILON) d_array[i] = 1. / d_array[i]; 536bd882c8aSJames Wright }); 537bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 538bd882c8aSJames Wright } 539bd882c8aSJames Wright 540bd882c8aSJames Wright //------------------------------------------------------------------------------ 541bd882c8aSJames Wright // Take reciprocal of a vector 542bd882c8aSJames Wright //------------------------------------------------------------------------------ 543bd882c8aSJames Wright static int CeedVectorReciprocal_Sycl(CeedVector vec) { 544bd882c8aSJames Wright Ceed ceed; 545bd882c8aSJames Wright Ceed_Sycl *data; 546*dd64fc84SJeremy L Thompson CeedSize length; 547*dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 548*dd64fc84SJeremy L Thompson 549*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 550*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 551*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 552bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 553bd882c8aSJames Wright 554bd882c8aSJames Wright // Set value for synced device/host array 555bd882c8aSJames Wright if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Sycl(data->sycl_queue, impl->d_array, length)); 556bd882c8aSJames Wright if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Sycl(impl->h_array, length)); 557bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 558bd882c8aSJames Wright } 559bd882c8aSJames Wright 560bd882c8aSJames Wright //------------------------------------------------------------------------------ 561bd882c8aSJames Wright // Compute x = alpha x on the host 562bd882c8aSJames Wright //------------------------------------------------------------------------------ 5636ca0f394SUmesh Unnikrishnan static int CeedHostScale_Sycl(CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 5646ca0f394SUmesh Unnikrishnan for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha; 565bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 566bd882c8aSJames Wright } 567bd882c8aSJames Wright 568bd882c8aSJames Wright //------------------------------------------------------------------------------ 569bd882c8aSJames Wright // Compute x = alpha x on device 570bd882c8aSJames Wright //------------------------------------------------------------------------------ 5716ca0f394SUmesh Unnikrishnan static int CeedDeviceScale_Sycl(sycl::queue &sycl_queue, CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 572bd882c8aSJames Wright // Order queue 573bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 574bd882c8aSJames Wright sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { x_array[i] *= alpha; }); 575bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 576bd882c8aSJames Wright } 577bd882c8aSJames Wright 578bd882c8aSJames Wright //------------------------------------------------------------------------------ 579bd882c8aSJames Wright // Compute x = alpha x 580bd882c8aSJames Wright //------------------------------------------------------------------------------ 581bd882c8aSJames Wright static int CeedVectorScale_Sycl(CeedVector x, CeedScalar alpha) { 582bd882c8aSJames Wright Ceed ceed; 583bd882c8aSJames Wright Ceed_Sycl *data; 584*dd64fc84SJeremy L Thompson CeedSize length; 585*dd64fc84SJeremy L Thompson CeedVector_Sycl *x_impl; 586*dd64fc84SJeremy L Thompson 587*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 588*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(x, &x_impl)); 589*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(x, &length)); 590bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 591bd882c8aSJames Wright 592bd882c8aSJames Wright // Set value for synced device/host array 593bd882c8aSJames Wright if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Sycl(data->sycl_queue, x_impl->d_array, alpha, length)); 594bd882c8aSJames Wright if (x_impl->h_array) CeedCallBackend(CeedHostScale_Sycl(x_impl->h_array, alpha, length)); 595bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 596bd882c8aSJames Wright } 597bd882c8aSJames Wright 598bd882c8aSJames Wright //------------------------------------------------------------------------------ 599bd882c8aSJames Wright // Compute y = alpha x + y on the host 600bd882c8aSJames Wright //------------------------------------------------------------------------------ 6016ca0f394SUmesh Unnikrishnan static int CeedHostAXPY_Sycl(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 6026ca0f394SUmesh Unnikrishnan for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 603bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 604bd882c8aSJames Wright } 605bd882c8aSJames Wright 606bd882c8aSJames Wright //------------------------------------------------------------------------------ 607bd882c8aSJames Wright // Compute y = alpha x + y on device 608bd882c8aSJames Wright //------------------------------------------------------------------------------ 6096ca0f394SUmesh Unnikrishnan static int CeedDeviceAXPY_Sycl(sycl::queue &sycl_queue, CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 610bd882c8aSJames Wright // Order queue 611bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 612bd882c8aSJames Wright sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { y_array[i] += alpha * x_array[i]; }); 613bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 614bd882c8aSJames Wright } 615bd882c8aSJames Wright 616bd882c8aSJames Wright //------------------------------------------------------------------------------ 617bd882c8aSJames Wright // Compute y = alpha x + y 618bd882c8aSJames Wright //------------------------------------------------------------------------------ 619bd882c8aSJames Wright static int CeedVectorAXPY_Sycl(CeedVector y, CeedScalar alpha, CeedVector x) { 620bd882c8aSJames Wright Ceed ceed; 621*dd64fc84SJeremy L Thompson Ceed_Sycl *data; 622*dd64fc84SJeremy L Thompson CeedSize length; 623bd882c8aSJames Wright CeedVector_Sycl *y_impl, *x_impl; 624*dd64fc84SJeremy L Thompson 625*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 626bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(y, &y_impl)); 627bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(x, &x_impl)); 628bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(y, &length)); 629bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 630bd882c8aSJames Wright 631bd882c8aSJames Wright // Set value for synced device/host array 632bd882c8aSJames Wright if (y_impl->d_array) { 633bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 634bd882c8aSJames Wright CeedCallBackend(CeedDeviceAXPY_Sycl(data->sycl_queue, y_impl->d_array, alpha, x_impl->d_array, length)); 635bd882c8aSJames Wright } 636bd882c8aSJames Wright if (y_impl->h_array) { 637bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 638bd882c8aSJames Wright CeedCallBackend(CeedHostAXPY_Sycl(y_impl->h_array, alpha, x_impl->h_array, length)); 639bd882c8aSJames Wright } 640bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 641bd882c8aSJames Wright } 642bd882c8aSJames Wright 643bd882c8aSJames Wright //------------------------------------------------------------------------------ 644bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y on the host 645bd882c8aSJames Wright //------------------------------------------------------------------------------ 6466ca0f394SUmesh Unnikrishnan static int CeedHostPointwiseMult_Sycl(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 6476ca0f394SUmesh Unnikrishnan for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 648bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 649bd882c8aSJames Wright } 650bd882c8aSJames Wright 651bd882c8aSJames Wright //------------------------------------------------------------------------------ 652bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 653bd882c8aSJames Wright //------------------------------------------------------------------------------ 6546ca0f394SUmesh Unnikrishnan static int CeedDevicePointwiseMult_Sycl(sycl::queue &sycl_queue, CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 655bd882c8aSJames Wright // Order queue 656bd882c8aSJames Wright sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 657bd882c8aSJames Wright sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { w_array[i] = x_array[i] * y_array[i]; }); 658bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 659bd882c8aSJames Wright } 660bd882c8aSJames Wright 661bd882c8aSJames Wright //------------------------------------------------------------------------------ 662bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y 663bd882c8aSJames Wright //------------------------------------------------------------------------------ 664bd882c8aSJames Wright static int CeedVectorPointwiseMult_Sycl(CeedVector w, CeedVector x, CeedVector y) { 665bd882c8aSJames Wright Ceed ceed; 666*dd64fc84SJeremy L Thompson Ceed_Sycl *data; 667*dd64fc84SJeremy L Thompson CeedSize length; 668bd882c8aSJames Wright CeedVector_Sycl *w_impl, *x_impl, *y_impl; 669*dd64fc84SJeremy L Thompson 670*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 671bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(w, &w_impl)); 672bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(x, &x_impl)); 673bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(y, &y_impl)); 674bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(w, &length)); 675bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 676bd882c8aSJames Wright 677bd882c8aSJames Wright // Set value for synced device/host array 678bd882c8aSJames Wright if (!w_impl->d_array && !w_impl->h_array) { 679bd882c8aSJames Wright CeedCallBackend(CeedVectorSetValue(w, 0.0)); 680bd882c8aSJames Wright } 681bd882c8aSJames Wright if (w_impl->d_array) { 682bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 683bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 684bd882c8aSJames Wright CeedCallBackend(CeedDevicePointwiseMult_Sycl(data->sycl_queue, w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 685bd882c8aSJames Wright } 686bd882c8aSJames Wright if (w_impl->h_array) { 687bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 688bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 689bd882c8aSJames Wright CeedCallBackend(CeedHostPointwiseMult_Sycl(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 690bd882c8aSJames Wright } 691bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 692bd882c8aSJames Wright } 693bd882c8aSJames Wright 694bd882c8aSJames Wright //------------------------------------------------------------------------------ 695bd882c8aSJames Wright // Destroy the vector 696bd882c8aSJames Wright //------------------------------------------------------------------------------ 697bd882c8aSJames Wright static int CeedVectorDestroy_Sycl(const CeedVector vec) { 698bd882c8aSJames Wright Ceed ceed; 699bd882c8aSJames Wright Ceed_Sycl *data; 700*dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 701*dd64fc84SJeremy L Thompson 702*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 703*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 704bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 705bd882c8aSJames Wright 706bd882c8aSJames Wright // Wait for all work to finish before freeing memory 707bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 708bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 709bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->reduction_norm, data->sycl_context)); 710bd882c8aSJames Wright 711bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl->h_array_owned)); 712bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 713bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 714bd882c8aSJames Wright } 715bd882c8aSJames Wright 716bd882c8aSJames Wright //------------------------------------------------------------------------------ 717bd882c8aSJames Wright // Create a vector of the specified length (does not allocate memory) 718bd882c8aSJames Wright //------------------------------------------------------------------------------ 719bd882c8aSJames Wright int CeedVectorCreate_Sycl(CeedSize n, CeedVector vec) { 720bd882c8aSJames Wright Ceed ceed; 721bd882c8aSJames Wright Ceed_Sycl *data; 722*dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 723bd882c8aSJames Wright 724*dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 725*dd64fc84SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 726bd882c8aSJames Wright CeedCallBackend(CeedCalloc(1, &impl)); 727bd882c8aSJames Wright CeedCallSycl(ceed, impl->reduction_norm = sycl::malloc_host<CeedScalar>(1, data->sycl_context)); 728bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Sycl)); 729bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Sycl)); 730bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Sycl)); 731bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Sycl)); 732bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Sycl)); 733bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Sycl)); 734bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Sycl)); 735bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Sycl)); 736bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Sycl)); 737bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Norm", CeedVectorNorm_Sycl)); 738bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Sycl)); 739bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Sycl)); 740bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Scale", CeedVectorScale_Sycl)); 741bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Sycl)); 742bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Sycl)); 743bd882c8aSJames Wright CeedCallBackend(CeedVectorSetData(vec, impl)); 744bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 745bd882c8aSJames Wright } 746