1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 11 #include <cmath> 12 #include <string> 13 #include <sycl/sycl.hpp> 14 15 #include "ceed-sycl-ref.hpp" 16 17 //------------------------------------------------------------------------------ 18 // Check if host/device sync is needed 19 //------------------------------------------------------------------------------ 20 static inline int CeedVectorNeedSync_Sycl(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 21 bool has_valid_array = false; 22 CeedVector_Sycl *impl; 23 24 CeedCallBackend(CeedVectorGetData(vec, &impl)); 25 CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 26 switch (mem_type) { 27 case CEED_MEM_HOST: 28 *need_sync = has_valid_array && !impl->h_array; 29 break; 30 case CEED_MEM_DEVICE: 31 *need_sync = has_valid_array && !impl->d_array; 32 break; 33 } 34 return CEED_ERROR_SUCCESS; 35 } 36 37 //------------------------------------------------------------------------------ 38 // Sync host to device 39 //------------------------------------------------------------------------------ 40 static inline int CeedVectorSyncH2D_Sycl(const CeedVector vec) { 41 Ceed ceed; 42 Ceed_Sycl *data; 43 CeedSize length; 44 CeedVector_Sycl *impl; 45 46 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 47 CeedCallBackend(CeedVectorGetData(vec, &impl)); 48 CeedCallBackend(CeedGetData(ceed, &data)); 49 CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 50 51 CeedCallBackend(CeedVectorGetLength(vec, &length)); 52 if (impl->d_array_borrowed) { 53 impl->d_array = impl->d_array_borrowed; 54 } else if (impl->d_array_owned) { 55 impl->d_array = impl->d_array_owned; 56 } else { 57 CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 58 impl->d_array = impl->d_array_owned; 59 } 60 61 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 62 // Copy from host to device 63 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->h_array, impl->d_array, length, {e}); 64 // Wait for copy to finish and handle exceptions. 65 CeedCallSycl(ceed, copy_event.wait_and_throw()); 66 return CEED_ERROR_SUCCESS; 67 } 68 69 //------------------------------------------------------------------------------ 70 // Sync device to host 71 //------------------------------------------------------------------------------ 72 static inline int CeedVectorSyncD2H_Sycl(const CeedVector vec) { 73 Ceed ceed; 74 Ceed_Sycl *data; 75 CeedSize length; 76 CeedVector_Sycl *impl; 77 78 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 79 CeedCallBackend(CeedVectorGetData(vec, &impl)); 80 CeedCallBackend(CeedGetData(ceed, &data)); 81 82 CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 83 84 CeedCallBackend(CeedVectorGetLength(vec, &length)); 85 if (impl->h_array_borrowed) { 86 impl->h_array = impl->h_array_borrowed; 87 } else if (impl->h_array_owned) { 88 impl->h_array = impl->h_array_owned; 89 } else { 90 CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 91 impl->h_array = impl->h_array_owned; 92 } 93 94 // Order queue 95 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 96 // Copy from device to host 97 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(impl->d_array, impl->h_array, length, {e}); 98 // Wait for copy to finish and handle exceptions. 99 CeedCallSycl(ceed, copy_event.wait_and_throw()); 100 return CEED_ERROR_SUCCESS; 101 } 102 103 //------------------------------------------------------------------------------ 104 // Sync arrays 105 //------------------------------------------------------------------------------ 106 static int CeedVectorSyncArray_Sycl(const CeedVector vec, CeedMemType mem_type) { 107 bool need_sync = false; 108 109 // Check whether device/host sync is needed 110 CeedCallBackend(CeedVectorNeedSync_Sycl(vec, mem_type, &need_sync)); 111 if (!need_sync) return CEED_ERROR_SUCCESS; 112 113 switch (mem_type) { 114 case CEED_MEM_HOST: 115 return CeedVectorSyncD2H_Sycl(vec); 116 case CEED_MEM_DEVICE: 117 return CeedVectorSyncH2D_Sycl(vec); 118 } 119 return CEED_ERROR_UNSUPPORTED; 120 } 121 122 //------------------------------------------------------------------------------ 123 // Set all pointers as invalid 124 //------------------------------------------------------------------------------ 125 static inline int CeedVectorSetAllInvalid_Sycl(const CeedVector vec) { 126 CeedVector_Sycl *impl; 127 128 CeedCallBackend(CeedVectorGetData(vec, &impl)); 129 impl->h_array = NULL; 130 impl->d_array = NULL; 131 return CEED_ERROR_SUCCESS; 132 } 133 134 //------------------------------------------------------------------------------ 135 // Check if CeedVector has any valid pointer 136 //------------------------------------------------------------------------------ 137 static inline int CeedVectorHasValidArray_Sycl(const CeedVector vec, bool *has_valid_array) { 138 CeedVector_Sycl *impl; 139 140 CeedCallBackend(CeedVectorGetData(vec, &impl)); 141 *has_valid_array = impl->h_array || impl->d_array; 142 return CEED_ERROR_SUCCESS; 143 } 144 145 //------------------------------------------------------------------------------ 146 // Check if has array of given type 147 //------------------------------------------------------------------------------ 148 static inline int CeedVectorHasArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 149 CeedVector_Sycl *impl; 150 151 CeedCallBackend(CeedVectorGetData(vec, &impl)); 152 switch (mem_type) { 153 case CEED_MEM_HOST: 154 *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned; 155 break; 156 case CEED_MEM_DEVICE: 157 *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned; 158 break; 159 } 160 return CEED_ERROR_SUCCESS; 161 } 162 163 //------------------------------------------------------------------------------ 164 // Check if has borrowed array of given type 165 //------------------------------------------------------------------------------ 166 static inline int CeedVectorHasBorrowedArrayOfType_Sycl(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 167 CeedVector_Sycl *impl; 168 169 CeedCallBackend(CeedVectorGetData(vec, &impl)); 170 switch (mem_type) { 171 case CEED_MEM_HOST: 172 *has_borrowed_array_of_type = impl->h_array_borrowed; 173 break; 174 case CEED_MEM_DEVICE: 175 *has_borrowed_array_of_type = impl->d_array_borrowed; 176 break; 177 } 178 return CEED_ERROR_SUCCESS; 179 } 180 181 //------------------------------------------------------------------------------ 182 // Set array from host 183 //------------------------------------------------------------------------------ 184 static int CeedVectorSetArrayHost_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 185 CeedSize length; 186 CeedVector_Sycl *impl; 187 188 CeedCallBackend(CeedVectorGetData(vec, &impl)); 189 CeedCallBackend(CeedVectorGetLength(vec, &length)); 190 191 switch (copy_mode) { 192 case CEED_COPY_VALUES: { 193 if (!impl->h_array_owned) CeedCallBackend(CeedMalloc(length, &impl->h_array_owned)); 194 if (array) memcpy(impl->h_array_owned, array, length * sizeof(array[0])); 195 impl->h_array_borrowed = NULL; 196 impl->h_array = impl->h_array_owned; 197 } break; 198 case CEED_OWN_POINTER: 199 CeedCallBackend(CeedFree(&impl->h_array_owned)); 200 impl->h_array_owned = array; 201 impl->h_array_borrowed = NULL; 202 impl->h_array = impl->h_array_owned; 203 break; 204 case CEED_USE_POINTER: 205 CeedCallBackend(CeedFree(&impl->h_array_owned)); 206 impl->h_array_owned = NULL; 207 impl->h_array_borrowed = array; 208 impl->h_array = impl->h_array_borrowed; 209 break; 210 } 211 return CEED_ERROR_SUCCESS; 212 } 213 214 //------------------------------------------------------------------------------ 215 // Set array from device 216 //------------------------------------------------------------------------------ 217 static int CeedVectorSetArrayDevice_Sycl(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 218 CeedSize length; 219 Ceed ceed; 220 Ceed_Sycl *data; 221 CeedVector_Sycl *impl; 222 223 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 224 CeedCallBackend(CeedVectorGetData(vec, &impl)); 225 CeedCallBackend(CeedGetData(ceed, &data)); 226 CeedCallBackend(CeedVectorGetLength(vec, &length)); 227 228 // Order queue 229 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 230 231 switch (copy_mode) { 232 case CEED_COPY_VALUES: { 233 if (!impl->d_array_owned) 234 CeedCallSycl(ceed, impl->d_array_owned = sycl::malloc_device<CeedScalar>(length, data->sycl_device, data->sycl_context)); 235 if (array) { 236 sycl::event copy_event = data->sycl_queue.copy<CeedScalar>(array, impl->d_array, length, {e}); 237 // Wait for copy to finish and handle exceptions. 238 CeedCallSycl(ceed, copy_event.wait_and_throw()); 239 } 240 impl->d_array_borrowed = NULL; 241 impl->d_array = impl->d_array_owned; 242 } break; 243 case CEED_OWN_POINTER: 244 if (impl->d_array_owned) { 245 // Wait for all work to finish before freeing memory 246 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 247 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 248 } 249 impl->d_array_owned = array; 250 impl->d_array_borrowed = NULL; 251 impl->d_array = impl->d_array_owned; 252 break; 253 case CEED_USE_POINTER: 254 if (impl->d_array_owned) { 255 // Wait for all work to finish before freeing memory 256 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 257 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 258 } 259 impl->d_array_owned = NULL; 260 impl->d_array_borrowed = array; 261 impl->d_array = impl->d_array_borrowed; 262 break; 263 } 264 return CEED_ERROR_SUCCESS; 265 } 266 267 //------------------------------------------------------------------------------ 268 // Set the array used by a vector, 269 // freeing any previously allocated array if applicable 270 //------------------------------------------------------------------------------ 271 static int CeedVectorSetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 272 CeedVector_Sycl *impl; 273 274 CeedCallBackend(CeedVectorGetData(vec, &impl)); 275 276 CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 277 switch (mem_type) { 278 case CEED_MEM_HOST: 279 return CeedVectorSetArrayHost_Sycl(vec, copy_mode, array); 280 case CEED_MEM_DEVICE: 281 return CeedVectorSetArrayDevice_Sycl(vec, copy_mode, array); 282 } 283 return CEED_ERROR_UNSUPPORTED; 284 } 285 286 //------------------------------------------------------------------------------ 287 // Set host array to value 288 //------------------------------------------------------------------------------ 289 static int CeedHostSetValue_Sycl(CeedScalar *h_array, CeedSize length, CeedScalar val) { 290 for (CeedSize i = 0; i < length; i++) h_array[i] = val; 291 return CEED_ERROR_SUCCESS; 292 } 293 294 //------------------------------------------------------------------------------ 295 // Set device array to value 296 //------------------------------------------------------------------------------ 297 static int CeedDeviceSetValue_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length, CeedScalar val) { 298 // Order queue 299 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 300 sycl_queue.fill(d_array, val, length, {e}); 301 return CEED_ERROR_SUCCESS; 302 } 303 304 //------------------------------------------------------------------------------ 305 // Set a vector to a value, 306 //------------------------------------------------------------------------------ 307 static int CeedVectorSetValue_Sycl(CeedVector vec, CeedScalar val) { 308 Ceed ceed; 309 Ceed_Sycl *data; 310 CeedSize length; 311 CeedVector_Sycl *impl; 312 313 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 314 CeedCallBackend(CeedVectorGetData(vec, &impl)); 315 CeedCallBackend(CeedVectorGetLength(vec, &length)); 316 CeedCallBackend(CeedGetData(ceed, &data)); 317 318 // Set value for synced device/host array 319 if (!impl->d_array && !impl->h_array) { 320 if (impl->d_array_borrowed) { 321 impl->d_array = impl->d_array_borrowed; 322 } else if (impl->h_array_borrowed) { 323 impl->h_array = impl->h_array_borrowed; 324 } else if (impl->d_array_owned) { 325 impl->d_array = impl->d_array_owned; 326 } else if (impl->h_array_owned) { 327 impl->h_array = impl->h_array_owned; 328 } else { 329 CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 330 } 331 } 332 if (impl->d_array) { 333 CeedCallBackend(CeedDeviceSetValue_Sycl(data->sycl_queue, impl->d_array, length, val)); 334 impl->h_array = NULL; 335 } 336 if (impl->h_array) { 337 CeedCallBackend(CeedHostSetValue_Sycl(impl->h_array, length, val)); 338 impl->d_array = NULL; 339 } 340 return CEED_ERROR_SUCCESS; 341 } 342 343 //------------------------------------------------------------------------------ 344 // Vector Take Array 345 //------------------------------------------------------------------------------ 346 static int CeedVectorTakeArray_Sycl(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 347 Ceed ceed; 348 Ceed_Sycl *data; 349 CeedVector_Sycl *impl; 350 351 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 352 CeedCallBackend(CeedVectorGetData(vec, &impl)); 353 CeedCallBackend(CeedGetData(ceed, &data)); 354 355 // Order queue 356 data->sycl_queue.ext_oneapi_submit_barrier(); 357 358 // Sync array to requested mem_type 359 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 360 361 // Update pointer 362 switch (mem_type) { 363 case CEED_MEM_HOST: 364 (*array) = impl->h_array_borrowed; 365 impl->h_array_borrowed = NULL; 366 impl->h_array = NULL; 367 break; 368 case CEED_MEM_DEVICE: 369 (*array) = impl->d_array_borrowed; 370 impl->d_array_borrowed = NULL; 371 impl->d_array = NULL; 372 break; 373 } 374 return CEED_ERROR_SUCCESS; 375 } 376 377 //------------------------------------------------------------------------------ 378 // Core logic for array syncronization for GetArray. 379 // If a different memory type is most up to date, this will perform a copy 380 //------------------------------------------------------------------------------ 381 static int CeedVectorGetArrayCore_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 382 CeedVector_Sycl *impl; 383 384 CeedCallBackend(CeedVectorGetData(vec, &impl)); 385 386 // Sync array to requested mem_type 387 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 388 389 // Update pointer 390 switch (mem_type) { 391 case CEED_MEM_HOST: 392 *array = impl->h_array; 393 break; 394 case CEED_MEM_DEVICE: 395 *array = impl->d_array; 396 break; 397 } 398 return CEED_ERROR_SUCCESS; 399 } 400 401 //------------------------------------------------------------------------------ 402 // Get read-only access to a vector via the specified mem_type 403 //------------------------------------------------------------------------------ 404 static int CeedVectorGetArrayRead_Sycl(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 405 return CeedVectorGetArrayCore_Sycl(vec, mem_type, (CeedScalar **)array); 406 } 407 408 //------------------------------------------------------------------------------ 409 // Get read/write access to a vector via the specified mem_type 410 //------------------------------------------------------------------------------ 411 static int CeedVectorGetArray_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 412 CeedVector_Sycl *impl; 413 414 CeedCallBackend(CeedVectorGetData(vec, &impl)); 415 CeedCallBackend(CeedVectorGetArrayCore_Sycl(vec, mem_type, array)); 416 CeedCallBackend(CeedVectorSetAllInvalid_Sycl(vec)); 417 switch (mem_type) { 418 case CEED_MEM_HOST: 419 impl->h_array = *array; 420 break; 421 case CEED_MEM_DEVICE: 422 impl->d_array = *array; 423 break; 424 } 425 return CEED_ERROR_SUCCESS; 426 } 427 428 //------------------------------------------------------------------------------ 429 // Get write access to a vector via the specified mem_type 430 //------------------------------------------------------------------------------ 431 static int CeedVectorGetArrayWrite_Sycl(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 432 bool has_array_of_type = true; 433 CeedVector_Sycl *impl; 434 435 CeedCallBackend(CeedVectorGetData(vec, &impl)); 436 CeedCallBackend(CeedVectorHasArrayOfType_Sycl(vec, mem_type, &has_array_of_type)); 437 if (!has_array_of_type) { 438 // Allocate if array is not yet allocated 439 CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 440 } else { 441 // Select dirty array 442 switch (mem_type) { 443 case CEED_MEM_HOST: 444 if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 445 else impl->h_array = impl->h_array_owned; 446 break; 447 case CEED_MEM_DEVICE: 448 if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 449 else impl->d_array = impl->d_array_owned; 450 } 451 } 452 return CeedVectorGetArray_Sycl(vec, mem_type, array); 453 } 454 455 //------------------------------------------------------------------------------ 456 // Get the norm of a CeedVector 457 //------------------------------------------------------------------------------ 458 static int CeedVectorNorm_Sycl(CeedVector vec, CeedNormType type, CeedScalar *norm) { 459 Ceed ceed; 460 Ceed_Sycl *data; 461 CeedSize length; 462 const CeedScalar *d_array; 463 CeedVector_Sycl *impl; 464 465 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 466 CeedCallBackend(CeedVectorGetData(vec, &impl)); 467 CeedCallBackend(CeedVectorGetLength(vec, &length)); 468 CeedCallBackend(CeedGetData(ceed, &data)); 469 470 // Compute norm 471 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 472 switch (type) { 473 case CEED_NORM_1: { 474 // Order queue 475 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 476 auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 477 data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += abs(d_array[i]); }).wait_and_throw(); 478 } break; 479 case CEED_NORM_2: { 480 // Order queue 481 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 482 auto sumReduction = sycl::reduction(impl->reduction_norm, sycl::plus<>(), {sycl::property::reduction::initialize_to_identity{}}); 483 data->sycl_queue.parallel_for(length, {e}, sumReduction, [=](sycl::id<1> i, auto &sum) { sum += (d_array[i] * d_array[i]); }).wait_and_throw(); 484 } break; 485 case CEED_NORM_MAX: { 486 // Order queue 487 sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier(); 488 auto maxReduction = sycl::reduction(impl->reduction_norm, sycl::maximum<>(), {sycl::property::reduction::initialize_to_identity{}}); 489 data->sycl_queue.parallel_for(length, {e}, maxReduction, [=](sycl::id<1> i, auto &max) { max.combine(abs(d_array[i])); }).wait_and_throw(); 490 } break; 491 } 492 // L2 norm - square root over reduced value 493 if (type == CEED_NORM_2) *norm = sqrt(*impl->reduction_norm); 494 else *norm = *impl->reduction_norm; 495 CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 496 return CEED_ERROR_SUCCESS; 497 } 498 499 //------------------------------------------------------------------------------ 500 // Take reciprocal of a vector on host 501 //------------------------------------------------------------------------------ 502 static int CeedHostReciprocal_Sycl(CeedScalar *h_array, CeedSize length) { 503 for (CeedSize i = 0; i < length; i++) { 504 if (std::fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 505 } 506 return CEED_ERROR_SUCCESS; 507 } 508 509 //------------------------------------------------------------------------------ 510 // Take reciprocal of a vector on device 511 //------------------------------------------------------------------------------ 512 static int CeedDeviceReciprocal_Sycl(sycl::queue &sycl_queue, CeedScalar *d_array, CeedSize length) { 513 // Order queue 514 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 515 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { 516 if (std::fabs(d_array[i]) > CEED_EPSILON) d_array[i] = 1. / d_array[i]; 517 }); 518 return CEED_ERROR_SUCCESS; 519 } 520 521 //------------------------------------------------------------------------------ 522 // Take reciprocal of a vector 523 //------------------------------------------------------------------------------ 524 static int CeedVectorReciprocal_Sycl(CeedVector vec) { 525 Ceed ceed; 526 Ceed_Sycl *data; 527 CeedSize length; 528 CeedVector_Sycl *impl; 529 530 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 531 CeedCallBackend(CeedVectorGetData(vec, &impl)); 532 CeedCallBackend(CeedVectorGetLength(vec, &length)); 533 CeedCallBackend(CeedGetData(ceed, &data)); 534 535 // Set value for synced device/host array 536 if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Sycl(data->sycl_queue, impl->d_array, length)); 537 if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Sycl(impl->h_array, length)); 538 return CEED_ERROR_SUCCESS; 539 } 540 541 //------------------------------------------------------------------------------ 542 // Compute x = alpha x on the host 543 //------------------------------------------------------------------------------ 544 static int CeedHostScale_Sycl(CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 545 for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha; 546 return CEED_ERROR_SUCCESS; 547 } 548 549 //------------------------------------------------------------------------------ 550 // Compute x = alpha x on device 551 //------------------------------------------------------------------------------ 552 static int CeedDeviceScale_Sycl(sycl::queue &sycl_queue, CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 553 // Order queue 554 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 555 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { x_array[i] *= alpha; }); 556 return CEED_ERROR_SUCCESS; 557 } 558 559 //------------------------------------------------------------------------------ 560 // Compute x = alpha x 561 //------------------------------------------------------------------------------ 562 static int CeedVectorScale_Sycl(CeedVector x, CeedScalar alpha) { 563 Ceed ceed; 564 Ceed_Sycl *data; 565 CeedSize length; 566 CeedVector_Sycl *x_impl; 567 568 CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 569 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 570 CeedCallBackend(CeedVectorGetLength(x, &length)); 571 CeedCallBackend(CeedGetData(ceed, &data)); 572 573 // Set value for synced device/host array 574 if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Sycl(data->sycl_queue, x_impl->d_array, alpha, length)); 575 if (x_impl->h_array) CeedCallBackend(CeedHostScale_Sycl(x_impl->h_array, alpha, length)); 576 return CEED_ERROR_SUCCESS; 577 } 578 579 //------------------------------------------------------------------------------ 580 // Compute y = alpha x + y on the host 581 //------------------------------------------------------------------------------ 582 static int CeedHostAXPY_Sycl(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 583 for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 584 return CEED_ERROR_SUCCESS; 585 } 586 587 //------------------------------------------------------------------------------ 588 // Compute y = alpha x + y on device 589 //------------------------------------------------------------------------------ 590 static int CeedDeviceAXPY_Sycl(sycl::queue &sycl_queue, CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 591 // Order queue 592 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 593 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { y_array[i] += alpha * x_array[i]; }); 594 return CEED_ERROR_SUCCESS; 595 } 596 597 //------------------------------------------------------------------------------ 598 // Compute y = alpha x + y 599 //------------------------------------------------------------------------------ 600 static int CeedVectorAXPY_Sycl(CeedVector y, CeedScalar alpha, CeedVector x) { 601 Ceed ceed; 602 Ceed_Sycl *data; 603 CeedSize length; 604 CeedVector_Sycl *y_impl, *x_impl; 605 606 CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 607 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 608 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 609 CeedCallBackend(CeedVectorGetLength(y, &length)); 610 CeedCallBackend(CeedGetData(ceed, &data)); 611 612 // Set value for synced device/host array 613 if (y_impl->d_array) { 614 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 615 CeedCallBackend(CeedDeviceAXPY_Sycl(data->sycl_queue, y_impl->d_array, alpha, x_impl->d_array, length)); 616 } 617 if (y_impl->h_array) { 618 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 619 CeedCallBackend(CeedHostAXPY_Sycl(y_impl->h_array, alpha, x_impl->h_array, length)); 620 } 621 return CEED_ERROR_SUCCESS; 622 } 623 624 //------------------------------------------------------------------------------ 625 // Compute the pointwise multiplication w = x .* y on the host 626 //------------------------------------------------------------------------------ 627 static int CeedHostPointwiseMult_Sycl(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 628 for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 629 return CEED_ERROR_SUCCESS; 630 } 631 632 //------------------------------------------------------------------------------ 633 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 634 //------------------------------------------------------------------------------ 635 static int CeedDevicePointwiseMult_Sycl(sycl::queue &sycl_queue, CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 636 // Order queue 637 sycl::event e = sycl_queue.ext_oneapi_submit_barrier(); 638 sycl_queue.parallel_for(length, {e}, [=](sycl::id<1> i) { w_array[i] = x_array[i] * y_array[i]; }); 639 return CEED_ERROR_SUCCESS; 640 } 641 642 //------------------------------------------------------------------------------ 643 // Compute the pointwise multiplication w = x .* y 644 //------------------------------------------------------------------------------ 645 static int CeedVectorPointwiseMult_Sycl(CeedVector w, CeedVector x, CeedVector y) { 646 Ceed ceed; 647 Ceed_Sycl *data; 648 CeedSize length; 649 CeedVector_Sycl *w_impl, *x_impl, *y_impl; 650 651 CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 652 CeedCallBackend(CeedVectorGetData(w, &w_impl)); 653 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 654 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 655 CeedCallBackend(CeedVectorGetLength(w, &length)); 656 CeedCallBackend(CeedGetData(ceed, &data)); 657 658 // Set value for synced device/host array 659 if (!w_impl->d_array && !w_impl->h_array) { 660 CeedCallBackend(CeedVectorSetValue(w, 0.0)); 661 } 662 if (w_impl->d_array) { 663 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 664 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 665 CeedCallBackend(CeedDevicePointwiseMult_Sycl(data->sycl_queue, w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 666 } 667 if (w_impl->h_array) { 668 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 669 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 670 CeedCallBackend(CeedHostPointwiseMult_Sycl(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 671 } 672 return CEED_ERROR_SUCCESS; 673 } 674 675 //------------------------------------------------------------------------------ 676 // Destroy the vector 677 //------------------------------------------------------------------------------ 678 static int CeedVectorDestroy_Sycl(const CeedVector vec) { 679 Ceed ceed; 680 Ceed_Sycl *data; 681 CeedVector_Sycl *impl; 682 683 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 684 CeedCallBackend(CeedVectorGetData(vec, &impl)); 685 CeedCallBackend(CeedGetData(ceed, &data)); 686 687 // Wait for all work to finish before freeing memory 688 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 689 CeedCallSycl(ceed, sycl::free(impl->d_array_owned, data->sycl_context)); 690 CeedCallSycl(ceed, sycl::free(impl->reduction_norm, data->sycl_context)); 691 692 CeedCallBackend(CeedFree(&impl->h_array_owned)); 693 CeedCallBackend(CeedFree(&impl)); 694 return CEED_ERROR_SUCCESS; 695 } 696 697 //------------------------------------------------------------------------------ 698 // Create a vector of the specified length (does not allocate memory) 699 //------------------------------------------------------------------------------ 700 int CeedVectorCreate_Sycl(CeedSize n, CeedVector vec) { 701 Ceed ceed; 702 Ceed_Sycl *data; 703 CeedVector_Sycl *impl; 704 705 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 706 CeedCallBackend(CeedGetData(ceed, &data)); 707 CeedCallBackend(CeedCalloc(1, &impl)); 708 CeedCallSycl(ceed, impl->reduction_norm = sycl::malloc_host<CeedScalar>(1, data->sycl_context)); 709 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Sycl)); 710 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Sycl)); 711 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Sycl)); 712 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Sycl)); 713 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Sycl)); 714 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Sycl)); 715 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Sycl)); 716 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Sycl)); 717 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Sycl)); 718 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Norm", CeedVectorNorm_Sycl)); 719 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Sycl)); 720 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Sycl)); 721 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Scale", CeedVectorScale_Sycl)); 722 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Sycl)); 723 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Sycl)); 724 CeedCallBackend(CeedVectorSetData(vec, impl)); 725 return CEED_ERROR_SUCCESS; 726 } 727