1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 11 #include <cmath> 12 #include <string> 13 #include <sycl/sycl.hpp> 14 15 #include "ceed-sycl-ref.hpp" 16 17 //------------------------------------------------------------------------------ 18 // Check if host/device sync is needed 19 //------------------------------------------------------------------------------ 20 static inline int CeedVectorNeedSync_Sycl(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 21 CeedVector_Sycl *impl; 22 CeedCallBackend(CeedVectorGetData(vec, &impl)); 23 24 bool has_valid_array = false; 25 CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 26 switch (mem_type) { 27 case CEED_MEM_HOST: 28 *need_sync = has_valid_array && !impl->h_array; 29 break; 30 case CEED_MEM_DEVICE: 31 *need_sync = has_valid_array && !impl->d_array; 32 break; 33 } 34 35 return CEED_ERROR_SUCCESS; 36 } 37 38 //------------------------------------------------------------------------------ 39 // Sync host to device 40 //------------------------------------------------------------------------------ 41 static inline int CeedVectorSyncH2D_Sycl(const CeedVector vec) { 42 Ceed ceed; 43 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 44 CeedVector_Sycl *impl; 45 CeedCallBackend(CeedVectorGetData(vec, &impl)); 46 Ceed_Sycl *data; 47 CeedCallBackend(CeedGetData(ceed, &data)); 48 49 if (!impl->h_array) { 50 // LCOV_EXCL_START 51 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 52 // LCOV_EXCL_STOP 53 } 54 55 CeedSize length; 56 CeedCallBackend(CeedVectorGetLength(vec, &length)); 57 58 if (impl->d_array_borrowed) { 59 impl->d_array = impl->d_array_borrowed; 60 } else if (impl->d_array_owned) { 61 impl->d_array = impl->d_array_owned; 62 } else { 63 CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 64 impl->d_array = impl->d_array_owned; 65 } 66 67 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 68 // Copy from host to device 69 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->h_array, impl->d_array, length, {e}); 70 // Wait for copy to finish and handle exceptions. 71 CeedCallSycl(ceed, copy_event.wait_and_throw()); 72 73 return CEED_ERROR_SUCCESS; 74 } 75 76 //------------------------------------------------------------------------------ 77 // Sync device to host 78 //------------------------------------------------------------------------------ 79 static inline int CeedVectorSyncD2H_Sycl(const CeedVector vec) { 80 Ceed ceed; 81 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 82 CeedVector_Sycl *impl; 83 CeedCallBackend(CeedVectorGetData(vec, &impl)); 84 Ceed_Sycl *data; 85 CeedCallBackend(CeedGetData(ceed, &data)); 86 87 CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 88 89 CeedSize length; 90 CeedCallBackend(CeedVectorGetLength(vec, &length)); 91 92 if (impl->h_array_borrowed) { 93 impl->h_array = impl->h_array_borrowed; 94 } else if (impl->h_array_owned) { 95 impl->h_array = impl->h_array_owned; 96 } else { 97 CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 98 impl->h_array = impl->h_array_owned; 99 } 100 101 // Order queue 102 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 103 // Copy from device to host 104 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->d_array, impl->h_array, length, {e}); 105 106 // Wait for copy to finish and handle exceptions. 107 CeedCallSycl(ceed, copy_event.wait_and_throw()); 108 109 return CEED_ERROR_SUCCESS; 110 } 111 112 //------------------------------------------------------------------------------ 113 // Sync arrays 114 //------------------------------------------------------------------------------ 115 static int CeedVectorSyncArray_Sycl(const CeedVector vec, CeedMemType mem_type) { 116 // Check whether device/host sync is needed 117 bool need_sync = false; 118 CeedCallBackend(CeedVectorNeedSync_Sycl(vec, mem_type, &need_sync)); 119 if (!need_sync) return CEED_ERROR_SUCCESS; 120 121 switch (mem_type) { 122 case CEED_MEM_HOST: 123 return CeedVectorSyncD2H_Sycl(vec); 124 case CEED_MEM_DEVICE: 125 return CeedVectorSyncH2D_Sycl(vec); 126 } 127 return CEED_ERROR_UNSUPPORTED; 128 } 129 130 //------------------------------------------------------------------------------ 131 // Set all pointers as invalid 132 //------------------------------------------------------------------------------ 133 static inline int CeedVectorSetAllInvalid_Sycl(const CeedVector vec) { 134 CeedVector_Sycl *impl; 135 CeedCallBackend(CeedVectorGetData(vec, &impl)); 136 137 impl->h_array = NULL; 138 impl->d_array = NULL; 139 140 return CEED_ERROR_SUCCESS; 141 } 142 143 //------------------------------------------------------------------------------ 144 // Check if CeedVector has any valid pointer 145 //------------------------------------------------------------------------------ 146 static inline int CeedVectorHasValidArray_Sycl(const CeedVector vec, bool *has_valid_array) { 147 CeedVector_Sycl *impl; 148 CeedCallBackend(CeedVectorGetData(vec, &impl)); 149 150 *has_valid_array = !!impl->h_array || !!impl->d_array; 151 152 return CEED_ERROR_SUCCESS; 153 } 154 155 //------------------------------------------------------------------------------ 156 // Check if has array of given type 157 //------------------------------------------------------------------------------ 158 static inline int CeedVectorHasArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 159 CeedVector_Sycl *impl; 160 CeedCallBackend(CeedVectorGetData(vec, &impl)); 161 162 switch (mem_type) { 163 case CEED_MEM_HOST: 164 *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned; 165 break; 166 case CEED_MEM_DEVICE: 167 *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned; 168 break; 169 } 170 171 return CEED_ERROR_SUCCESS; 172 } 173 174 //------------------------------------------------------------------------------ 175 // Check if has borrowed array of given type 176 //------------------------------------------------------------------------------ 177 static inline int CeedVectorHasBorrowedArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 178 CeedVector_Sycl *impl; 179 CeedCallBackend(CeedVectorGetData(vec, &impl)); 180 181 switch (mem_type) { 182 case CEED_MEM_HOST: 183 *has_borrowed_array_of_type = !!impl->h_array_borrowed; 184 break; 185 case CEED_MEM_DEVICE: 186 *has_borrowed_array_of_type = !!impl->d_array_borrowed; 187 break; 188 } 189 190 return CEED_ERROR_SUCCESS; 191 } 192 193 //------------------------------------------------------------------------------ 194 // Set array from host 195 //------------------------------------------------------------------------------ 196 static int CeedVectorSetArrayHost_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 197 CeedVector_Sycl *impl; 198 CeedCallBackend(CeedVectorGetData(vec, &impl)); 199 200 switch (copy_mode) { 201 case CEED_COPY_VALUES: { 202 if (!impl->h_array_owned) { 203 CeedSize length; 204 CeedCallBackend(CeedVectorGetLength(vec, &length)); 205 CeedCallBackend(CeedMalloc(length, &impl->h_array_owned)); 206 } 207 impl->h_array_borrowed = NULL; 208 impl->h_array = impl->h_array_owned; 209 if (array) { 210 CeedSize length; 211 CeedCallBackend(CeedVectorGetLength(vec, &length)); 212 size_t bytes = length * sizeof(CeedScalar); 213 memcpy(impl->h_array, array, bytes); 214 } 215 } break; 216 case CEED_OWN_POINTER: 217 CeedCallBackend(CeedFree(&impl->h_array_owned)); 218 impl->h_array_owned = array; 219 impl->h_array_borrowed = NULL; 220 impl->h_array = array; 221 break; 222 case CEED_USE_POINTER: 223 CeedCallBackend(CeedFree(&impl->h_array_owned)); 224 impl->h_array_borrowed = array; 225 impl->h_array = array; 226 break; 227 } 228 229 return CEED_ERROR_SUCCESS; 230 } 231 232 //------------------------------------------------------------------------------ 233 // Set array from device 234 //------------------------------------------------------------------------------ 235 static int CeedVectorSetArrayDevice_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 236 Ceed ceed; 237 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 238 CeedVector_Sycl *impl; 239 CeedCallBackend(CeedVectorGetData(vec, &impl)); 240 Ceed_Sycl *data; 241 CeedCallBackend(CeedGetData(ceed, &data)); 242 243 // Order queue 244 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 245 246 switch (copy_mode) { 247 case CEED_COPY_VALUES: { 248 CeedSize length; 249 CeedCallBackend(CeedVectorGetLength(vec, &length)); 250 if (!impl->d_array_owned) { 251 CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 252 impl->d_array = impl->d_array_owned; 253 } 254 if (array) { 255 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(array, impl->d_array, length, {e}); 256 // Wait for copy to finish and handle exceptions. 257 CeedCallSycl(ceed, copy_event.wait_and_throw()); 258 } 259 } break; 260 case CEED_OWN_POINTER: 261 if (impl->d_array_owned) { 262 // Wait for all work to finish before freeing memory 263 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 264 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 265 } 266 impl->d_array_owned = array; 267 impl->d_array_borrowed = NULL; 268 impl->d_array = array; 269 break; 270 case CEED_USE_POINTER: 271 if (impl->d_array_owned) { 272 // Wait for all work to finish before freeing memory 273 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 274 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 275 } 276 impl->d_array_owned = NULL; 277 impl->d_array_borrowed = array; 278 impl->d_array = array; 279 break; 280 } 281 282 return CEED_ERROR_SUCCESS; 283 } 284 285 //------------------------------------------------------------------------------ 286 // Set the array used by a vector, 287 // freeing any previously allocated array if applicable 288 //------------------------------------------------------------------------------ 289 static int CeedVectorSetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 290 Ceed ceed; 291 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 292 CeedVector_Sycl *impl; 293 CeedCallBackend(CeedVectorGetData(vec, &impl)); 294 295 CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 296 switch (mem_type) { 297 case CEED_MEM_HOST: 298 return CeedVectorSetArrayHost_Sycl(vec, copy_mode, array); 299 case CEED_MEM_DEVICE: 300 return CeedVectorSetArrayDevice_Sycl(vec, copy_mode, array); 301 } 302 303 return CEED_ERROR_UNSUPPORTED; 304 } 305 306 //------------------------------------------------------------------------------ 307 // Set host array to value 308 //------------------------------------------------------------------------------ 309 static int CeedHostSetValue_Sycl(CeedScalar *h_array, CeedInt length, CeedScalar val) { 310 for (int i = 0; i < length; i++) h_array[i] = val; 311 return CEED_ERROR_SUCCESS; 312 } 313 314 //------------------------------------------------------------------------------ 315 // Set device array to value 316 //------------------------------------------------------------------------------ 317 static int CeedDeviceSetValue_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedInt length, CeedScalar val) { 318 // Order queue 319 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 320 sycl_queue.fill(d_array, val, length, {e}); 321 return CEED_ERROR_SUCCESS; 322 } 323 324 //------------------------------------------------------------------------------ 325 // Set a vector to a value, 326 //------------------------------------------------------------------------------ 327 static int CeedVectorSetValue_Sycl(CeedVector vec, CeedScalar val) { 328 Ceed ceed; 329 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 330 CeedVector_Sycl *impl; 331 CeedCallBackend(CeedVectorGetData(vec, &impl)); 332 CeedSize length; 333 CeedCallBackend(CeedVectorGetLength(vec, &length)); 334 Ceed_Sycl *data; 335 CeedCallBackend(CeedGetData(ceed, &data)); 336 337 // Set value for synced device/host array 338 if (!impl->d_array && !impl->h_array) { 339 if (impl->d_array_borrowed) { 340 impl->d_array = impl->d_array_borrowed; 341 } else if (impl->h_array_borrowed) { 342 impl->h_array = impl->h_array_borrowed; 343 } else if (impl->d_array_owned) { 344 impl->d_array = impl->d_array_owned; 345 } else if (impl->h_array_owned) { 346 impl->h_array = impl->h_array_owned; 347 } else { 348 CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 349 } 350 } 351 if (impl->d_array) { 352 CeedCallBackend(CeedDeviceSetValue_Sycl(data->sycl_queue, impl->d_array, length, val)); 353 impl->h_array = NULL; 354 } 355 if (impl->h_array) { 356 CeedCallBackend(CeedHostSetValue_Sycl(impl->h_array, length, val)); 357 impl->d_array = NULL; 358 } 359 360 return CEED_ERROR_SUCCESS; 361 } 362 363 //------------------------------------------------------------------------------ 364 // Vector Take Array 365 //------------------------------------------------------------------------------ 366 static int CeedVectorTakeArray_Sycl(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 367 Ceed ceed; 368 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 369 CeedVector_Sycl *impl; 370 CeedCallBackend(CeedVectorGetData(vec, &impl)); 371 372 Ceed_Sycl *data; 373 CeedCallBackend(CeedGetData(ceed, &data)); 374 375 // Order queue 376 data->sycl_queue.ext_oneapi_submit_barrier(); 377 378 // Sync array to requested mem_type 379 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 380 381 // Update pointer 382 switch (mem_type) { 383 case CEED_MEM_HOST: 384 (*array) = impl->h_array_borrowed; 385 impl->h_array_borrowed = NULL; 386 impl->h_array = NULL; 387 break; 388 case CEED_MEM_DEVICE: 389 (*array) = impl->d_array_borrowed; 390 impl->d_array_borrowed = NULL; 391 impl->d_array = NULL; 392 break; 393 } 394 395 return CEED_ERROR_SUCCESS; 396 } 397 398 //------------------------------------------------------------------------------ 399 // Core logic for array syncronization for GetArray. 400 // If a different memory type is most up to date, this will perform a copy 401 //------------------------------------------------------------------------------ 402 static int CeedVectorGetArrayCore_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 403 Ceed ceed; 404 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 405 CeedVector_Sycl *impl; 406 CeedCallBackend(CeedVectorGetData(vec, &impl)); 407 408 // Sync array to requested mem_type 409 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 410 411 // Update pointer 412 switch (mem_type) { 413 case CEED_MEM_HOST: 414 *array = impl->h_array; 415 break; 416 case CEED_MEM_DEVICE: 417 *array = impl->d_array; 418 break; 419 } 420 421 return CEED_ERROR_SUCCESS; 422 } 423 //------------------------------------------------------------------------------ 424 // Get read-only access to a vector via the specified mem_type 425 //------------------------------------------------------------------------------ 426 static int CeedVectorGetArrayRead_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 427 return CeedVectorGetArrayCore_Sycl(vec, mem_type, (CeedScalar **)array); 428 } 429 430 //------------------------------------------------------------------------------ 431 // Get read/write access to a vector via the specified mem_type 432 //------------------------------------------------------------------------------ 433 static int CeedVectorGetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 434 CeedVector_Sycl *impl; 435 CeedCallBackend(CeedVectorGetData(vec, &impl)); 436 437 CeedCallBackend(CeedVectorGetArrayCore_Sycl(vec, mem_type, array)); 438 439 CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 440 switch (mem_type) { 441 case CEED_MEM_HOST: 442 impl->h_array = *array; 443 break; 444 case CEED_MEM_DEVICE: 445 impl->d_array = *array; 446 break; 447 } 448 449 return CEED_ERROR_SUCCESS; 450 } 451 452 //------------------------------------------------------------------------------ 453 // Get write access to a vector via the specified mem_type 454 //------------------------------------------------------------------------------ 455 static int CeedVectorGetArrayWrite_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 456 CeedVector_Sycl *impl; 457 CeedCallBackend(CeedVectorGetData(vec, &impl)); 458 459 bool has_array_of_type = true; 460 CeedCallBackend(CeedVectorHasArrayOfType_Sycl(vec, mem_type, &has_array_of_type)); 461 if (!has_array_of_type) { 462 // Allocate if array is not yet allocated 463 CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 464 } else { 465 // Select dirty array 466 switch (mem_type) { 467 case CEED_MEM_HOST: 468 if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 469 else impl->h_array = impl->h_array_owned; 470 break; 471 case CEED_MEM_DEVICE: 472 if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 473 else impl->d_array = impl->d_array_owned; 474 } 475 } 476 477 return CeedVectorGetArray_Sycl(vec, mem_type, array); 478 } 479 480 //------------------------------------------------------------------------------ 481 // Get the norm of a CeedVector 482 //------------------------------------------------------------------------------ 483 static int CeedVectorNorm_Sycl(CeedVector vec, CeedNormType type, CeedScalar *norm) { 484 Ceed ceed; 485 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 486 CeedVector_Sycl *impl; 487 CeedCallBackend(CeedVectorGetData(vec, &impl)); 488 CeedSize length; 489 CeedCallBackend(CeedVectorGetLength(vec, &length)); 490 Ceed_Sycl *data; 491 CeedCallBackend(CeedGetData(ceed, &data)); 492 493 // Compute norm 494 const CeedScalar *d_array; 495 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 496 497 switch (type) { 498 case CEED_NORM_1: { 499 // Order queue 500 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 501 auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 502 data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += abs(d_array[i]); }).wait_and_throw(); 503 } break; 504 case CEED_NORM_2: { 505 // Order queue 506 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 507 auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 508 data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += (d_array[i] * d_array[i]); }).wait_and_throw(); 509 } break; 510 case CEED_NORM_MAX: { 511 // Order queue 512 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 513 auto maxReduction = sycl::reduction(impl->reduction_norm, sycl::maximum<>(), {sycl::property::reduction::initialize_to_identity{}}); 514 data->sycl_queue.parallel_for(length, {e}, maxReduction, [=](sycl::id<1> i, auto &max) { max.combine(abs(d_array[i])); }).wait_and_throw(); 515 } break; 516 } 517 // L2 norm - square root over reduced value 518 if (type == CEED_NORM_2) *norm = sqrt(*impl->reduction_norm); 519 else *norm = *impl->reduction_norm; 520 521 CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 522 523 return CEED_ERROR_SUCCESS; 524 } 525 526 //------------------------------------------------------------------------------ 527 // Take reciprocal of a vector on host 528 //------------------------------------------------------------------------------ 529 static int CeedHostReciprocal_Sycl(CeedScalar *h_array, CeedInt length) { 530 for (int i = 0; i < length; i++) { 531 if (std::fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 532 } 533 return CEED_ERROR_SUCCESS; 534 } 535 536 //------------------------------------------------------------------------------ 537 // Take reciprocal of a vector on device 538 //------------------------------------------------------------------------------ 539 static int CeedDeviceReciprocal_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedInt length) { 540 // Order queue 541 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 542 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { 543 if (std::fabs(d_array[i]) > CEED_EPSILON) d_array[i] = 1. / d_array[i]; 544 }); 545 return CEED_ERROR_SUCCESS; 546 } 547 548 //------------------------------------------------------------------------------ 549 // Take reciprocal of a vector 550 //------------------------------------------------------------------------------ 551 static int CeedVectorReciprocal_Sycl(CeedVector vec) { 552 Ceed ceed; 553 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 554 CeedVector_Sycl *impl; 555 CeedCallBackend(CeedVectorGetData(vec, &impl)); 556 CeedSize length; 557 CeedCallBackend(CeedVectorGetLength(vec, &length)); 558 Ceed_Sycl *data; 559 CeedCallBackend(CeedGetData(ceed, &data)); 560 561 // Set value for synced device/host array 562 if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Sycl(data->sycl_queue, impl->d_array, length)); 563 if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Sycl(impl->h_array, length)); 564 565 return CEED_ERROR_SUCCESS; 566 } 567 568 //------------------------------------------------------------------------------ 569 // Compute x = alpha x on the host 570 //------------------------------------------------------------------------------ 571 static int CeedHostScale_Sycl(CeedScalar *x_array, CeedScalar alpha, CeedInt length) { 572 for (int i = 0; i < length; i++) x_array[i] *= alpha; 573 return CEED_ERROR_SUCCESS; 574 } 575 576 //------------------------------------------------------------------------------ 577 // Compute x = alpha x on device 578 //------------------------------------------------------------------------------ 579 static int CeedDeviceScale_Sycl(sycl::queue &sycl_queue, CeedScalar *x_array, CeedScalar alpha, CeedInt length) { 580 // Order queue 581 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 582 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { x_array[i] *= alpha; }); 583 return CEED_ERROR_SUCCESS; 584 } 585 586 //------------------------------------------------------------------------------ 587 // Compute x = alpha x 588 //------------------------------------------------------------------------------ 589 static int CeedVectorScale_Sycl(CeedVector x, CeedScalar alpha) { 590 Ceed ceed; 591 CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 592 CeedVector_Sycl *x_impl; 593 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 594 CeedSize length; 595 CeedCallBackend(CeedVectorGetLength(x, &length)); 596 Ceed_Sycl *data; 597 CeedCallBackend(CeedGetData(ceed, &data)); 598 599 // Set value for synced device/host array 600 if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Sycl(data->sycl_queue, x_impl->d_array, alpha, length)); 601 if (x_impl->h_array) CeedCallBackend(CeedHostScale_Sycl(x_impl->h_array, alpha, length)); 602 603 return CEED_ERROR_SUCCESS; 604 } 605 606 //------------------------------------------------------------------------------ 607 // Compute y = alpha x + y on the host 608 //------------------------------------------------------------------------------ 609 static int CeedHostAXPY_Sycl(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) { 610 for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 611 return CEED_ERROR_SUCCESS; 612 } 613 614 //------------------------------------------------------------------------------ 615 // Compute y = alpha x + y on device 616 //------------------------------------------------------------------------------ 617 static int CeedDeviceAXPY_Sycl(sycl::queue &sycl_queue, CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) { 618 // Order queue 619 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 620 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { y_array[i] += alpha * x_array[i]; }); 621 return CEED_ERROR_SUCCESS; 622 } 623 624 //------------------------------------------------------------------------------ 625 // Compute y = alpha x + y 626 //------------------------------------------------------------------------------ 627 static int CeedVectorAXPY_Sycl(CeedVector y, CeedScalar alpha, CeedVector x) { 628 Ceed ceed; 629 CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 630 CeedVector_Sycl *y_impl, *x_impl; 631 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 632 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 633 CeedSize length; 634 CeedCallBackend(CeedVectorGetLength(y, &length)); 635 Ceed_Sycl *data; 636 CeedCallBackend(CeedGetData(ceed, &data)); 637 638 // Set value for synced device/host array 639 if (y_impl->d_array) { 640 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 641 CeedCallBackend(CeedDeviceAXPY_Sycl(data->sycl_queue, y_impl->d_array, alpha, x_impl->d_array, length)); 642 } 643 if (y_impl->h_array) { 644 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 645 CeedCallBackend(CeedHostAXPY_Sycl(y_impl->h_array, alpha, x_impl->h_array, length)); 646 } 647 648 return CEED_ERROR_SUCCESS; 649 } 650 651 //------------------------------------------------------------------------------ 652 // Compute the pointwise multiplication w = x .* y on the host 653 //------------------------------------------------------------------------------ 654 static int CeedHostPointwiseMult_Sycl(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) { 655 for (int i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 656 return CEED_ERROR_SUCCESS; 657 } 658 659 //------------------------------------------------------------------------------ 660 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 661 //------------------------------------------------------------------------------ 662 static int CeedDevicePointwiseMult_Sycl(sycl::queue &sycl_queue, CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) { 663 // Order queue 664 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 665 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { w_array[i] = x_array[i] * y_array[i]; }); 666 return CEED_ERROR_SUCCESS; 667 } 668 669 //------------------------------------------------------------------------------ 670 // Compute the pointwise multiplication w = x .* y 671 //------------------------------------------------------------------------------ 672 static int CeedVectorPointwiseMult_Sycl(CeedVector w, CeedVector x, CeedVector y) { 673 Ceed ceed; 674 CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 675 CeedVector_Sycl *w_impl, *x_impl, *y_impl; 676 CeedCallBackend(CeedVectorGetData(w, &w_impl)); 677 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 678 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 679 CeedSize length; 680 CeedCallBackend(CeedVectorGetLength(w, &length)); 681 Ceed_Sycl *data; 682 CeedCallBackend(CeedGetData(ceed, &data)); 683 684 // Set value for synced device/host array 685 if (!w_impl->d_array && !w_impl->h_array) { 686 CeedCallBackend(CeedVectorSetValue(w, 0.0)); 687 } 688 if (w_impl->d_array) { 689 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 690 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 691 CeedCallBackend(CeedDevicePointwiseMult_Sycl(data->sycl_queue, w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 692 } 693 if (w_impl->h_array) { 694 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 695 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 696 CeedCallBackend(CeedHostPointwiseMult_Sycl(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 697 } 698 699 return CEED_ERROR_SUCCESS; 700 } 701 702 //------------------------------------------------------------------------------ 703 // Destroy the vector 704 //------------------------------------------------------------------------------ 705 static int CeedVectorDestroy_Sycl(const CeedVector vec) { 706 Ceed ceed; 707 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 708 CeedVector_Sycl *impl; 709 CeedCallBackend(CeedVectorGetData(vec, &impl)); 710 Ceed_Sycl *data; 711 CeedCallBackend(CeedGetData(ceed, &data)); 712 713 // Wait for all work to finish before freeing memory 714 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 715 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 716 CeedCallSycl(ceed, sycl::free(impl->reduction_norm, data->sycl_context)); 717 718 CeedCallBackend(CeedFree(&impl->h_array_owned)); 719 CeedCallBackend(CeedFree(&impl)); 720 721 return CEED_ERROR_SUCCESS; 722 } 723 724 //------------------------------------------------------------------------------ 725 // Create a vector of the specified length (does not allocate memory) 726 //------------------------------------------------------------------------------ 727 int CeedVectorCreate_Sycl(CeedSize n, CeedVector vec) { 728 CeedVector_Sycl *impl; 729 Ceed ceed; 730 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 731 Ceed_Sycl *data; 732 CeedCallBackend(CeedGetData(ceed, &data)); 733 734 CeedCallBackend(CeedCalloc(1, &impl)); 735 CeedCallSycl(ceed, impl->reduction_norm = sycl::malloc_host<CeedScalar>(1, data->sycl_context)); 736 737 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Sycl)); 738 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Sycl)); 739 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Sycl)); 740 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Sycl)); 741 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Sycl)); 742 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Sycl)); 743 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Sycl)); 744 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Sycl)); 745 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Sycl)); 746 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Norm", CeedVectorNorm_Sycl)); 747 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Sycl)); 748 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Sycl)); 749 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Scale", CeedVectorScale_Sycl)); 750 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Sycl)); 751 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Sycl)); 752 753 CeedCallBackend(CeedVectorSetData(vec, impl)); 754 755 return CEED_ERROR_SUCCESS; 756 } 757