1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 11 #include <cmath> 12 #include <string> 13 #include <sycl/sycl.hpp> 14 15 #include "ceed-sycl-ref.hpp" 16 17 //------------------------------------------------------------------------------ 18 // Check if host/device sync is needed 19 //------------------------------------------------------------------------------ 20 static inline int CeedVectorNeedSync_Sycl(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 21 CeedVector_Sycl *impl; 22 CeedCallBackend(CeedVectorGetData(vec, &impl)); 23 24 bool has_valid_array = false; 25 CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 26 switch (mem_type) { 27 case CEED_MEM_HOST: 28 *need_sync = has_valid_array && !impl->h_array; 29 break; 30 case CEED_MEM_DEVICE: 31 *need_sync = has_valid_array && !impl->d_array; 32 break; 33 } 34 35 return CEED_ERROR_SUCCESS; 36 } 37 38 //------------------------------------------------------------------------------ 39 // Sync host to device 40 //------------------------------------------------------------------------------ 41 static inline int CeedVectorSyncH2D_Sycl(const CeedVector vec) { 42 Ceed ceed; 43 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 44 CeedVector_Sycl *impl; 45 CeedCallBackend(CeedVectorGetData(vec, &impl)); 46 Ceed_Sycl *data; 47 CeedCallBackend(CeedGetData(ceed, &data)); 48 49 if (!impl->h_array) { 50 // LCOV_EXCL_START 51 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 52 // LCOV_EXCL_STOP 53 } 54 55 CeedSize length; 56 CeedCallBackend(CeedVectorGetLength(vec, &length)); 57 58 if (impl->d_array_borrowed) { 59 impl->d_array = impl->d_array_borrowed; 60 } else if (impl->d_array_owned) { 61 impl->d_array = impl->d_array_owned; 62 } else { 63 CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 64 impl->d_array = impl->d_array_owned; 65 } 66 67 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 68 // Copy from host to device 69 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->h_array, impl->d_array, length, {e}); 70 // Wait for copy to finish and handle exceptions. 71 CeedCallSycl(ceed, copy_event.wait_and_throw()); 72 73 return CEED_ERROR_SUCCESS; 74 } 75 76 //------------------------------------------------------------------------------ 77 // Sync device to host 78 //------------------------------------------------------------------------------ 79 static inline int CeedVectorSyncD2H_Sycl(const CeedVector vec) { 80 Ceed ceed; 81 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 82 CeedVector_Sycl *impl; 83 CeedCallBackend(CeedVectorGetData(vec, &impl)); 84 Ceed_Sycl *data; 85 CeedCallBackend(CeedGetData(ceed, &data)); 86 87 CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 88 89 CeedSize length; 90 CeedCallBackend(CeedVectorGetLength(vec, &length)); 91 92 if (impl->h_array_borrowed) { 93 impl->h_array = impl->h_array_borrowed; 94 } else if (impl->h_array_owned) { 95 impl->h_array = impl->h_array_owned; 96 } else { 97 CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 98 impl->h_array = impl->h_array_owned; 99 } 100 101 // Order queue 102 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 103 // Copy from device to host 104 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->d_array, impl->h_array, length, {e}); 105 106 // Wait for copy to finish and handle exceptions. 107 CeedCallSycl(ceed, copy_event.wait_and_throw()); 108 109 return CEED_ERROR_SUCCESS; 110 } 111 112 //------------------------------------------------------------------------------ 113 // Sync arrays 114 //------------------------------------------------------------------------------ 115 static int CeedVectorSyncArray_Sycl(const CeedVector vec, CeedMemType mem_type) { 116 // Check whether device/host sync is needed 117 bool need_sync = false; 118 CeedCallBackend(CeedVectorNeedSync_Sycl(vec, mem_type, &need_sync)); 119 if (!need_sync) return CEED_ERROR_SUCCESS; 120 121 switch (mem_type) { 122 case CEED_MEM_HOST: 123 return CeedVectorSyncD2H_Sycl(vec); 124 case CEED_MEM_DEVICE: 125 return CeedVectorSyncH2D_Sycl(vec); 126 } 127 return CEED_ERROR_UNSUPPORTED; 128 } 129 130 //------------------------------------------------------------------------------ 131 // Set all pointers as invalid 132 //------------------------------------------------------------------------------ 133 static inline int CeedVectorSetAllInvalid_Sycl(const CeedVector vec) { 134 CeedVector_Sycl *impl; 135 CeedCallBackend(CeedVectorGetData(vec, &impl)); 136 137 impl->h_array = NULL; 138 impl->d_array = NULL; 139 140 return CEED_ERROR_SUCCESS; 141 } 142 143 //------------------------------------------------------------------------------ 144 // Check if CeedVector has any valid pointer 145 //------------------------------------------------------------------------------ 146 static inline int CeedVectorHasValidArray_Sycl(const CeedVector vec, bool *has_valid_array) { 147 CeedVector_Sycl *impl; 148 CeedCallBackend(CeedVectorGetData(vec, &impl)); 149 150 *has_valid_array = impl->h_array || impl->d_array; 151 152 return CEED_ERROR_SUCCESS; 153 } 154 155 //------------------------------------------------------------------------------ 156 // Check if has array of given type 157 //------------------------------------------------------------------------------ 158 static inline int CeedVectorHasArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 159 CeedVector_Sycl *impl; 160 CeedCallBackend(CeedVectorGetData(vec, &impl)); 161 162 switch (mem_type) { 163 case CEED_MEM_HOST: 164 *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned; 165 break; 166 case CEED_MEM_DEVICE: 167 *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned; 168 break; 169 } 170 171 return CEED_ERROR_SUCCESS; 172 } 173 174 //------------------------------------------------------------------------------ 175 // Check if has borrowed array of given type 176 //------------------------------------------------------------------------------ 177 static inline int CeedVectorHasBorrowedArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 178 CeedVector_Sycl *impl; 179 CeedCallBackend(CeedVectorGetData(vec, &impl)); 180 181 switch (mem_type) { 182 case CEED_MEM_HOST: 183 *has_borrowed_array_of_type = impl->h_array_borrowed; 184 break; 185 case CEED_MEM_DEVICE: 186 *has_borrowed_array_of_type = impl->d_array_borrowed; 187 break; 188 } 189 190 return CEED_ERROR_SUCCESS; 191 } 192 193 //------------------------------------------------------------------------------ 194 // Set array from host 195 //------------------------------------------------------------------------------ 196 static int CeedVectorSetArrayHost_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 197 CeedVector_Sycl *impl; 198 CeedCallBackend(CeedVectorGetData(vec, &impl)); 199 200 switch (copy_mode) { 201 case CEED_COPY_VALUES: { 202 if (!impl->h_array_owned) { 203 CeedSize length; 204 CeedCallBackend(CeedVectorGetLength(vec, &length)); 205 CeedCallBackend(CeedMalloc(length, &impl->h_array_owned)); 206 } 207 impl->h_array_borrowed = NULL; 208 impl->h_array = impl->h_array_owned; 209 if (array) { 210 CeedSize length; 211 CeedCallBackend(CeedVectorGetLength(vec, &length)); 212 size_t bytes = length * sizeof(CeedScalar); 213 memcpy(impl->h_array, array, bytes); 214 } 215 } break; 216 case CEED_OWN_POINTER: 217 CeedCallBackend(CeedFree(&impl->h_array_owned)); 218 impl->h_array_owned = array; 219 impl->h_array_borrowed = NULL; 220 impl->h_array = array; 221 break; 222 case CEED_USE_POINTER: 223 CeedCallBackend(CeedFree(&impl->h_array_owned)); 224 impl->h_array_borrowed = array; 225 impl->h_array = array; 226 break; 227 } 228 229 return CEED_ERROR_SUCCESS; 230 } 231 232 //------------------------------------------------------------------------------ 233 // Set array from device 234 //------------------------------------------------------------------------------ 235 static int CeedVectorSetArrayDevice_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 236 Ceed ceed; 237 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 238 CeedVector_Sycl *impl; 239 CeedCallBackend(CeedVectorGetData(vec, &impl)); 240 Ceed_Sycl *data; 241 CeedCallBackend(CeedGetData(ceed, &data)); 242 243 // Order queue 244 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 245 246 switch (copy_mode) { 247 case CEED_COPY_VALUES: { 248 CeedSize length; 249 CeedCallBackend(CeedVectorGetLength(vec, &length)); 250 if (!impl->d_array_owned) { 251 CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 252 } 253 impl->d_array = impl->d_array_owned; 254 if (array) { 255 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(array, impl->d_array, length, {e}); 256 // Wait for copy to finish and handle exceptions. 257 CeedCallSycl(ceed, copy_event.wait_and_throw()); 258 } 259 } break; 260 case CEED_OWN_POINTER: 261 if (impl->d_array_owned) { 262 // Wait for all work to finish before freeing memory 263 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 264 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 265 } 266 impl->d_array_owned = array; 267 impl->d_array_borrowed = NULL; 268 impl->d_array = array; 269 break; 270 case CEED_USE_POINTER: 271 if (impl->d_array_owned) { 272 // Wait for all work to finish before freeing memory 273 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 274 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 275 } 276 impl->d_array_owned = NULL; 277 impl->d_array_borrowed = array; 278 impl->d_array = array; 279 break; 280 } 281 282 return CEED_ERROR_SUCCESS; 283 } 284 285 //------------------------------------------------------------------------------ 286 // Set the array used by a vector, 287 // freeing any previously allocated array if applicable 288 //------------------------------------------------------------------------------ 289 static int CeedVectorSetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 290 Ceed ceed; 291 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 292 CeedVector_Sycl *impl; 293 CeedCallBackend(CeedVectorGetData(vec, &impl)); 294 295 CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 296 switch (mem_type) { 297 case CEED_MEM_HOST: 298 return CeedVectorSetArrayHost_Sycl(vec, copy_mode, array); 299 case CEED_MEM_DEVICE: 300 return CeedVectorSetArrayDevice_Sycl(vec, copy_mode, array); 301 } 302 303 return CEED_ERROR_UNSUPPORTED; 304 } 305 306 //------------------------------------------------------------------------------ 307 // Set host array to value 308 //------------------------------------------------------------------------------ 309 static int CeedHostSetValue_Sycl(CeedScalar *h_array, CeedSize length, CeedScalar val) { 310 for (CeedSize i = 0; i < length; i++) h_array[i] = val; 311 return CEED_ERROR_SUCCESS; 312 } 313 314 //------------------------------------------------------------------------------ 315 // Set device array to value 316 //------------------------------------------------------------------------------ 317 static int CeedDeviceSetValue_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length, CeedScalar val) { 318 // Order queue 319 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 320 sycl_queue.fill(d_array, val, length, {e}); 321 return CEED_ERROR_SUCCESS; 322 } 323 324 //------------------------------------------------------------------------------ 325 // Set a vector to a value, 326 //------------------------------------------------------------------------------ 327 static int CeedVectorSetValue_Sycl(CeedVector vec, CeedScalar val) { 328 Ceed ceed; 329 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 330 CeedVector_Sycl *impl; 331 CeedCallBackend(CeedVectorGetData(vec, &impl)); 332 CeedSize length; 333 CeedCallBackend(CeedVectorGetLength(vec, &length)); 334 Ceed_Sycl *data; 335 CeedCallBackend(CeedGetData(ceed, &data)); 336 337 // Set value for synced device/host array 338 if (!impl->d_array && !impl->h_array) { 339 if (impl->d_array_borrowed) { 340 impl->d_array = impl->d_array_borrowed; 341 } else if (impl->h_array_borrowed) { 342 impl->h_array = impl->h_array_borrowed; 343 } else if (impl->d_array_owned) { 344 impl->d_array = impl->d_array_owned; 345 } else if (impl->h_array_owned) { 346 impl->h_array = impl->h_array_owned; 347 } else { 348 CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 349 } 350 } 351 if (impl->d_array) { 352 CeedCallBackend(CeedDeviceSetValue_Sycl(data->sycl_queue, impl->d_array, length, val)); 353 impl->h_array = NULL; 354 } 355 if (impl->h_array) { 356 CeedCallBackend(CeedHostSetValue_Sycl(impl->h_array, length, val)); 357 impl->d_array = NULL; 358 } 359 360 return CEED_ERROR_SUCCESS; 361 } 362 363 //------------------------------------------------------------------------------ 364 // Vector Take Array 365 //------------------------------------------------------------------------------ 366 static int CeedVectorTakeArray_Sycl(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 367 Ceed ceed; 368 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 369 CeedVector_Sycl *impl; 370 CeedCallBackend(CeedVectorGetData(vec, &impl)); 371 372 Ceed_Sycl *data; 373 CeedCallBackend(CeedGetData(ceed, &data)); 374 375 // Order queue 376 data->sycl_queue.ext_oneapi_submit_barrier(); 377 378 // Sync array to requested mem_type 379 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 380 381 // Update pointer 382 switch (mem_type) { 383 case CEED_MEM_HOST: 384 (*array) = impl->h_array_borrowed; 385 impl->h_array_borrowed = NULL; 386 impl->h_array = NULL; 387 break; 388 case CEED_MEM_DEVICE: 389 (*array) = impl->d_array_borrowed; 390 impl->d_array_borrowed = NULL; 391 impl->d_array = NULL; 392 break; 393 } 394 395 return CEED_ERROR_SUCCESS; 396 } 397 398 //------------------------------------------------------------------------------ 399 // Core logic for array syncronization for GetArray. 400 // If a different memory type is most up to date, this will perform a copy 401 //------------------------------------------------------------------------------ 402 static int CeedVectorGetArrayCore_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 403 Ceed ceed; 404 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 405 CeedVector_Sycl *impl; 406 CeedCallBackend(CeedVectorGetData(vec, &impl)); 407 408 // Sync array to requested mem_type 409 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 410 411 // Update pointer 412 switch (mem_type) { 413 case CEED_MEM_HOST: 414 *array = impl->h_array; 415 break; 416 case CEED_MEM_DEVICE: 417 *array = impl->d_array; 418 break; 419 } 420 421 return CEED_ERROR_SUCCESS; 422 } 423 424 //------------------------------------------------------------------------------ 425 // Get read-only access to a vector via the specified mem_type 426 //------------------------------------------------------------------------------ 427 static int CeedVectorGetArrayRead_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 428 return CeedVectorGetArrayCore_Sycl(vec, mem_type, (CeedScalar **)array); 429 } 430 431 //------------------------------------------------------------------------------ 432 // Get read/write access to a vector via the specified mem_type 433 //------------------------------------------------------------------------------ 434 static int CeedVectorGetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 435 CeedVector_Sycl *impl; 436 CeedCallBackend(CeedVectorGetData(vec, &impl)); 437 438 CeedCallBackend(CeedVectorGetArrayCore_Sycl(vec, mem_type, array)); 439 440 CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 441 switch (mem_type) { 442 case CEED_MEM_HOST: 443 impl->h_array = *array; 444 break; 445 case CEED_MEM_DEVICE: 446 impl->d_array = *array; 447 break; 448 } 449 450 return CEED_ERROR_SUCCESS; 451 } 452 453 //------------------------------------------------------------------------------ 454 // Get write access to a vector via the specified mem_type 455 //------------------------------------------------------------------------------ 456 static int CeedVectorGetArrayWrite_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 457 CeedVector_Sycl *impl; 458 CeedCallBackend(CeedVectorGetData(vec, &impl)); 459 460 bool has_array_of_type = true; 461 CeedCallBackend(CeedVectorHasArrayOfType_Sycl(vec, mem_type, &has_array_of_type)); 462 if (!has_array_of_type) { 463 // Allocate if array is not yet allocated 464 CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 465 } else { 466 // Select dirty array 467 switch (mem_type) { 468 case CEED_MEM_HOST: 469 if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 470 else impl->h_array = impl->h_array_owned; 471 break; 472 case CEED_MEM_DEVICE: 473 if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 474 else impl->d_array = impl->d_array_owned; 475 } 476 } 477 478 return CeedVectorGetArray_Sycl(vec, mem_type, array); 479 } 480 481 //------------------------------------------------------------------------------ 482 // Get the norm of a CeedVector 483 //------------------------------------------------------------------------------ 484 static int CeedVectorNorm_Sycl(CeedVector vec, CeedNormType type, CeedScalar *norm) { 485 Ceed ceed; 486 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 487 CeedVector_Sycl *impl; 488 CeedCallBackend(CeedVectorGetData(vec, &impl)); 489 CeedSize length; 490 CeedCallBackend(CeedVectorGetLength(vec, &length)); 491 Ceed_Sycl *data; 492 CeedCallBackend(CeedGetData(ceed, &data)); 493 494 // Compute norm 495 const CeedScalar *d_array; 496 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 497 498 switch (type) { 499 case CEED_NORM_1: { 500 // Order queue 501 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 502 auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 503 data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += abs(d_array[i]); }).wait_and_throw(); 504 } break; 505 case CEED_NORM_2: { 506 // Order queue 507 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 508 auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 509 data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += (d_array[i] * d_array[i]); }).wait_and_throw(); 510 } break; 511 case CEED_NORM_MAX: { 512 // Order queue 513 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 514 auto maxReduction = sycl::reduction(impl->reduction_norm, sycl::maximum<>(), {sycl::property::reduction::initialize_to_identity{}}); 515 data->sycl_queue.parallel_for(length, {e}, maxReduction, [=](sycl::id<1> i, auto &max) { max.combine(abs(d_array[i])); }).wait_and_throw(); 516 } break; 517 } 518 // L2 norm - square root over reduced value 519 if (type == CEED_NORM_2) *norm = sqrt(*impl->reduction_norm); 520 else *norm = *impl->reduction_norm; 521 522 CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 523 524 return CEED_ERROR_SUCCESS; 525 } 526 527 //------------------------------------------------------------------------------ 528 // Take reciprocal of a vector on host 529 //------------------------------------------------------------------------------ 530 static int CeedHostReciprocal_Sycl(CeedScalar *h_array, CeedSize length) { 531 for (CeedSize i = 0; i < length; i++) { 532 if (std::fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 533 } 534 return CEED_ERROR_SUCCESS; 535 } 536 537 //------------------------------------------------------------------------------ 538 // Take reciprocal of a vector on device 539 //------------------------------------------------------------------------------ 540 static int CeedDeviceReciprocal_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length) { 541 // Order queue 542 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 543 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { 544 if (std::fabs(d_array[i]) > CEED_EPSILON) d_array[i] = 1. / d_array[i]; 545 }); 546 return CEED_ERROR_SUCCESS; 547 } 548 549 //------------------------------------------------------------------------------ 550 // Take reciprocal of a vector 551 //------------------------------------------------------------------------------ 552 static int CeedVectorReciprocal_Sycl(CeedVector vec) { 553 Ceed ceed; 554 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 555 CeedVector_Sycl *impl; 556 CeedCallBackend(CeedVectorGetData(vec, &impl)); 557 CeedSize length; 558 CeedCallBackend(CeedVectorGetLength(vec, &length)); 559 Ceed_Sycl *data; 560 CeedCallBackend(CeedGetData(ceed, &data)); 561 562 // Set value for synced device/host array 563 if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Sycl(data->sycl_queue, impl->d_array, length)); 564 if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Sycl(impl->h_array, length)); 565 566 return CEED_ERROR_SUCCESS; 567 } 568 569 //------------------------------------------------------------------------------ 570 // Compute x = alpha x on the host 571 //------------------------------------------------------------------------------ 572 static int CeedHostScale_Sycl(CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 573 for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha; 574 return CEED_ERROR_SUCCESS; 575 } 576 577 //------------------------------------------------------------------------------ 578 // Compute x = alpha x on device 579 //------------------------------------------------------------------------------ 580 static int CeedDeviceScale_Sycl(sycl::queue &sycl_queue, CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 581 // Order queue 582 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 583 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { x_array[i] *= alpha; }); 584 return CEED_ERROR_SUCCESS; 585 } 586 587 //------------------------------------------------------------------------------ 588 // Compute x = alpha x 589 //------------------------------------------------------------------------------ 590 static int CeedVectorScale_Sycl(CeedVector x, CeedScalar alpha) { 591 Ceed ceed; 592 CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 593 CeedVector_Sycl *x_impl; 594 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 595 CeedSize length; 596 CeedCallBackend(CeedVectorGetLength(x, &length)); 597 Ceed_Sycl *data; 598 CeedCallBackend(CeedGetData(ceed, &data)); 599 600 // Set value for synced device/host array 601 if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Sycl(data->sycl_queue, x_impl->d_array, alpha, length)); 602 if (x_impl->h_array) CeedCallBackend(CeedHostScale_Sycl(x_impl->h_array, alpha, length)); 603 604 return CEED_ERROR_SUCCESS; 605 } 606 607 //------------------------------------------------------------------------------ 608 // Compute y = alpha x + y on the host 609 //------------------------------------------------------------------------------ 610 static int CeedHostAXPY_Sycl(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 611 for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 612 return CEED_ERROR_SUCCESS; 613 } 614 615 //------------------------------------------------------------------------------ 616 // Compute y = alpha x + y on device 617 //------------------------------------------------------------------------------ 618 static int CeedDeviceAXPY_Sycl(sycl::queue &sycl_queue, CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 619 // Order queue 620 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 621 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { y_array[i] += alpha * x_array[i]; }); 622 return CEED_ERROR_SUCCESS; 623 } 624 625 //------------------------------------------------------------------------------ 626 // Compute y = alpha x + y 627 //------------------------------------------------------------------------------ 628 static int CeedVectorAXPY_Sycl(CeedVector y, CeedScalar alpha, CeedVector x) { 629 Ceed ceed; 630 CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 631 CeedVector_Sycl *y_impl, *x_impl; 632 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 633 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 634 CeedSize length; 635 CeedCallBackend(CeedVectorGetLength(y, &length)); 636 Ceed_Sycl *data; 637 CeedCallBackend(CeedGetData(ceed, &data)); 638 639 // Set value for synced device/host array 640 if (y_impl->d_array) { 641 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 642 CeedCallBackend(CeedDeviceAXPY_Sycl(data->sycl_queue, y_impl->d_array, alpha, x_impl->d_array, length)); 643 } 644 if (y_impl->h_array) { 645 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 646 CeedCallBackend(CeedHostAXPY_Sycl(y_impl->h_array, alpha, x_impl->h_array, length)); 647 } 648 649 return CEED_ERROR_SUCCESS; 650 } 651 652 //------------------------------------------------------------------------------ 653 // Compute the pointwise multiplication w = x .* y on the host 654 //------------------------------------------------------------------------------ 655 static int CeedHostPointwiseMult_Sycl(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 656 for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 657 return CEED_ERROR_SUCCESS; 658 } 659 660 //------------------------------------------------------------------------------ 661 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 662 //------------------------------------------------------------------------------ 663 static int CeedDevicePointwiseMult_Sycl(sycl::queue &sycl_queue, CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 664 // Order queue 665 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 666 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { w_array[i] = x_array[i] * y_array[i]; }); 667 return CEED_ERROR_SUCCESS; 668 } 669 670 //------------------------------------------------------------------------------ 671 // Compute the pointwise multiplication w = x .* y 672 //------------------------------------------------------------------------------ 673 static int CeedVectorPointwiseMult_Sycl(CeedVector w, CeedVector x, CeedVector y) { 674 Ceed ceed; 675 CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 676 CeedVector_Sycl *w_impl, *x_impl, *y_impl; 677 CeedCallBackend(CeedVectorGetData(w, &w_impl)); 678 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 679 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 680 CeedSize length; 681 CeedCallBackend(CeedVectorGetLength(w, &length)); 682 Ceed_Sycl *data; 683 CeedCallBackend(CeedGetData(ceed, &data)); 684 685 // Set value for synced device/host array 686 if (!w_impl->d_array && !w_impl->h_array) { 687 CeedCallBackend(CeedVectorSetValue(w, 0.0)); 688 } 689 if (w_impl->d_array) { 690 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 691 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 692 CeedCallBackend(CeedDevicePointwiseMult_Sycl(data->sycl_queue, w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 693 } 694 if (w_impl->h_array) { 695 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 696 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 697 CeedCallBackend(CeedHostPointwiseMult_Sycl(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 698 } 699 700 return CEED_ERROR_SUCCESS; 701 } 702 703 //------------------------------------------------------------------------------ 704 // Destroy the vector 705 //------------------------------------------------------------------------------ 706 static int CeedVectorDestroy_Sycl(const CeedVector vec) { 707 Ceed ceed; 708 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 709 CeedVector_Sycl *impl; 710 CeedCallBackend(CeedVectorGetData(vec, &impl)); 711 Ceed_Sycl *data; 712 CeedCallBackend(CeedGetData(ceed, &data)); 713 714 // Wait for all work to finish before freeing memory 715 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 716 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 717 CeedCallSycl(ceed, sycl::free(impl->reduction_norm, data->sycl_context)); 718 719 CeedCallBackend(CeedFree(&impl->h_array_owned)); 720 CeedCallBackend(CeedFree(&impl)); 721 722 return CEED_ERROR_SUCCESS; 723 } 724 725 //------------------------------------------------------------------------------ 726 // Create a vector of the specified length (does not allocate memory) 727 //------------------------------------------------------------------------------ 728 int CeedVectorCreate_Sycl(CeedSize n, CeedVector vec) { 729 CeedVector_Sycl *impl; 730 Ceed ceed; 731 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 732 Ceed_Sycl *data; 733 CeedCallBackend(CeedGetData(ceed, &data)); 734 735 CeedCallBackend(CeedCalloc(1, &impl)); 736 CeedCallSycl(ceed, impl->reduction_norm = sycl::malloc_host<CeedScalar>(1, data->sycl_context)); 737 738 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Sycl)); 739 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Sycl)); 740 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Sycl)); 741 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Sycl)); 742 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Sycl)); 743 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Sycl)); 744 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Sycl)); 745 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Sycl)); 746 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Sycl)); 747 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Norm", CeedVectorNorm_Sycl)); 748 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Sycl)); 749 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Sycl)); 750 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Scale", CeedVectorScale_Sycl)); 751 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Sycl)); 752 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Sycl)); 753 754 CeedCallBackend(CeedVectorSetData(vec, impl)); 755 756 return CEED_ERROR_SUCCESS; 757 } 758