1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 11 #include <cmath> 12 #include <string> 13 #include <sycl/sycl.hpp> 14 15 #include "ceed-sycl-ref.hpp" 16 17 //------------------------------------------------------------------------------ 18 // Check if host/device sync is needed 19 //------------------------------------------------------------------------------ 20 static inline int CeedVectorNeedSync_Sycl(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 21 bool has_valid_array = false; 22 CeedVector_Sycl *impl; 23 24 CeedCallBackend(CeedVectorGetData(vec, &impl)); 25 CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 26 switch (mem_type) { 27 case CEED_MEM_HOST: 28 *need_sync = has_valid_array && !impl->h_array; 29 break; 30 case CEED_MEM_DEVICE: 31 *need_sync = has_valid_array && !impl->d_array; 32 break; 33 } 34 return CEED_ERROR_SUCCESS; 35 } 36 37 //------------------------------------------------------------------------------ 38 // Sync host to device 39 //------------------------------------------------------------------------------ 40 static inline int CeedVectorSyncH2D_Sycl(const CeedVector vec) { 41 Ceed ceed; 42 Ceed_Sycl *data; 43 CeedSize length; 44 CeedVector_Sycl *impl; 45 46 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 47 CeedCallBackend(CeedVectorGetData(vec, &impl)); 48 CeedCallBackend(CeedGetData(ceed, &data)); 49 50 if (!impl->h_array) { 51 // LCOV_EXCL_START 52 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 53 // LCOV_EXCL_STOP 54 } 55 56 CeedCallBackend(CeedVectorGetLength(vec, &length)); 57 if (impl->d_array_borrowed) { 58 impl->d_array = impl->d_array_borrowed; 59 } else if (impl->d_array_owned) { 60 impl->d_array = impl->d_array_owned; 61 } else { 62 CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 63 impl->d_array = impl->d_array_owned; 64 } 65 66 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 67 // Copy from host to device 68 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->h_array, impl->d_array, length, {e}); 69 // Wait for copy to finish and handle exceptions. 70 CeedCallSycl(ceed, copy_event.wait_and_throw()); 71 return CEED_ERROR_SUCCESS; 72 } 73 74 //------------------------------------------------------------------------------ 75 // Sync device to host 76 //------------------------------------------------------------------------------ 77 static inline int CeedVectorSyncD2H_Sycl(const CeedVector vec) { 78 Ceed ceed; 79 Ceed_Sycl *data; 80 CeedSize length; 81 CeedVector_Sycl *impl; 82 83 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 84 CeedCallBackend(CeedVectorGetData(vec, &impl)); 85 CeedCallBackend(CeedGetData(ceed, &data)); 86 87 CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 88 89 CeedCallBackend(CeedVectorGetLength(vec, &length)); 90 if (impl->h_array_borrowed) { 91 impl->h_array = impl->h_array_borrowed; 92 } else if (impl->h_array_owned) { 93 impl->h_array = impl->h_array_owned; 94 } else { 95 CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 96 impl->h_array = impl->h_array_owned; 97 } 98 99 // Order queue 100 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 101 // Copy from device to host 102 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->d_array, impl->h_array, length, {e}); 103 // Wait for copy to finish and handle exceptions. 104 CeedCallSycl(ceed, copy_event.wait_and_throw()); 105 return CEED_ERROR_SUCCESS; 106 } 107 108 //------------------------------------------------------------------------------ 109 // Sync arrays 110 //------------------------------------------------------------------------------ 111 static int CeedVectorSyncArray_Sycl(const CeedVector vec, CeedMemType mem_type) { 112 bool need_sync = false; 113 114 // Check whether device/host sync is needed 115 CeedCallBackend(CeedVectorNeedSync_Sycl(vec, mem_type, &need_sync)); 116 if (!need_sync) return CEED_ERROR_SUCCESS; 117 118 switch (mem_type) { 119 case CEED_MEM_HOST: 120 return CeedVectorSyncD2H_Sycl(vec); 121 case CEED_MEM_DEVICE: 122 return CeedVectorSyncH2D_Sycl(vec); 123 } 124 return CEED_ERROR_UNSUPPORTED; 125 } 126 127 //------------------------------------------------------------------------------ 128 // Set all pointers as invalid 129 //------------------------------------------------------------------------------ 130 static inline int CeedVectorSetAllInvalid_Sycl(const CeedVector vec) { 131 CeedVector_Sycl *impl; 132 133 CeedCallBackend(CeedVectorGetData(vec, &impl)); 134 impl->h_array = NULL; 135 impl->d_array = NULL; 136 return CEED_ERROR_SUCCESS; 137 } 138 139 //------------------------------------------------------------------------------ 140 // Check if CeedVector has any valid pointer 141 //------------------------------------------------------------------------------ 142 static inline int CeedVectorHasValidArray_Sycl(const CeedVector vec, bool *has_valid_array) { 143 CeedVector_Sycl *impl; 144 145 CeedCallBackend(CeedVectorGetData(vec, &impl)); 146 *has_valid_array = impl->h_array || impl->d_array; 147 return CEED_ERROR_SUCCESS; 148 } 149 150 //------------------------------------------------------------------------------ 151 // Check if has array of given type 152 //------------------------------------------------------------------------------ 153 static inline int CeedVectorHasArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 154 CeedVector_Sycl *impl; 155 156 CeedCallBackend(CeedVectorGetData(vec, &impl)); 157 switch (mem_type) { 158 case CEED_MEM_HOST: 159 *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned; 160 break; 161 case CEED_MEM_DEVICE: 162 *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned; 163 break; 164 } 165 return CEED_ERROR_SUCCESS; 166 } 167 168 //------------------------------------------------------------------------------ 169 // Check if has borrowed array of given type 170 //------------------------------------------------------------------------------ 171 static inline int CeedVectorHasBorrowedArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 172 CeedVector_Sycl *impl; 173 174 CeedCallBackend(CeedVectorGetData(vec, &impl)); 175 switch (mem_type) { 176 case CEED_MEM_HOST: 177 *has_borrowed_array_of_type = impl->h_array_borrowed; 178 break; 179 case CEED_MEM_DEVICE: 180 *has_borrowed_array_of_type = impl->d_array_borrowed; 181 break; 182 } 183 return CEED_ERROR_SUCCESS; 184 } 185 186 //------------------------------------------------------------------------------ 187 // Set array from host 188 //------------------------------------------------------------------------------ 189 static int CeedVectorSetArrayHost_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 190 CeedVector_Sycl *impl; 191 192 CeedCallBackend(CeedVectorGetData(vec, &impl)); 193 switch (copy_mode) { 194 case CEED_COPY_VALUES: { 195 if (!impl->h_array_owned) { 196 CeedSize length; 197 198 CeedCallBackend(CeedVectorGetLength(vec, &length)); 199 CeedCallBackend(CeedMalloc(length, &impl->h_array_owned)); 200 } 201 impl->h_array_borrowed = NULL; 202 impl->h_array = impl->h_array_owned; 203 if (array) { 204 CeedSize length; 205 size_t bytes; 206 207 CeedCallBackend(CeedVectorGetLength(vec, &length)); 208 bytes = length * sizeof(CeedScalar); 209 memcpy(impl->h_array, array, bytes); 210 } 211 } break; 212 case CEED_OWN_POINTER: 213 CeedCallBackend(CeedFree(&impl->h_array_owned)); 214 impl->h_array_owned = array; 215 impl->h_array_borrowed = NULL; 216 impl->h_array = array; 217 break; 218 case CEED_USE_POINTER: 219 CeedCallBackend(CeedFree(&impl->h_array_owned)); 220 impl->h_array_borrowed = array; 221 impl->h_array = array; 222 break; 223 } 224 return CEED_ERROR_SUCCESS; 225 } 226 227 //------------------------------------------------------------------------------ 228 // Set array from device 229 //------------------------------------------------------------------------------ 230 static int CeedVectorSetArrayDevice_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 231 Ceed ceed; 232 Ceed_Sycl *data; 233 CeedVector_Sycl *impl; 234 235 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 236 CeedCallBackend(CeedVectorGetData(vec, &impl)); 237 CeedCallBackend(CeedGetData(ceed, &data)); 238 239 // Order queue 240 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 241 242 switch (copy_mode) { 243 case CEED_COPY_VALUES: { 244 CeedSize length; 245 246 CeedCallBackend(CeedVectorGetLength(vec, &length)); 247 if (!impl->d_array_owned) { 248 CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 249 } 250 impl->d_array = impl->d_array_owned; 251 if (array) { 252 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(array, impl->d_array, length, {e}); 253 // Wait for copy to finish and handle exceptions. 254 CeedCallSycl(ceed, copy_event.wait_and_throw()); 255 } 256 } break; 257 case CEED_OWN_POINTER: 258 if (impl->d_array_owned) { 259 // Wait for all work to finish before freeing memory 260 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 261 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 262 } 263 impl->d_array_owned = array; 264 impl->d_array_borrowed = NULL; 265 impl->d_array = array; 266 break; 267 case CEED_USE_POINTER: 268 if (impl->d_array_owned) { 269 // Wait for all work to finish before freeing memory 270 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 271 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 272 } 273 impl->d_array_owned = NULL; 274 impl->d_array_borrowed = array; 275 impl->d_array = array; 276 break; 277 } 278 return CEED_ERROR_SUCCESS; 279 } 280 281 //------------------------------------------------------------------------------ 282 // Set the array used by a vector, 283 // freeing any previously allocated array if applicable 284 //------------------------------------------------------------------------------ 285 static int CeedVectorSetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 286 Ceed ceed; 287 CeedVector_Sycl *impl; 288 289 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 290 CeedCallBackend(CeedVectorGetData(vec, &impl)); 291 292 CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 293 switch (mem_type) { 294 case CEED_MEM_HOST: 295 return CeedVectorSetArrayHost_Sycl(vec, copy_mode, array); 296 case CEED_MEM_DEVICE: 297 return CeedVectorSetArrayDevice_Sycl(vec, copy_mode, array); 298 } 299 return CEED_ERROR_UNSUPPORTED; 300 } 301 302 //------------------------------------------------------------------------------ 303 // Set host array to value 304 //------------------------------------------------------------------------------ 305 static int CeedHostSetValue_Sycl(CeedScalar *h_array, CeedSize length, CeedScalar val) { 306 for (CeedSize i = 0; i < length; i++) h_array[i] = val; 307 return CEED_ERROR_SUCCESS; 308 } 309 310 //------------------------------------------------------------------------------ 311 // Set device array to value 312 //------------------------------------------------------------------------------ 313 static int CeedDeviceSetValue_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length, CeedScalar val) { 314 // Order queue 315 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 316 sycl_queue.fill(d_array, val, length, {e}); 317 return CEED_ERROR_SUCCESS; 318 } 319 320 //------------------------------------------------------------------------------ 321 // Set a vector to a value, 322 //------------------------------------------------------------------------------ 323 static int CeedVectorSetValue_Sycl(CeedVector vec, CeedScalar val) { 324 Ceed ceed; 325 Ceed_Sycl *data; 326 CeedSize length; 327 CeedVector_Sycl *impl; 328 329 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 330 CeedCallBackend(CeedVectorGetData(vec, &impl)); 331 CeedCallBackend(CeedVectorGetLength(vec, &length)); 332 CeedCallBackend(CeedGetData(ceed, &data)); 333 334 // Set value for synced device/host array 335 if (!impl->d_array && !impl->h_array) { 336 if (impl->d_array_borrowed) { 337 impl->d_array = impl->d_array_borrowed; 338 } else if (impl->h_array_borrowed) { 339 impl->h_array = impl->h_array_borrowed; 340 } else if (impl->d_array_owned) { 341 impl->d_array = impl->d_array_owned; 342 } else if (impl->h_array_owned) { 343 impl->h_array = impl->h_array_owned; 344 } else { 345 CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 346 } 347 } 348 if (impl->d_array) { 349 CeedCallBackend(CeedDeviceSetValue_Sycl(data->sycl_queue, impl->d_array, length, val)); 350 impl->h_array = NULL; 351 } 352 if (impl->h_array) { 353 CeedCallBackend(CeedHostSetValue_Sycl(impl->h_array, length, val)); 354 impl->d_array = NULL; 355 } 356 return CEED_ERROR_SUCCESS; 357 } 358 359 //------------------------------------------------------------------------------ 360 // Vector Take Array 361 //------------------------------------------------------------------------------ 362 static int CeedVectorTakeArray_Sycl(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 363 Ceed ceed; 364 Ceed_Sycl *data; 365 CeedVector_Sycl *impl; 366 367 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 368 CeedCallBackend(CeedVectorGetData(vec, &impl)); 369 CeedCallBackend(CeedGetData(ceed, &data)); 370 371 // Order queue 372 data->sycl_queue.ext_oneapi_submit_barrier(); 373 374 // Sync array to requested mem_type 375 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 376 377 // Update pointer 378 switch (mem_type) { 379 case CEED_MEM_HOST: 380 (*array) = impl->h_array_borrowed; 381 impl->h_array_borrowed = NULL; 382 impl->h_array = NULL; 383 break; 384 case CEED_MEM_DEVICE: 385 (*array) = impl->d_array_borrowed; 386 impl->d_array_borrowed = NULL; 387 impl->d_array = NULL; 388 break; 389 } 390 return CEED_ERROR_SUCCESS; 391 } 392 393 //------------------------------------------------------------------------------ 394 // Core logic for array syncronization for GetArray. 395 // If a different memory type is most up to date, this will perform a copy 396 //------------------------------------------------------------------------------ 397 static int CeedVectorGetArrayCore_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 398 Ceed ceed; 399 CeedVector_Sycl *impl; 400 401 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 402 CeedCallBackend(CeedVectorGetData(vec, &impl)); 403 404 // Sync array to requested mem_type 405 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 406 407 // Update pointer 408 switch (mem_type) { 409 case CEED_MEM_HOST: 410 *array = impl->h_array; 411 break; 412 case CEED_MEM_DEVICE: 413 *array = impl->d_array; 414 break; 415 } 416 return CEED_ERROR_SUCCESS; 417 } 418 419 //------------------------------------------------------------------------------ 420 // Get read-only access to a vector via the specified mem_type 421 //------------------------------------------------------------------------------ 422 static int CeedVectorGetArrayRead_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 423 return CeedVectorGetArrayCore_Sycl(vec, mem_type, (CeedScalar **)array); 424 } 425 426 //------------------------------------------------------------------------------ 427 // Get read/write access to a vector via the specified mem_type 428 //------------------------------------------------------------------------------ 429 static int CeedVectorGetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 430 CeedVector_Sycl *impl; 431 432 CeedCallBackend(CeedVectorGetData(vec, &impl)); 433 CeedCallBackend(CeedVectorGetArrayCore_Sycl(vec, mem_type, array)); 434 CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 435 switch (mem_type) { 436 case CEED_MEM_HOST: 437 impl->h_array = *array; 438 break; 439 case CEED_MEM_DEVICE: 440 impl->d_array = *array; 441 break; 442 } 443 return CEED_ERROR_SUCCESS; 444 } 445 446 //------------------------------------------------------------------------------ 447 // Get write access to a vector via the specified mem_type 448 //------------------------------------------------------------------------------ 449 static int CeedVectorGetArrayWrite_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 450 bool has_array_of_type = true; 451 CeedVector_Sycl *impl; 452 453 CeedCallBackend(CeedVectorGetData(vec, &impl)); 454 CeedCallBackend(CeedVectorHasArrayOfType_Sycl(vec, mem_type, &has_array_of_type)); 455 if (!has_array_of_type) { 456 // Allocate if array is not yet allocated 457 CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 458 } else { 459 // Select dirty array 460 switch (mem_type) { 461 case CEED_MEM_HOST: 462 if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 463 else impl->h_array = impl->h_array_owned; 464 break; 465 case CEED_MEM_DEVICE: 466 if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 467 else impl->d_array = impl->d_array_owned; 468 } 469 } 470 return CeedVectorGetArray_Sycl(vec, mem_type, array); 471 } 472 473 //------------------------------------------------------------------------------ 474 // Get the norm of a CeedVector 475 //------------------------------------------------------------------------------ 476 static int CeedVectorNorm_Sycl(CeedVector vec, CeedNormType type, CeedScalar *norm) { 477 Ceed ceed; 478 Ceed_Sycl *data; 479 CeedSize length; 480 const CeedScalar *d_array; 481 CeedVector_Sycl *impl; 482 483 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 484 CeedCallBackend(CeedVectorGetData(vec, &impl)); 485 CeedCallBackend(CeedVectorGetLength(vec, &length)); 486 CeedCallBackend(CeedGetData(ceed, &data)); 487 488 // Compute norm 489 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 490 switch (type) { 491 case CEED_NORM_1: { 492 // Order queue 493 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 494 auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 495 data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += abs(d_array[i]); }).wait_and_throw(); 496 } break; 497 case CEED_NORM_2: { 498 // Order queue 499 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 500 auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 501 data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += (d_array[i] * d_array[i]); }).wait_and_throw(); 502 } break; 503 case CEED_NORM_MAX: { 504 // Order queue 505 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 506 auto maxReduction = sycl::reduction(impl->reduction_norm, sycl::maximum<>(), {sycl::property::reduction::initialize_to_identity{}}); 507 data->sycl_queue.parallel_for(length, {e}, maxReduction, [=](sycl::id<1> i, auto &max) { max.combine(abs(d_array[i])); }).wait_and_throw(); 508 } break; 509 } 510 // L2 norm - square root over reduced value 511 if (type == CEED_NORM_2) *norm = sqrt(*impl->reduction_norm); 512 else *norm = *impl->reduction_norm; 513 CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 514 return CEED_ERROR_SUCCESS; 515 } 516 517 //------------------------------------------------------------------------------ 518 // Take reciprocal of a vector on host 519 //------------------------------------------------------------------------------ 520 static int CeedHostReciprocal_Sycl(CeedScalar *h_array, CeedSize length) { 521 for (CeedSize i = 0; i < length; i++) { 522 if (std::fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 523 } 524 return CEED_ERROR_SUCCESS; 525 } 526 527 //------------------------------------------------------------------------------ 528 // Take reciprocal of a vector on device 529 //------------------------------------------------------------------------------ 530 static int CeedDeviceReciprocal_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length) { 531 // Order queue 532 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 533 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { 534 if (std::fabs(d_array[i]) > CEED_EPSILON) d_array[i] = 1. / d_array[i]; 535 }); 536 return CEED_ERROR_SUCCESS; 537 } 538 539 //------------------------------------------------------------------------------ 540 // Take reciprocal of a vector 541 //------------------------------------------------------------------------------ 542 static int CeedVectorReciprocal_Sycl(CeedVector vec) { 543 Ceed ceed; 544 Ceed_Sycl *data; 545 CeedSize length; 546 CeedVector_Sycl *impl; 547 548 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 549 CeedCallBackend(CeedVectorGetData(vec, &impl)); 550 CeedCallBackend(CeedVectorGetLength(vec, &length)); 551 CeedCallBackend(CeedGetData(ceed, &data)); 552 553 // Set value for synced device/host array 554 if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Sycl(data->sycl_queue, impl->d_array, length)); 555 if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Sycl(impl->h_array, length)); 556 return CEED_ERROR_SUCCESS; 557 } 558 559 //------------------------------------------------------------------------------ 560 // Compute x = alpha x on the host 561 //------------------------------------------------------------------------------ 562 static int CeedHostScale_Sycl(CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 563 for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha; 564 return CEED_ERROR_SUCCESS; 565 } 566 567 //------------------------------------------------------------------------------ 568 // Compute x = alpha x on device 569 //------------------------------------------------------------------------------ 570 static int CeedDeviceScale_Sycl(sycl::queue &sycl_queue, CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 571 // Order queue 572 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 573 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { x_array[i] *= alpha; }); 574 return CEED_ERROR_SUCCESS; 575 } 576 577 //------------------------------------------------------------------------------ 578 // Compute x = alpha x 579 //------------------------------------------------------------------------------ 580 static int CeedVectorScale_Sycl(CeedVector x, CeedScalar alpha) { 581 Ceed ceed; 582 Ceed_Sycl *data; 583 CeedSize length; 584 CeedVector_Sycl *x_impl; 585 586 CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 587 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 588 CeedCallBackend(CeedVectorGetLength(x, &length)); 589 CeedCallBackend(CeedGetData(ceed, &data)); 590 591 // Set value for synced device/host array 592 if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Sycl(data->sycl_queue, x_impl->d_array, alpha, length)); 593 if (x_impl->h_array) CeedCallBackend(CeedHostScale_Sycl(x_impl->h_array, alpha, length)); 594 return CEED_ERROR_SUCCESS; 595 } 596 597 //------------------------------------------------------------------------------ 598 // Compute y = alpha x + y on the host 599 //------------------------------------------------------------------------------ 600 static int CeedHostAXPY_Sycl(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 601 for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 602 return CEED_ERROR_SUCCESS; 603 } 604 605 //------------------------------------------------------------------------------ 606 // Compute y = alpha x + y on device 607 //------------------------------------------------------------------------------ 608 static int CeedDeviceAXPY_Sycl(sycl::queue &sycl_queue, CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 609 // Order queue 610 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 611 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { y_array[i] += alpha * x_array[i]; }); 612 return CEED_ERROR_SUCCESS; 613 } 614 615 //------------------------------------------------------------------------------ 616 // Compute y = alpha x + y 617 //------------------------------------------------------------------------------ 618 static int CeedVectorAXPY_Sycl(CeedVector y, CeedScalar alpha, CeedVector x) { 619 Ceed ceed; 620 Ceed_Sycl *data; 621 CeedSize length; 622 CeedVector_Sycl *y_impl, *x_impl; 623 624 CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 625 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 626 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 627 CeedCallBackend(CeedVectorGetLength(y, &length)); 628 CeedCallBackend(CeedGetData(ceed, &data)); 629 630 // Set value for synced device/host array 631 if (y_impl->d_array) { 632 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 633 CeedCallBackend(CeedDeviceAXPY_Sycl(data->sycl_queue, y_impl->d_array, alpha, x_impl->d_array, length)); 634 } 635 if (y_impl->h_array) { 636 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 637 CeedCallBackend(CeedHostAXPY_Sycl(y_impl->h_array, alpha, x_impl->h_array, length)); 638 } 639 return CEED_ERROR_SUCCESS; 640 } 641 642 //------------------------------------------------------------------------------ 643 // Compute the pointwise multiplication w = x .* y on the host 644 //------------------------------------------------------------------------------ 645 static int CeedHostPointwiseMult_Sycl(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 646 for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 647 return CEED_ERROR_SUCCESS; 648 } 649 650 //------------------------------------------------------------------------------ 651 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 652 //------------------------------------------------------------------------------ 653 static int CeedDevicePointwiseMult_Sycl(sycl::queue &sycl_queue, CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 654 // Order queue 655 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 656 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { w_array[i] = x_array[i] * y_array[i]; }); 657 return CEED_ERROR_SUCCESS; 658 } 659 660 //------------------------------------------------------------------------------ 661 // Compute the pointwise multiplication w = x .* y 662 //------------------------------------------------------------------------------ 663 static int CeedVectorPointwiseMult_Sycl(CeedVector w, CeedVector x, CeedVector y) { 664 Ceed ceed; 665 Ceed_Sycl *data; 666 CeedSize length; 667 CeedVector_Sycl *w_impl, *x_impl, *y_impl; 668 669 CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 670 CeedCallBackend(CeedVectorGetData(w, &w_impl)); 671 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 672 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 673 CeedCallBackend(CeedVectorGetLength(w, &length)); 674 CeedCallBackend(CeedGetData(ceed, &data)); 675 676 // Set value for synced device/host array 677 if (!w_impl->d_array && !w_impl->h_array) { 678 CeedCallBackend(CeedVectorSetValue(w, 0.0)); 679 } 680 if (w_impl->d_array) { 681 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 682 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 683 CeedCallBackend(CeedDevicePointwiseMult_Sycl(data->sycl_queue, w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 684 } 685 if (w_impl->h_array) { 686 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 687 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 688 CeedCallBackend(CeedHostPointwiseMult_Sycl(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 689 } 690 return CEED_ERROR_SUCCESS; 691 } 692 693 //------------------------------------------------------------------------------ 694 // Destroy the vector 695 //------------------------------------------------------------------------------ 696 static int CeedVectorDestroy_Sycl(const CeedVector vec) { 697 Ceed ceed; 698 Ceed_Sycl *data; 699 CeedVector_Sycl *impl; 700 701 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 702 CeedCallBackend(CeedVectorGetData(vec, &impl)); 703 CeedCallBackend(CeedGetData(ceed, &data)); 704 705 // Wait for all work to finish before freeing memory 706 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 707 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 708 CeedCallSycl(ceed, sycl::free(impl->reduction_norm, data->sycl_context)); 709 710 CeedCallBackend(CeedFree(&impl->h_array_owned)); 711 CeedCallBackend(CeedFree(&impl)); 712 return CEED_ERROR_SUCCESS; 713 } 714 715 //------------------------------------------------------------------------------ 716 // Create a vector of the specified length (does not allocate memory) 717 //------------------------------------------------------------------------------ 718 int CeedVectorCreate_Sycl(CeedSize n, CeedVector vec) { 719 Ceed ceed; 720 Ceed_Sycl *data; 721 CeedVector_Sycl *impl; 722 723 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 724 CeedCallBackend(CeedGetData(ceed, &data)); 725 CeedCallBackend(CeedCalloc(1, &impl)); 726 CeedCallSycl(ceed, impl->reduction_norm = sycl::malloc_host<CeedScalar>(1, data->sycl_context)); 727 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Sycl)); 728 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Sycl)); 729 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Sycl)); 730 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Sycl)); 731 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Sycl)); 732 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Sycl)); 733 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Sycl)); 734 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Sycl)); 735 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Sycl)); 736 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Norm", CeedVectorNorm_Sycl)); 737 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Sycl)); 738 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Sycl)); 739 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Scale", CeedVectorScale_Sycl)); 740 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Sycl)); 741 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Sycl)); 742 CeedCallBackend(CeedVectorSetData(vec, impl)); 743 return CEED_ERROR_SUCCESS; 744 } 745