15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3bd882c8aSJames Wright // 4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5bd882c8aSJames Wright // 6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7bd882c8aSJames Wright 8bd882c8aSJames Wright #include <ceed/backend.h> 9bd882c8aSJames Wright #include <ceed/ceed.h> 10bd882c8aSJames Wright 11bd882c8aSJames Wright #include <cmath> 12bd882c8aSJames Wright #include <string> 13bd882c8aSJames Wright #include <sycl/sycl.hpp> 14bd882c8aSJames Wright 15bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 16bd882c8aSJames Wright 17bd882c8aSJames Wright //------------------------------------------------------------------------------ 18bd882c8aSJames Wright // Check if host/device sync is needed 19bd882c8aSJames Wright //------------------------------------------------------------------------------ 20bd882c8aSJames Wright static inline int CeedVectorNeedSync_Sycl(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 21bd882c8aSJames Wright bool has_valid_array = false; 22dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 23dd64fc84SJeremy L Thompson 24dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 25bd882c8aSJames Wright CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 26bd882c8aSJames Wright switch (mem_type) { 27bd882c8aSJames Wright case CEED_MEM_HOST: 28bd882c8aSJames Wright *need_sync = has_valid_array && !impl->h_array; 29bd882c8aSJames Wright break; 30bd882c8aSJames Wright case CEED_MEM_DEVICE: 31bd882c8aSJames Wright *need_sync = has_valid_array && !impl->d_array; 32bd882c8aSJames Wright break; 33bd882c8aSJames Wright } 34bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 35bd882c8aSJames Wright } 36bd882c8aSJames Wright 37bd882c8aSJames Wright //------------------------------------------------------------------------------ 38bd882c8aSJames Wright // Sync host to device 39bd882c8aSJames Wright //------------------------------------------------------------------------------ 40bd882c8aSJames Wright static inline int CeedVectorSyncH2D_Sycl(const CeedVector vec) { 41bd882c8aSJames Wright Ceed ceed; 42bd882c8aSJames Wright Ceed_Sycl *data; 43dd64fc84SJeremy L Thompson CeedSize length; 44dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 45dd64fc84SJeremy L Thompson 46dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 47bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 48*9bc66399SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 49*9bc66399SJeremy L Thompson 504e3038a5SJeremy L Thompson CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 51bd882c8aSJames Wright 52bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 53bd882c8aSJames Wright if (impl->d_array_borrowed) { 54bd882c8aSJames Wright impl->d_array = impl->d_array_borrowed; 55bd882c8aSJames Wright } else if (impl->d_array_owned) { 56bd882c8aSJames Wright impl->d_array = impl->d_array_owned; 57bd882c8aSJames Wright } else { 58bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 59bd882c8aSJames Wright impl->d_array = impl->d_array_owned; 60bd882c8aSJames Wright } 61bd882c8aSJames Wright 62bd882c8aSJames Wright // Copy from host to device 631f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 641f4b1b45SUmesh Unnikrishnan 651f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 661f4b1b45SUmesh Unnikrishnan CeedCallSycl(ceed, data->sycl_queue.copy<CeedScalar>(impl->h_array, impl->d_array, length, e).wait_and_throw()); 67*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 68bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 69bd882c8aSJames Wright } 70bd882c8aSJames Wright 71bd882c8aSJames Wright //------------------------------------------------------------------------------ 72bd882c8aSJames Wright // Sync device to host 73bd882c8aSJames Wright //------------------------------------------------------------------------------ 74bd882c8aSJames Wright static inline int CeedVectorSyncD2H_Sycl(const CeedVector vec) { 75bd882c8aSJames Wright Ceed ceed; 76bd882c8aSJames Wright Ceed_Sycl *data; 77dd64fc84SJeremy L Thompson CeedSize length; 78dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 79dd64fc84SJeremy L Thompson 80dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 81bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 82*9bc66399SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 83bd882c8aSJames Wright 84bd882c8aSJames Wright CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 85bd882c8aSJames Wright 86bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(vec, &length)); 87bd882c8aSJames Wright if (impl->h_array_borrowed) { 88bd882c8aSJames Wright impl->h_array = impl->h_array_borrowed; 89bd882c8aSJames Wright } else if (impl->h_array_owned) { 90bd882c8aSJames Wright impl->h_array = impl->h_array_owned; 91bd882c8aSJames Wright } else { 92bd882c8aSJames Wright CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 93bd882c8aSJames Wright impl->h_array = impl->h_array_owned; 94bd882c8aSJames Wright } 95bd882c8aSJames Wright 96bd882c8aSJames Wright // Copy from device to host 971f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 981f4b1b45SUmesh Unnikrishnan 991f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 1001f4b1b45SUmesh Unnikrishnan CeedCallSycl(ceed, data->sycl_queue.copy<CeedScalar>(impl->d_array, impl->h_array, length, e).wait_and_throw()); 101*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 102bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 103bd882c8aSJames Wright } 104bd882c8aSJames Wright 105bd882c8aSJames Wright //------------------------------------------------------------------------------ 106bd882c8aSJames Wright // Sync arrays 107bd882c8aSJames Wright //------------------------------------------------------------------------------ 108bd882c8aSJames Wright static int CeedVectorSyncArray_Sycl(const CeedVector vec, CeedMemType mem_type) { 109bd882c8aSJames Wright bool need_sync = false; 110dd64fc84SJeremy L Thompson 111dd64fc84SJeremy L Thompson // Check whether device/host sync is needed 112bd882c8aSJames Wright CeedCallBackend(CeedVectorNeedSync_Sycl(vec, mem_type, &need_sync)); 113bd882c8aSJames Wright if (!need_sync) return CEED_ERROR_SUCCESS; 114bd882c8aSJames Wright 115bd882c8aSJames Wright switch (mem_type) { 116bd882c8aSJames Wright case CEED_MEM_HOST: 117bd882c8aSJames Wright return CeedVectorSyncD2H_Sycl(vec); 118bd882c8aSJames Wright case CEED_MEM_DEVICE: 119bd882c8aSJames Wright return CeedVectorSyncH2D_Sycl(vec); 120bd882c8aSJames Wright } 121bd882c8aSJames Wright return CEED_ERROR_UNSUPPORTED; 122bd882c8aSJames Wright } 123bd882c8aSJames Wright 124bd882c8aSJames Wright //------------------------------------------------------------------------------ 125bd882c8aSJames Wright // Set all pointers as invalid 126bd882c8aSJames Wright //------------------------------------------------------------------------------ 127bd882c8aSJames Wright static inline int CeedVectorSetAllInvalid_Sycl(const CeedVector vec) { 128bd882c8aSJames Wright CeedVector_Sycl *impl; 129bd882c8aSJames Wright 130dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 131bd882c8aSJames Wright impl->h_array = NULL; 132bd882c8aSJames Wright impl->d_array = NULL; 133bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 134bd882c8aSJames Wright } 135bd882c8aSJames Wright 136bd882c8aSJames Wright //------------------------------------------------------------------------------ 137bd882c8aSJames Wright // Check if CeedVector has any valid pointer 138bd882c8aSJames Wright //------------------------------------------------------------------------------ 139bd882c8aSJames Wright static inline int CeedVectorHasValidArray_Sycl(const CeedVector vec, bool *has_valid_array) { 140bd882c8aSJames Wright CeedVector_Sycl *impl; 141dd64fc84SJeremy L Thompson 142bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 1431c66c397SJeremy L Thompson *has_valid_array = impl->h_array || impl->d_array; 144bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 145bd882c8aSJames Wright } 146bd882c8aSJames Wright 147bd882c8aSJames Wright //------------------------------------------------------------------------------ 148bd882c8aSJames Wright // Check if has array of given type 149bd882c8aSJames Wright //------------------------------------------------------------------------------ 150bd882c8aSJames Wright static inline int CeedVectorHasArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 151bd882c8aSJames Wright CeedVector_Sycl *impl; 152bd882c8aSJames Wright 153dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 154bd882c8aSJames Wright switch (mem_type) { 155bd882c8aSJames Wright case CEED_MEM_HOST: 1561c66c397SJeremy L Thompson *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned; 157bd882c8aSJames Wright break; 158bd882c8aSJames Wright case CEED_MEM_DEVICE: 1591c66c397SJeremy L Thompson *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned; 160bd882c8aSJames Wright break; 161bd882c8aSJames Wright } 162bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 163bd882c8aSJames Wright } 164bd882c8aSJames Wright 165bd882c8aSJames Wright //------------------------------------------------------------------------------ 166bd882c8aSJames Wright // Check if has borrowed array of given type 167bd882c8aSJames Wright //------------------------------------------------------------------------------ 168bd882c8aSJames Wright static inline int CeedVectorHasBorrowedArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 169bd882c8aSJames Wright CeedVector_Sycl *impl; 170bd882c8aSJames Wright 171dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 172bd882c8aSJames Wright switch (mem_type) { 173bd882c8aSJames Wright case CEED_MEM_HOST: 1741c66c397SJeremy L Thompson *has_borrowed_array_of_type = impl->h_array_borrowed; 175bd882c8aSJames Wright break; 176bd882c8aSJames Wright case CEED_MEM_DEVICE: 1771c66c397SJeremy L Thompson *has_borrowed_array_of_type = impl->d_array_borrowed; 178bd882c8aSJames Wright break; 179bd882c8aSJames Wright } 180bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 181bd882c8aSJames Wright } 182bd882c8aSJames Wright 183bd882c8aSJames Wright //------------------------------------------------------------------------------ 184bd882c8aSJames Wright // Set array from host 185bd882c8aSJames Wright //------------------------------------------------------------------------------ 186bd882c8aSJames Wright static int CeedVectorSetArrayHost_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 187f59ebe5eSJeremy L Thompson CeedSize length; 188bd882c8aSJames Wright CeedVector_Sycl *impl; 189bd882c8aSJames Wright 190dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 191f59ebe5eSJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 192f59ebe5eSJeremy L Thompson 193f5d1e504SJeremy L Thompson CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned, 194f5d1e504SJeremy L Thompson (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array)); 195bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 196bd882c8aSJames Wright } 197bd882c8aSJames Wright 198bd882c8aSJames Wright //------------------------------------------------------------------------------ 199bd882c8aSJames Wright // Set array from device 200bd882c8aSJames Wright //------------------------------------------------------------------------------ 201bd882c8aSJames Wright static int CeedVectorSetArrayDevice_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 202f59ebe5eSJeremy L Thompson CeedSize length; 203bd882c8aSJames Wright Ceed ceed; 204bd882c8aSJames Wright Ceed_Sycl *data; 205dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 206dd64fc84SJeremy L Thompson 207dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 208dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 209bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 210f59ebe5eSJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 211bd882c8aSJames Wright 2121f4b1b45SUmesh Unnikrishnan // Order queue if needed. 2131f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 2141f4b1b45SUmesh Unnikrishnan 2151f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 216bd882c8aSJames Wright 217bd882c8aSJames Wright switch (copy_mode) { 218bd882c8aSJames Wright case CEED_COPY_VALUES: { 219e588e9b3SJeremy L Thompson if (!impl->d_array_owned) { 220bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 221e588e9b3SJeremy L Thompson } 222bd882c8aSJames Wright if (array) { 223bd882c8aSJames Wright // Wait for copy to finish and handle exceptions. 2241f4b1b45SUmesh Unnikrishnan CeedCallSycl(ceed, data->sycl_queue.copy<CeedScalar>(array, impl->d_array_owned, length, e).wait_and_throw()); 225bd882c8aSJames Wright } 226f59ebe5eSJeremy L Thompson impl->d_array_borrowed = NULL; 227f59ebe5eSJeremy L Thompson impl->d_array = impl->d_array_owned; 228bd882c8aSJames Wright } break; 229bd882c8aSJames Wright case CEED_OWN_POINTER: 230bd882c8aSJames Wright if (impl->d_array_owned) { 231bd882c8aSJames Wright // Wait for all work to finish before freeing memory 232bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 233bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 234bd882c8aSJames Wright } 235bd882c8aSJames Wright impl->d_array_owned = array; 236bd882c8aSJames Wright impl->d_array_borrowed = NULL; 237f59ebe5eSJeremy L Thompson impl->d_array = impl->d_array_owned; 238bd882c8aSJames Wright break; 239bd882c8aSJames Wright case CEED_USE_POINTER: 240bd882c8aSJames Wright if (impl->d_array_owned) { 241bd882c8aSJames Wright // Wait for all work to finish before freeing memory 242bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 243bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 244bd882c8aSJames Wright } 245bd882c8aSJames Wright impl->d_array_owned = NULL; 246bd882c8aSJames Wright impl->d_array_borrowed = array; 247f59ebe5eSJeremy L Thompson impl->d_array = impl->d_array_borrowed; 248bd882c8aSJames Wright break; 249bd882c8aSJames Wright } 250*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 251bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 252bd882c8aSJames Wright } 253bd882c8aSJames Wright 254bd882c8aSJames Wright //------------------------------------------------------------------------------ 255bd882c8aSJames Wright // Set the array used by a vector, 256bd882c8aSJames Wright // freeing any previously allocated array if applicable 257bd882c8aSJames Wright //------------------------------------------------------------------------------ 258bd882c8aSJames Wright static int CeedVectorSetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 259bd882c8aSJames Wright CeedVector_Sycl *impl; 260dd64fc84SJeremy L Thompson 261bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 262bd882c8aSJames Wright 263bd882c8aSJames Wright CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 264bd882c8aSJames Wright switch (mem_type) { 265bd882c8aSJames Wright case CEED_MEM_HOST: 266bd882c8aSJames Wright return CeedVectorSetArrayHost_Sycl(vec, copy_mode, array); 267bd882c8aSJames Wright case CEED_MEM_DEVICE: 268bd882c8aSJames Wright return CeedVectorSetArrayDevice_Sycl(vec, copy_mode, array); 269bd882c8aSJames Wright } 270bd882c8aSJames Wright return CEED_ERROR_UNSUPPORTED; 271bd882c8aSJames Wright } 272bd882c8aSJames Wright 273bd882c8aSJames Wright //------------------------------------------------------------------------------ 274bd882c8aSJames Wright // Set host array to value 275bd882c8aSJames Wright //------------------------------------------------------------------------------ 2766ca0f394SUmesh Unnikrishnan static int CeedHostSetValue_Sycl(CeedScalar *h_array, CeedSize length, CeedScalar val) { 2776ca0f394SUmesh Unnikrishnan for (CeedSize i = 0; i < length; i++) h_array[i] = val; 278bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 279bd882c8aSJames Wright } 280bd882c8aSJames Wright 281bd882c8aSJames Wright //------------------------------------------------------------------------------ 282bd882c8aSJames Wright // Set device array to value 283bd882c8aSJames Wright //------------------------------------------------------------------------------ 2846ca0f394SUmesh Unnikrishnan static int CeedDeviceSetValue_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length, CeedScalar val) { 2851f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 2861f4b1b45SUmesh Unnikrishnan 2871f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 2881f4b1b45SUmesh Unnikrishnan sycl_queue.fill(d_array, val, length, e); 289bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 290bd882c8aSJames Wright } 291bd882c8aSJames Wright 292bd882c8aSJames Wright //------------------------------------------------------------------------------ 293bd882c8aSJames Wright // Set a vector to a value, 294bd882c8aSJames Wright //------------------------------------------------------------------------------ 295bd882c8aSJames Wright static int CeedVectorSetValue_Sycl(CeedVector vec, CeedScalar val) { 296bd882c8aSJames Wright Ceed ceed; 297bd882c8aSJames Wright Ceed_Sycl *data; 298dd64fc84SJeremy L Thompson CeedSize length; 299dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 300dd64fc84SJeremy L Thompson 301dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 302*9bc66399SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 303*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 304dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 305dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 306bd882c8aSJames Wright 307bd882c8aSJames Wright // Set value for synced device/host array 308bd882c8aSJames Wright if (!impl->d_array && !impl->h_array) { 309bd882c8aSJames Wright if (impl->d_array_borrowed) { 310bd882c8aSJames Wright impl->d_array = impl->d_array_borrowed; 311bd882c8aSJames Wright } else if (impl->h_array_borrowed) { 312bd882c8aSJames Wright impl->h_array = impl->h_array_borrowed; 313bd882c8aSJames Wright } else if (impl->d_array_owned) { 314bd882c8aSJames Wright impl->d_array = impl->d_array_owned; 315bd882c8aSJames Wright } else if (impl->h_array_owned) { 316bd882c8aSJames Wright impl->h_array = impl->h_array_owned; 317bd882c8aSJames Wright } else { 318bd882c8aSJames Wright CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 319bd882c8aSJames Wright } 320bd882c8aSJames Wright } 321bd882c8aSJames Wright if (impl->d_array) { 322bd882c8aSJames Wright CeedCallBackend(CeedDeviceSetValue_Sycl(data->sycl_queue, impl->d_array, length, val)); 323bd882c8aSJames Wright impl->h_array = NULL; 324bd882c8aSJames Wright } 325bd882c8aSJames Wright if (impl->h_array) { 326bd882c8aSJames Wright CeedCallBackend(CeedHostSetValue_Sycl(impl->h_array, length, val)); 327bd882c8aSJames Wright impl->d_array = NULL; 328bd882c8aSJames Wright } 329bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 330bd882c8aSJames Wright } 331bd882c8aSJames Wright 332bd882c8aSJames Wright //------------------------------------------------------------------------------ 333bd882c8aSJames Wright // Vector Take Array 334bd882c8aSJames Wright //------------------------------------------------------------------------------ 335bd882c8aSJames Wright static int CeedVectorTakeArray_Sycl(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 336bd882c8aSJames Wright Ceed ceed; 337bd882c8aSJames Wright Ceed_Sycl *data; 338dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 339dd64fc84SJeremy L Thompson 340dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 341bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 342*9bc66399SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 343*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 344*9bc66399SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 345bd882c8aSJames Wright 3461f4b1b45SUmesh Unnikrishnan // Order queue if needed 3471f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) data->sycl_queue.ext_oneapi_submit_barrier(); 348bd882c8aSJames Wright 349bd882c8aSJames Wright // Sync array to requested mem_type 350bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 351bd882c8aSJames Wright 352bd882c8aSJames Wright // Update pointer 353bd882c8aSJames Wright switch (mem_type) { 354bd882c8aSJames Wright case CEED_MEM_HOST: 355bd882c8aSJames Wright (*array) = impl->h_array_borrowed; 356bd882c8aSJames Wright impl->h_array_borrowed = NULL; 357bd882c8aSJames Wright impl->h_array = NULL; 358bd882c8aSJames Wright break; 359bd882c8aSJames Wright case CEED_MEM_DEVICE: 360bd882c8aSJames Wright (*array) = impl->d_array_borrowed; 361bd882c8aSJames Wright impl->d_array_borrowed = NULL; 362bd882c8aSJames Wright impl->d_array = NULL; 363bd882c8aSJames Wright break; 364bd882c8aSJames Wright } 365bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 366bd882c8aSJames Wright } 367bd882c8aSJames Wright 368bd882c8aSJames Wright //------------------------------------------------------------------------------ 369bd882c8aSJames Wright // Core logic for array syncronization for GetArray. 370bd882c8aSJames Wright // If a different memory type is most up to date, this will perform a copy 371bd882c8aSJames Wright //------------------------------------------------------------------------------ 372bd882c8aSJames Wright static int CeedVectorGetArrayCore_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 373bd882c8aSJames Wright CeedVector_Sycl *impl; 374dd64fc84SJeremy L Thompson 375bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 376bd882c8aSJames Wright 377bd882c8aSJames Wright // Sync array to requested mem_type 378bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 379bd882c8aSJames Wright 380bd882c8aSJames Wright // Update pointer 381bd882c8aSJames Wright switch (mem_type) { 382bd882c8aSJames Wright case CEED_MEM_HOST: 383bd882c8aSJames Wright *array = impl->h_array; 384bd882c8aSJames Wright break; 385bd882c8aSJames Wright case CEED_MEM_DEVICE: 386bd882c8aSJames Wright *array = impl->d_array; 387bd882c8aSJames Wright break; 388bd882c8aSJames Wright } 389bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 390bd882c8aSJames Wright } 391ff1e7120SSebastian Grimberg 392bd882c8aSJames Wright //------------------------------------------------------------------------------ 393bd882c8aSJames Wright // Get read-only access to a vector via the specified mem_type 394bd882c8aSJames Wright //------------------------------------------------------------------------------ 395bd882c8aSJames Wright static int CeedVectorGetArrayRead_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 396bd882c8aSJames Wright return CeedVectorGetArrayCore_Sycl(vec, mem_type, (CeedScalar **)array); 397bd882c8aSJames Wright } 398bd882c8aSJames Wright 399bd882c8aSJames Wright //------------------------------------------------------------------------------ 400bd882c8aSJames Wright // Get read/write access to a vector via the specified mem_type 401bd882c8aSJames Wright //------------------------------------------------------------------------------ 402bd882c8aSJames Wright static int CeedVectorGetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 403bd882c8aSJames Wright CeedVector_Sycl *impl; 404dd64fc84SJeremy L Thompson 405bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(vec, &impl)); 406bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayCore_Sycl(vec, mem_type, array)); 407bd882c8aSJames Wright CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 408bd882c8aSJames Wright switch (mem_type) { 409bd882c8aSJames Wright case CEED_MEM_HOST: 410bd882c8aSJames Wright impl->h_array = *array; 411bd882c8aSJames Wright break; 412bd882c8aSJames Wright case CEED_MEM_DEVICE: 413bd882c8aSJames Wright impl->d_array = *array; 414bd882c8aSJames Wright break; 415bd882c8aSJames Wright } 416bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 417bd882c8aSJames Wright } 418bd882c8aSJames Wright 419bd882c8aSJames Wright //------------------------------------------------------------------------------ 420bd882c8aSJames Wright // Get write access to a vector via the specified mem_type 421bd882c8aSJames Wright //------------------------------------------------------------------------------ 422bd882c8aSJames Wright static int CeedVectorGetArrayWrite_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 423bd882c8aSJames Wright bool has_array_of_type = true; 424dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 425dd64fc84SJeremy L Thompson 426dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 427bd882c8aSJames Wright CeedCallBackend(CeedVectorHasArrayOfType_Sycl(vec, mem_type, &has_array_of_type)); 428bd882c8aSJames Wright if (!has_array_of_type) { 429bd882c8aSJames Wright // Allocate if array is not yet allocated 430bd882c8aSJames Wright CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 431bd882c8aSJames Wright } else { 432bd882c8aSJames Wright // Select dirty array 433bd882c8aSJames Wright switch (mem_type) { 434bd882c8aSJames Wright case CEED_MEM_HOST: 435bd882c8aSJames Wright if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 436bd882c8aSJames Wright else impl->h_array = impl->h_array_owned; 437bd882c8aSJames Wright break; 438bd882c8aSJames Wright case CEED_MEM_DEVICE: 439bd882c8aSJames Wright if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 440bd882c8aSJames Wright else impl->d_array = impl->d_array_owned; 441bd882c8aSJames Wright } 442bd882c8aSJames Wright } 443bd882c8aSJames Wright return CeedVectorGetArray_Sycl(vec, mem_type, array); 444bd882c8aSJames Wright } 445bd882c8aSJames Wright 446bd882c8aSJames Wright //------------------------------------------------------------------------------ 447bd882c8aSJames Wright // Get the norm of a CeedVector 448bd882c8aSJames Wright //------------------------------------------------------------------------------ 449bd882c8aSJames Wright static int CeedVectorNorm_Sycl(CeedVector vec, CeedNormType type, CeedScalar *norm) { 450bd882c8aSJames Wright Ceed ceed; 451bd882c8aSJames Wright Ceed_Sycl *data; 452dd64fc84SJeremy L Thompson CeedSize length; 453dd64fc84SJeremy L Thompson const CeedScalar *d_array; 454dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 455dd64fc84SJeremy L Thompson 456dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 457*9bc66399SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 458*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 459dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 460dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 461bd882c8aSJames Wright 462bd882c8aSJames Wright // Compute norm 463bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 4641f4b1b45SUmesh Unnikrishnan 4651f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 4661f4b1b45SUmesh Unnikrishnan 4671f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 4681f4b1b45SUmesh Unnikrishnan 469bd882c8aSJames Wright switch (type) { 470bd882c8aSJames Wright case CEED_NORM_1: { 471bd882c8aSJames Wright // Order queue 472bd882c8aSJames Wright auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 4731f4b1b45SUmesh Unnikrishnan data->sycl_queue.parallel_for(length, e, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += abs(d_array[i]); }).wait_and_throw(); 474bd882c8aSJames Wright } break; 475bd882c8aSJames Wright case CEED_NORM_2: { 476bd882c8aSJames Wright // Order queue 477bd882c8aSJames Wright auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 4781f4b1b45SUmesh Unnikrishnan data->sycl_queue.parallel_for(length, e, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += (d_array[i] * d_array[i]); }).wait_and_throw(); 479bd882c8aSJames Wright } break; 480bd882c8aSJames Wright case CEED_NORM_MAX: { 481bd882c8aSJames Wright // Order queue 482bd882c8aSJames Wright auto maxReduction = sycl::reduction(impl->reduction_norm, sycl::maximum<>(), {sycl::property::reduction::initialize_to_identity{}}); 4831f4b1b45SUmesh Unnikrishnan data->sycl_queue.parallel_for(length, e, maxReduction, [=](sycl::id<1> i, auto &max) { max.combine(abs(d_array[i])); }).wait_and_throw(); 484bd882c8aSJames Wright } break; 485bd882c8aSJames Wright } 486bd882c8aSJames Wright // L2 norm - square root over reduced value 487bd882c8aSJames Wright if (type == CEED_NORM_2) *norm = sqrt(*impl->reduction_norm); 488bd882c8aSJames Wright else *norm = *impl->reduction_norm; 489bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 490bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 491bd882c8aSJames Wright } 492bd882c8aSJames Wright 493bd882c8aSJames Wright //------------------------------------------------------------------------------ 494bd882c8aSJames Wright // Take reciprocal of a vector on host 495bd882c8aSJames Wright //------------------------------------------------------------------------------ 4966ca0f394SUmesh Unnikrishnan static int CeedHostReciprocal_Sycl(CeedScalar *h_array, CeedSize length) { 4976ca0f394SUmesh Unnikrishnan for (CeedSize i = 0; i < length; i++) { 498bd882c8aSJames Wright if (std::fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 499bd882c8aSJames Wright } 500bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 501bd882c8aSJames Wright } 502bd882c8aSJames Wright 503bd882c8aSJames Wright //------------------------------------------------------------------------------ 504bd882c8aSJames Wright // Take reciprocal of a vector on device 505bd882c8aSJames Wright //------------------------------------------------------------------------------ 5066ca0f394SUmesh Unnikrishnan static int CeedDeviceReciprocal_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length) { 5071f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 5081f4b1b45SUmesh Unnikrishnan 5091f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 5101f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for(length, e, [=](sycl::id<1> i) { 511bd882c8aSJames Wright if (std::fabs(d_array[i]) > CEED_EPSILON) d_array[i] = 1. / d_array[i]; 512bd882c8aSJames Wright }); 513bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 514bd882c8aSJames Wright } 515bd882c8aSJames Wright 516bd882c8aSJames Wright //------------------------------------------------------------------------------ 517bd882c8aSJames Wright // Take reciprocal of a vector 518bd882c8aSJames Wright //------------------------------------------------------------------------------ 519bd882c8aSJames Wright static int CeedVectorReciprocal_Sycl(CeedVector vec) { 520bd882c8aSJames Wright Ceed ceed; 521bd882c8aSJames Wright Ceed_Sycl *data; 522dd64fc84SJeremy L Thompson CeedSize length; 523dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 524dd64fc84SJeremy L Thompson 525dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 526*9bc66399SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 527*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 528dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 529dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 530bd882c8aSJames Wright 531bd882c8aSJames Wright // Set value for synced device/host array 532bd882c8aSJames Wright if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Sycl(data->sycl_queue, impl->d_array, length)); 533bd882c8aSJames Wright if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Sycl(impl->h_array, length)); 534bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 535bd882c8aSJames Wright } 536bd882c8aSJames Wright 537bd882c8aSJames Wright //------------------------------------------------------------------------------ 538bd882c8aSJames Wright // Compute x = alpha x on the host 539bd882c8aSJames Wright //------------------------------------------------------------------------------ 5406ca0f394SUmesh Unnikrishnan static int CeedHostScale_Sycl(CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 5416ca0f394SUmesh Unnikrishnan for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha; 542bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 543bd882c8aSJames Wright } 544bd882c8aSJames Wright 545bd882c8aSJames Wright //------------------------------------------------------------------------------ 546bd882c8aSJames Wright // Compute x = alpha x on device 547bd882c8aSJames Wright //------------------------------------------------------------------------------ 5486ca0f394SUmesh Unnikrishnan static int CeedDeviceScale_Sycl(sycl::queue &sycl_queue, CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 5491f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 5501f4b1b45SUmesh Unnikrishnan 5511f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 5521f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for(length, e, [=](sycl::id<1> i) { x_array[i] *= alpha; }); 553bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 554bd882c8aSJames Wright } 555bd882c8aSJames Wright 556bd882c8aSJames Wright //------------------------------------------------------------------------------ 557bd882c8aSJames Wright // Compute x = alpha x 558bd882c8aSJames Wright //------------------------------------------------------------------------------ 559bd882c8aSJames Wright static int CeedVectorScale_Sycl(CeedVector x, CeedScalar alpha) { 560bd882c8aSJames Wright Ceed ceed; 561bd882c8aSJames Wright Ceed_Sycl *data; 562dd64fc84SJeremy L Thompson CeedSize length; 563dd64fc84SJeremy L Thompson CeedVector_Sycl *x_impl; 564dd64fc84SJeremy L Thompson 565dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 566*9bc66399SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 567*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 568dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(x, &x_impl)); 569dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(x, &length)); 570bd882c8aSJames Wright 571bd882c8aSJames Wright // Set value for synced device/host array 572bd882c8aSJames Wright if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Sycl(data->sycl_queue, x_impl->d_array, alpha, length)); 573bd882c8aSJames Wright if (x_impl->h_array) CeedCallBackend(CeedHostScale_Sycl(x_impl->h_array, alpha, length)); 574bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 575bd882c8aSJames Wright } 576bd882c8aSJames Wright 577bd882c8aSJames Wright //------------------------------------------------------------------------------ 578bd882c8aSJames Wright // Compute y = alpha x + y on the host 579bd882c8aSJames Wright //------------------------------------------------------------------------------ 5806ca0f394SUmesh Unnikrishnan static int CeedHostAXPY_Sycl(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 5816ca0f394SUmesh Unnikrishnan for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 582bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 583bd882c8aSJames Wright } 584bd882c8aSJames Wright 585bd882c8aSJames Wright //------------------------------------------------------------------------------ 586bd882c8aSJames Wright // Compute y = alpha x + y on device 587bd882c8aSJames Wright //------------------------------------------------------------------------------ 5886ca0f394SUmesh Unnikrishnan static int CeedDeviceAXPY_Sycl(sycl::queue &sycl_queue, CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 5891f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 5901f4b1b45SUmesh Unnikrishnan 5911f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 5921f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for(length, e, [=](sycl::id<1> i) { y_array[i] += alpha * x_array[i]; }); 593bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 594bd882c8aSJames Wright } 595bd882c8aSJames Wright 596bd882c8aSJames Wright //------------------------------------------------------------------------------ 597bd882c8aSJames Wright // Compute y = alpha x + y 598bd882c8aSJames Wright //------------------------------------------------------------------------------ 599bd882c8aSJames Wright static int CeedVectorAXPY_Sycl(CeedVector y, CeedScalar alpha, CeedVector x) { 600bd882c8aSJames Wright Ceed ceed; 601dd64fc84SJeremy L Thompson Ceed_Sycl *data; 602dd64fc84SJeremy L Thompson CeedSize length; 603bd882c8aSJames Wright CeedVector_Sycl *y_impl, *x_impl; 604dd64fc84SJeremy L Thompson 605dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 606*9bc66399SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 607*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 608bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(y, &y_impl)); 609bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(x, &x_impl)); 610bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(y, &length)); 611bd882c8aSJames Wright 612bd882c8aSJames Wright // Set value for synced device/host array 613bd882c8aSJames Wright if (y_impl->d_array) { 614bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 615bd882c8aSJames Wright CeedCallBackend(CeedDeviceAXPY_Sycl(data->sycl_queue, y_impl->d_array, alpha, x_impl->d_array, length)); 616bd882c8aSJames Wright } 617bd882c8aSJames Wright if (y_impl->h_array) { 618bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 619bd882c8aSJames Wright CeedCallBackend(CeedHostAXPY_Sycl(y_impl->h_array, alpha, x_impl->h_array, length)); 620bd882c8aSJames Wright } 621bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 622bd882c8aSJames Wright } 623bd882c8aSJames Wright 624bd882c8aSJames Wright //------------------------------------------------------------------------------ 625bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y on the host 626bd882c8aSJames Wright //------------------------------------------------------------------------------ 6276ca0f394SUmesh Unnikrishnan static int CeedHostPointwiseMult_Sycl(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 6286ca0f394SUmesh Unnikrishnan for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 629bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 630bd882c8aSJames Wright } 631bd882c8aSJames Wright 632bd882c8aSJames Wright //------------------------------------------------------------------------------ 633bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 634bd882c8aSJames Wright //------------------------------------------------------------------------------ 6356ca0f394SUmesh Unnikrishnan static int CeedDevicePointwiseMult_Sycl(sycl::queue &sycl_queue, CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 6361f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 6371f4b1b45SUmesh Unnikrishnan 6381f4b1b45SUmesh Unnikrishnan if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 6391f4b1b45SUmesh Unnikrishnan sycl_queue.parallel_for(length, e, [=](sycl::id<1> i) { w_array[i] = x_array[i] * y_array[i]; }); 640bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 641bd882c8aSJames Wright } 642bd882c8aSJames Wright 643bd882c8aSJames Wright //------------------------------------------------------------------------------ 644bd882c8aSJames Wright // Compute the pointwise multiplication w = x .* y 645bd882c8aSJames Wright //------------------------------------------------------------------------------ 646bd882c8aSJames Wright static int CeedVectorPointwiseMult_Sycl(CeedVector w, CeedVector x, CeedVector y) { 647bd882c8aSJames Wright Ceed ceed; 648dd64fc84SJeremy L Thompson Ceed_Sycl *data; 649dd64fc84SJeremy L Thompson CeedSize length; 650bd882c8aSJames Wright CeedVector_Sycl *w_impl, *x_impl, *y_impl; 651dd64fc84SJeremy L Thompson 652dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 653*9bc66399SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 654*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 655bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(w, &w_impl)); 656bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(x, &x_impl)); 657bd882c8aSJames Wright CeedCallBackend(CeedVectorGetData(y, &y_impl)); 658bd882c8aSJames Wright CeedCallBackend(CeedVectorGetLength(w, &length)); 659bd882c8aSJames Wright 660bd882c8aSJames Wright // Set value for synced device/host array 661bd882c8aSJames Wright if (!w_impl->d_array && !w_impl->h_array) { 662bd882c8aSJames Wright CeedCallBackend(CeedVectorSetValue(w, 0.0)); 663bd882c8aSJames Wright } 664bd882c8aSJames Wright if (w_impl->d_array) { 665bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 666bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 667bd882c8aSJames Wright CeedCallBackend(CeedDevicePointwiseMult_Sycl(data->sycl_queue, w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 668bd882c8aSJames Wright } 669bd882c8aSJames Wright if (w_impl->h_array) { 670bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 671bd882c8aSJames Wright CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 672bd882c8aSJames Wright CeedCallBackend(CeedHostPointwiseMult_Sycl(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 673bd882c8aSJames Wright } 674bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 675bd882c8aSJames Wright } 676bd882c8aSJames Wright 677bd882c8aSJames Wright //------------------------------------------------------------------------------ 678bd882c8aSJames Wright // Destroy the vector 679bd882c8aSJames Wright //------------------------------------------------------------------------------ 680bd882c8aSJames Wright static int CeedVectorDestroy_Sycl(const CeedVector vec) { 681bd882c8aSJames Wright Ceed ceed; 682bd882c8aSJames Wright Ceed_Sycl *data; 683dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 684dd64fc84SJeremy L Thompson 685dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 686dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 687bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 688bd882c8aSJames Wright 689bd882c8aSJames Wright // Wait for all work to finish before freeing memory 690bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 691bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 692bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->reduction_norm, data->sycl_context)); 693bd882c8aSJames Wright 694bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl->h_array_owned)); 695bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 696*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 697bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 698bd882c8aSJames Wright } 699bd882c8aSJames Wright 700bd882c8aSJames Wright //------------------------------------------------------------------------------ 701bd882c8aSJames Wright // Create a vector of the specified length (does not allocate memory) 702bd882c8aSJames Wright //------------------------------------------------------------------------------ 703bd882c8aSJames Wright int CeedVectorCreate_Sycl(CeedSize n, CeedVector vec) { 704bd882c8aSJames Wright Ceed ceed; 705bd882c8aSJames Wright Ceed_Sycl *data; 706dd64fc84SJeremy L Thompson CeedVector_Sycl *impl; 707bd882c8aSJames Wright 708dd64fc84SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 709dd64fc84SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 710bd882c8aSJames Wright CeedCallBackend(CeedCalloc(1, &impl)); 711bd882c8aSJames Wright CeedCallSycl(ceed, impl->reduction_norm = sycl::malloc_host<CeedScalar>(1, data->sycl_context)); 712bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Sycl)); 713bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Sycl)); 714bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Sycl)); 715bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Sycl)); 716bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Sycl)); 717bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Sycl)); 718bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Sycl)); 719bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Sycl)); 720bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Sycl)); 721bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Norm", CeedVectorNorm_Sycl)); 722bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Sycl)); 723bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Sycl)); 724bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Scale", CeedVectorScale_Sycl)); 725bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Sycl)); 726bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Sycl)); 727*9bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 728bd882c8aSJames Wright CeedCallBackend(CeedVectorSetData(vec, impl)); 729bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 730bd882c8aSJames Wright } 731