1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 11 #include <cmath> 12 #include <string> 13 #include <sycl/sycl.hpp> 14 15 #include "ceed-sycl-ref.hpp" 16 17 //------------------------------------------------------------------------------ 18 // Check if host/device sync is needed 19 //------------------------------------------------------------------------------ 20 static inline int CeedVectorNeedSync_Sycl(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 21 bool has_valid_array = false; 22 CeedVector_Sycl *impl; 23 24 CeedCallBackend(CeedVectorGetData(vec, &impl)); 25 CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 26 switch (mem_type) { 27 case CEED_MEM_HOST: 28 *need_sync = has_valid_array && !impl->h_array; 29 break; 30 case CEED_MEM_DEVICE: 31 *need_sync = has_valid_array && !impl->d_array; 32 break; 33 } 34 return CEED_ERROR_SUCCESS; 35 } 36 37 //------------------------------------------------------------------------------ 38 // Sync host to device 39 //------------------------------------------------------------------------------ 40 static inline int CeedVectorSyncH2D_Sycl(const CeedVector vec) { 41 Ceed ceed; 42 Ceed_Sycl *data; 43 CeedSize length; 44 CeedVector_Sycl *impl; 45 46 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 47 CeedCallBackend(CeedVectorGetData(vec, &impl)); 48 CeedCallBackend(CeedGetData(ceed, &data)); 49 CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 50 51 CeedCallBackend(CeedVectorGetLength(vec, &length)); 52 if (impl->d_array_borrowed) { 53 impl->d_array = impl->d_array_borrowed; 54 } else if (impl->d_array_owned) { 55 impl->d_array = impl->d_array_owned; 56 } else { 57 CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 58 impl->d_array = impl->d_array_owned; 59 } 60 61 // Copy from host to device 62 std::vector<sycl::event> e; 63 64 if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 65 CeedCallSycl(ceed, data->sycl_queue.copy<CeedScalar>(impl->h_array, impl->d_array, length, e).wait_and_throw()); 66 return CEED_ERROR_SUCCESS; 67 } 68 69 //------------------------------------------------------------------------------ 70 // Sync device to host 71 //------------------------------------------------------------------------------ 72 static inline int CeedVectorSyncD2H_Sycl(const CeedVector vec) { 73 Ceed ceed; 74 Ceed_Sycl *data; 75 CeedSize length; 76 CeedVector_Sycl *impl; 77 78 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 79 CeedCallBackend(CeedVectorGetData(vec, &impl)); 80 CeedCallBackend(CeedGetData(ceed, &data)); 81 82 CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 83 84 CeedCallBackend(CeedVectorGetLength(vec, &length)); 85 if (impl->h_array_borrowed) { 86 impl->h_array = impl->h_array_borrowed; 87 } else if (impl->h_array_owned) { 88 impl->h_array = impl->h_array_owned; 89 } else { 90 CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 91 impl->h_array = impl->h_array_owned; 92 } 93 94 // Copy from device to host 95 std::vector<sycl::event> e; 96 97 if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 98 CeedCallSycl(ceed, data->sycl_queue.copy<CeedScalar>(impl->d_array, impl->h_array, length, e).wait_and_throw()); 99 return CEED_ERROR_SUCCESS; 100 } 101 102 //------------------------------------------------------------------------------ 103 // Sync arrays 104 //------------------------------------------------------------------------------ 105 static int CeedVectorSyncArray_Sycl(const CeedVector vec, CeedMemType mem_type) { 106 bool need_sync = false; 107 108 // Check whether device/host sync is needed 109 CeedCallBackend(CeedVectorNeedSync_Sycl(vec, mem_type, &need_sync)); 110 if (!need_sync) return CEED_ERROR_SUCCESS; 111 112 switch (mem_type) { 113 case CEED_MEM_HOST: 114 return CeedVectorSyncD2H_Sycl(vec); 115 case CEED_MEM_DEVICE: 116 return CeedVectorSyncH2D_Sycl(vec); 117 } 118 return CEED_ERROR_UNSUPPORTED; 119 } 120 121 //------------------------------------------------------------------------------ 122 // Set all pointers as invalid 123 //------------------------------------------------------------------------------ 124 static inline int CeedVectorSetAllInvalid_Sycl(const CeedVector vec) { 125 CeedVector_Sycl *impl; 126 127 CeedCallBackend(CeedVectorGetData(vec, &impl)); 128 impl->h_array = NULL; 129 impl->d_array = NULL; 130 return CEED_ERROR_SUCCESS; 131 } 132 133 //------------------------------------------------------------------------------ 134 // Check if CeedVector has any valid pointer 135 //------------------------------------------------------------------------------ 136 static inline int CeedVectorHasValidArray_Sycl(const CeedVector vec, bool *has_valid_array) { 137 CeedVector_Sycl *impl; 138 139 CeedCallBackend(CeedVectorGetData(vec, &impl)); 140 *has_valid_array = impl->h_array || impl->d_array; 141 return CEED_ERROR_SUCCESS; 142 } 143 144 //------------------------------------------------------------------------------ 145 // Check if has array of given type 146 //------------------------------------------------------------------------------ 147 static inline int CeedVectorHasArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 148 CeedVector_Sycl *impl; 149 150 CeedCallBackend(CeedVectorGetData(vec, &impl)); 151 switch (mem_type) { 152 case CEED_MEM_HOST: 153 *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned; 154 break; 155 case CEED_MEM_DEVICE: 156 *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned; 157 break; 158 } 159 return CEED_ERROR_SUCCESS; 160 } 161 162 //------------------------------------------------------------------------------ 163 // Check if has borrowed array of given type 164 //------------------------------------------------------------------------------ 165 static inline int CeedVectorHasBorrowedArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 166 CeedVector_Sycl *impl; 167 168 CeedCallBackend(CeedVectorGetData(vec, &impl)); 169 switch (mem_type) { 170 case CEED_MEM_HOST: 171 *has_borrowed_array_of_type = impl->h_array_borrowed; 172 break; 173 case CEED_MEM_DEVICE: 174 *has_borrowed_array_of_type = impl->d_array_borrowed; 175 break; 176 } 177 return CEED_ERROR_SUCCESS; 178 } 179 180 //------------------------------------------------------------------------------ 181 // Set array from host 182 //------------------------------------------------------------------------------ 183 static int CeedVectorSetArrayHost_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 184 CeedSize length; 185 CeedVector_Sycl *impl; 186 187 CeedCallBackend(CeedVectorGetData(vec, &impl)); 188 CeedCallBackend(CeedVectorGetLength(vec, &length)); 189 190 CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned, 191 (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array)); 192 return CEED_ERROR_SUCCESS; 193 } 194 195 //------------------------------------------------------------------------------ 196 // Set array from device 197 //------------------------------------------------------------------------------ 198 static int CeedVectorSetArrayDevice_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 199 CeedSize length; 200 Ceed ceed; 201 Ceed_Sycl *data; 202 CeedVector_Sycl *impl; 203 204 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 205 CeedCallBackend(CeedVectorGetData(vec, &impl)); 206 CeedCallBackend(CeedGetData(ceed, &data)); 207 CeedCallBackend(CeedVectorGetLength(vec, &length)); 208 209 // Order queue if needed. 210 std::vector<sycl::event> e; 211 212 if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 213 214 switch (copy_mode) { 215 case CEED_COPY_VALUES: { 216 if (!impl->d_array_owned) { 217 CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 218 } 219 if (array) { 220 // Wait for copy to finish and handle exceptions. 221 CeedCallSycl(ceed, data->sycl_queue.copy<CeedScalar>(array, impl->d_array_owned, length, e).wait_and_throw()); 222 } 223 impl->d_array_borrowed = NULL; 224 impl->d_array = impl->d_array_owned; 225 } break; 226 case CEED_OWN_POINTER: 227 if (impl->d_array_owned) { 228 // Wait for all work to finish before freeing memory 229 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 230 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 231 } 232 impl->d_array_owned = array; 233 impl->d_array_borrowed = NULL; 234 impl->d_array = impl->d_array_owned; 235 break; 236 case CEED_USE_POINTER: 237 if (impl->d_array_owned) { 238 // Wait for all work to finish before freeing memory 239 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 240 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 241 } 242 impl->d_array_owned = NULL; 243 impl->d_array_borrowed = array; 244 impl->d_array = impl->d_array_borrowed; 245 break; 246 } 247 return CEED_ERROR_SUCCESS; 248 } 249 250 //------------------------------------------------------------------------------ 251 // Set the array used by a vector, 252 // freeing any previously allocated array if applicable 253 //------------------------------------------------------------------------------ 254 static int CeedVectorSetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 255 CeedVector_Sycl *impl; 256 257 CeedCallBackend(CeedVectorGetData(vec, &impl)); 258 259 CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 260 switch (mem_type) { 261 case CEED_MEM_HOST: 262 return CeedVectorSetArrayHost_Sycl(vec, copy_mode, array); 263 case CEED_MEM_DEVICE: 264 return CeedVectorSetArrayDevice_Sycl(vec, copy_mode, array); 265 } 266 return CEED_ERROR_UNSUPPORTED; 267 } 268 269 //------------------------------------------------------------------------------ 270 // Set host array to value 271 //------------------------------------------------------------------------------ 272 static int CeedHostSetValue_Sycl(CeedScalar *h_array, CeedSize length, CeedScalar val) { 273 for (CeedSize i = 0; i < length; i++) h_array[i] = val; 274 return CEED_ERROR_SUCCESS; 275 } 276 277 //------------------------------------------------------------------------------ 278 // Set device array to value 279 //------------------------------------------------------------------------------ 280 static int CeedDeviceSetValue_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length, CeedScalar val) { 281 std::vector<sycl::event> e; 282 283 if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 284 sycl_queue.fill(d_array, val, length, e); 285 return CEED_ERROR_SUCCESS; 286 } 287 288 //------------------------------------------------------------------------------ 289 // Set a vector to a value, 290 //------------------------------------------------------------------------------ 291 static int CeedVectorSetValue_Sycl(CeedVector vec, CeedScalar val) { 292 Ceed ceed; 293 Ceed_Sycl *data; 294 CeedSize length; 295 CeedVector_Sycl *impl; 296 297 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 298 CeedCallBackend(CeedVectorGetData(vec, &impl)); 299 CeedCallBackend(CeedVectorGetLength(vec, &length)); 300 CeedCallBackend(CeedGetData(ceed, &data)); 301 302 // Set value for synced device/host array 303 if (!impl->d_array && !impl->h_array) { 304 if (impl->d_array_borrowed) { 305 impl->d_array = impl->d_array_borrowed; 306 } else if (impl->h_array_borrowed) { 307 impl->h_array = impl->h_array_borrowed; 308 } else if (impl->d_array_owned) { 309 impl->d_array = impl->d_array_owned; 310 } else if (impl->h_array_owned) { 311 impl->h_array = impl->h_array_owned; 312 } else { 313 CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 314 } 315 } 316 if (impl->d_array) { 317 CeedCallBackend(CeedDeviceSetValue_Sycl(data->sycl_queue, impl->d_array, length, val)); 318 impl->h_array = NULL; 319 } 320 if (impl->h_array) { 321 CeedCallBackend(CeedHostSetValue_Sycl(impl->h_array, length, val)); 322 impl->d_array = NULL; 323 } 324 return CEED_ERROR_SUCCESS; 325 } 326 327 //------------------------------------------------------------------------------ 328 // Vector Take Array 329 //------------------------------------------------------------------------------ 330 static int CeedVectorTakeArray_Sycl(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 331 Ceed ceed; 332 Ceed_Sycl *data; 333 CeedVector_Sycl *impl; 334 335 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 336 CeedCallBackend(CeedVectorGetData(vec, &impl)); 337 CeedCallBackend(CeedGetData(ceed, &data)); 338 339 // Order queue if needed 340 if (!data->sycl_queue.is_in_order()) data->sycl_queue.ext_oneapi_submit_barrier(); 341 342 // Sync array to requested mem_type 343 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 344 345 // Update pointer 346 switch (mem_type) { 347 case CEED_MEM_HOST: 348 (*array) = impl->h_array_borrowed; 349 impl->h_array_borrowed = NULL; 350 impl->h_array = NULL; 351 break; 352 case CEED_MEM_DEVICE: 353 (*array) = impl->d_array_borrowed; 354 impl->d_array_borrowed = NULL; 355 impl->d_array = NULL; 356 break; 357 } 358 return CEED_ERROR_SUCCESS; 359 } 360 361 //------------------------------------------------------------------------------ 362 // Core logic for array syncronization for GetArray. 363 // If a different memory type is most up to date, this will perform a copy 364 //------------------------------------------------------------------------------ 365 static int CeedVectorGetArrayCore_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 366 CeedVector_Sycl *impl; 367 368 CeedCallBackend(CeedVectorGetData(vec, &impl)); 369 370 // Sync array to requested mem_type 371 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 372 373 // Update pointer 374 switch (mem_type) { 375 case CEED_MEM_HOST: 376 *array = impl->h_array; 377 break; 378 case CEED_MEM_DEVICE: 379 *array = impl->d_array; 380 break; 381 } 382 return CEED_ERROR_SUCCESS; 383 } 384 385 //------------------------------------------------------------------------------ 386 // Get read-only access to a vector via the specified mem_type 387 //------------------------------------------------------------------------------ 388 static int CeedVectorGetArrayRead_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 389 return CeedVectorGetArrayCore_Sycl(vec, mem_type, (CeedScalar **)array); 390 } 391 392 //------------------------------------------------------------------------------ 393 // Get read/write access to a vector via the specified mem_type 394 //------------------------------------------------------------------------------ 395 static int CeedVectorGetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 396 CeedVector_Sycl *impl; 397 398 CeedCallBackend(CeedVectorGetData(vec, &impl)); 399 CeedCallBackend(CeedVectorGetArrayCore_Sycl(vec, mem_type, array)); 400 CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 401 switch (mem_type) { 402 case CEED_MEM_HOST: 403 impl->h_array = *array; 404 break; 405 case CEED_MEM_DEVICE: 406 impl->d_array = *array; 407 break; 408 } 409 return CEED_ERROR_SUCCESS; 410 } 411 412 //------------------------------------------------------------------------------ 413 // Get write access to a vector via the specified mem_type 414 //------------------------------------------------------------------------------ 415 static int CeedVectorGetArrayWrite_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 416 bool has_array_of_type = true; 417 CeedVector_Sycl *impl; 418 419 CeedCallBackend(CeedVectorGetData(vec, &impl)); 420 CeedCallBackend(CeedVectorHasArrayOfType_Sycl(vec, mem_type, &has_array_of_type)); 421 if (!has_array_of_type) { 422 // Allocate if array is not yet allocated 423 CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 424 } else { 425 // Select dirty array 426 switch (mem_type) { 427 case CEED_MEM_HOST: 428 if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 429 else impl->h_array = impl->h_array_owned; 430 break; 431 case CEED_MEM_DEVICE: 432 if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 433 else impl->d_array = impl->d_array_owned; 434 } 435 } 436 return CeedVectorGetArray_Sycl(vec, mem_type, array); 437 } 438 439 //------------------------------------------------------------------------------ 440 // Get the norm of a CeedVector 441 //------------------------------------------------------------------------------ 442 static int CeedVectorNorm_Sycl(CeedVector vec, CeedNormType type, CeedScalar *norm) { 443 Ceed ceed; 444 Ceed_Sycl *data; 445 CeedSize length; 446 const CeedScalar *d_array; 447 CeedVector_Sycl *impl; 448 449 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 450 CeedCallBackend(CeedVectorGetData(vec, &impl)); 451 CeedCallBackend(CeedVectorGetLength(vec, &length)); 452 CeedCallBackend(CeedGetData(ceed, &data)); 453 454 // Compute norm 455 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 456 457 std::vector<sycl::event> e; 458 459 if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 460 461 switch (type) { 462 case CEED_NORM_1: { 463 // Order queue 464 auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 465 data->sycl_queue.parallel_for(length, e, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += abs(d_array[i]); }).wait_and_throw(); 466 } break; 467 case CEED_NORM_2: { 468 // Order queue 469 auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 470 data->sycl_queue.parallel_for(length, e, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += (d_array[i] * d_array[i]); }).wait_and_throw(); 471 } break; 472 case CEED_NORM_MAX: { 473 // Order queue 474 auto maxReduction = sycl::reduction(impl->reduction_norm, sycl::maximum<>(), {sycl::property::reduction::initialize_to_identity{}}); 475 data->sycl_queue.parallel_for(length, e, maxReduction, [=](sycl::id<1> i, auto &max) { max.combine(abs(d_array[i])); }).wait_and_throw(); 476 } break; 477 } 478 // L2 norm - square root over reduced value 479 if (type == CEED_NORM_2) *norm = sqrt(*impl->reduction_norm); 480 else *norm = *impl->reduction_norm; 481 CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 482 return CEED_ERROR_SUCCESS; 483 } 484 485 //------------------------------------------------------------------------------ 486 // Take reciprocal of a vector on host 487 //------------------------------------------------------------------------------ 488 static int CeedHostReciprocal_Sycl(CeedScalar *h_array, CeedSize length) { 489 for (CeedSize i = 0; i < length; i++) { 490 if (std::fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 491 } 492 return CEED_ERROR_SUCCESS; 493 } 494 495 //------------------------------------------------------------------------------ 496 // Take reciprocal of a vector on device 497 //------------------------------------------------------------------------------ 498 static int CeedDeviceReciprocal_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length) { 499 std::vector<sycl::event> e; 500 501 if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 502 sycl_queue.parallel_for(length, e, [=](sycl::id<1> i) { 503 if (std::fabs(d_array[i]) > CEED_EPSILON) d_array[i] = 1. / d_array[i]; 504 }); 505 return CEED_ERROR_SUCCESS; 506 } 507 508 //------------------------------------------------------------------------------ 509 // Take reciprocal of a vector 510 //------------------------------------------------------------------------------ 511 static int CeedVectorReciprocal_Sycl(CeedVector vec) { 512 Ceed ceed; 513 Ceed_Sycl *data; 514 CeedSize length; 515 CeedVector_Sycl *impl; 516 517 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 518 CeedCallBackend(CeedVectorGetData(vec, &impl)); 519 CeedCallBackend(CeedVectorGetLength(vec, &length)); 520 CeedCallBackend(CeedGetData(ceed, &data)); 521 522 // Set value for synced device/host array 523 if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Sycl(data->sycl_queue, impl->d_array, length)); 524 if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Sycl(impl->h_array, length)); 525 return CEED_ERROR_SUCCESS; 526 } 527 528 //------------------------------------------------------------------------------ 529 // Compute x = alpha x on the host 530 //------------------------------------------------------------------------------ 531 static int CeedHostScale_Sycl(CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 532 for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha; 533 return CEED_ERROR_SUCCESS; 534 } 535 536 //------------------------------------------------------------------------------ 537 // Compute x = alpha x on device 538 //------------------------------------------------------------------------------ 539 static int CeedDeviceScale_Sycl(sycl::queue &sycl_queue, CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 540 std::vector<sycl::event> e; 541 542 if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 543 sycl_queue.parallel_for(length, e, [=](sycl::id<1> i) { x_array[i] *= alpha; }); 544 return CEED_ERROR_SUCCESS; 545 } 546 547 //------------------------------------------------------------------------------ 548 // Compute x = alpha x 549 //------------------------------------------------------------------------------ 550 static int CeedVectorScale_Sycl(CeedVector x, CeedScalar alpha) { 551 Ceed ceed; 552 Ceed_Sycl *data; 553 CeedSize length; 554 CeedVector_Sycl *x_impl; 555 556 CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 557 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 558 CeedCallBackend(CeedVectorGetLength(x, &length)); 559 CeedCallBackend(CeedGetData(ceed, &data)); 560 561 // Set value for synced device/host array 562 if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Sycl(data->sycl_queue, x_impl->d_array, alpha, length)); 563 if (x_impl->h_array) CeedCallBackend(CeedHostScale_Sycl(x_impl->h_array, alpha, length)); 564 return CEED_ERROR_SUCCESS; 565 } 566 567 //------------------------------------------------------------------------------ 568 // Compute y = alpha x + y on the host 569 //------------------------------------------------------------------------------ 570 static int CeedHostAXPY_Sycl(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 571 for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 572 return CEED_ERROR_SUCCESS; 573 } 574 575 //------------------------------------------------------------------------------ 576 // Compute y = alpha x + y on device 577 //------------------------------------------------------------------------------ 578 static int CeedDeviceAXPY_Sycl(sycl::queue &sycl_queue, CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 579 std::vector<sycl::event> e; 580 581 if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 582 sycl_queue.parallel_for(length, e, [=](sycl::id<1> i) { y_array[i] += alpha * x_array[i]; }); 583 return CEED_ERROR_SUCCESS; 584 } 585 586 //------------------------------------------------------------------------------ 587 // Compute y = alpha x + y 588 //------------------------------------------------------------------------------ 589 static int CeedVectorAXPY_Sycl(CeedVector y, CeedScalar alpha, CeedVector x) { 590 Ceed ceed; 591 Ceed_Sycl *data; 592 CeedSize length; 593 CeedVector_Sycl *y_impl, *x_impl; 594 595 CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 596 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 597 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 598 CeedCallBackend(CeedVectorGetLength(y, &length)); 599 CeedCallBackend(CeedGetData(ceed, &data)); 600 601 // Set value for synced device/host array 602 if (y_impl->d_array) { 603 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 604 CeedCallBackend(CeedDeviceAXPY_Sycl(data->sycl_queue, y_impl->d_array, alpha, x_impl->d_array, length)); 605 } 606 if (y_impl->h_array) { 607 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 608 CeedCallBackend(CeedHostAXPY_Sycl(y_impl->h_array, alpha, x_impl->h_array, length)); 609 } 610 return CEED_ERROR_SUCCESS; 611 } 612 613 //------------------------------------------------------------------------------ 614 // Compute the pointwise multiplication w = x .* y on the host 615 //------------------------------------------------------------------------------ 616 static int CeedHostPointwiseMult_Sycl(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 617 for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 618 return CEED_ERROR_SUCCESS; 619 } 620 621 //------------------------------------------------------------------------------ 622 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 623 //------------------------------------------------------------------------------ 624 static int CeedDevicePointwiseMult_Sycl(sycl::queue &sycl_queue, CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 625 std::vector<sycl::event> e; 626 627 if (!sycl_queue.is_in_order()) e = {sycl_queue.ext_oneapi_submit_barrier()}; 628 sycl_queue.parallel_for(length, e, [=](sycl::id<1> i) { w_array[i] = x_array[i] * y_array[i]; }); 629 return CEED_ERROR_SUCCESS; 630 } 631 632 //------------------------------------------------------------------------------ 633 // Compute the pointwise multiplication w = x .* y 634 //------------------------------------------------------------------------------ 635 static int CeedVectorPointwiseMult_Sycl(CeedVector w, CeedVector x, CeedVector y) { 636 Ceed ceed; 637 Ceed_Sycl *data; 638 CeedSize length; 639 CeedVector_Sycl *w_impl, *x_impl, *y_impl; 640 641 CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 642 CeedCallBackend(CeedVectorGetData(w, &w_impl)); 643 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 644 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 645 CeedCallBackend(CeedVectorGetLength(w, &length)); 646 CeedCallBackend(CeedGetData(ceed, &data)); 647 648 // Set value for synced device/host array 649 if (!w_impl->d_array && !w_impl->h_array) { 650 CeedCallBackend(CeedVectorSetValue(w, 0.0)); 651 } 652 if (w_impl->d_array) { 653 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 654 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 655 CeedCallBackend(CeedDevicePointwiseMult_Sycl(data->sycl_queue, w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 656 } 657 if (w_impl->h_array) { 658 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 659 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 660 CeedCallBackend(CeedHostPointwiseMult_Sycl(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 661 } 662 return CEED_ERROR_SUCCESS; 663 } 664 665 //------------------------------------------------------------------------------ 666 // Destroy the vector 667 //------------------------------------------------------------------------------ 668 static int CeedVectorDestroy_Sycl(const CeedVector vec) { 669 Ceed ceed; 670 Ceed_Sycl *data; 671 CeedVector_Sycl *impl; 672 673 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 674 CeedCallBackend(CeedVectorGetData(vec, &impl)); 675 CeedCallBackend(CeedGetData(ceed, &data)); 676 677 // Wait for all work to finish before freeing memory 678 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 679 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 680 CeedCallSycl(ceed, sycl::free(impl->reduction_norm, data->sycl_context)); 681 682 CeedCallBackend(CeedFree(&impl->h_array_owned)); 683 CeedCallBackend(CeedFree(&impl)); 684 return CEED_ERROR_SUCCESS; 685 } 686 687 //------------------------------------------------------------------------------ 688 // Create a vector of the specified length (does not allocate memory) 689 //------------------------------------------------------------------------------ 690 int CeedVectorCreate_Sycl(CeedSize n, CeedVector vec) { 691 Ceed ceed; 692 Ceed_Sycl *data; 693 CeedVector_Sycl *impl; 694 695 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 696 CeedCallBackend(CeedGetData(ceed, &data)); 697 CeedCallBackend(CeedCalloc(1, &impl)); 698 CeedCallSycl(ceed, impl->reduction_norm = sycl::malloc_host<CeedScalar>(1, data->sycl_context)); 699 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Sycl)); 700 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Sycl)); 701 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Sycl)); 702 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Sycl)); 703 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Sycl)); 704 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Sycl)); 705 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Sycl)); 706 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Sycl)); 707 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Sycl)); 708 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Norm", CeedVectorNorm_Sycl)); 709 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Sycl)); 710 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Sycl)); 711 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Scale", CeedVectorScale_Sycl)); 712 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Sycl)); 713 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Sycl)); 714 CeedCallBackend(CeedVectorSetData(vec, impl)); 715 return CEED_ERROR_SUCCESS; 716 } 717