1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include <hip/hip_runtime.h> 11 #include <math.h> 12 #include <string.h> 13 #include "ceed-hip-ref.h" 14 15 16 //------------------------------------------------------------------------------ 17 // Check if host/device sync is needed 18 //------------------------------------------------------------------------------ 19 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, 20 CeedMemType mem_type, bool *need_sync) { 21 int ierr; 22 CeedVector_Hip *impl; 23 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 24 25 bool has_valid_array = false; 26 ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr); 27 switch (mem_type) { 28 case CEED_MEM_HOST: 29 *need_sync = has_valid_array && !impl->h_array; 30 break; 31 case CEED_MEM_DEVICE: 32 *need_sync = has_valid_array && !impl->d_array; 33 break; 34 } 35 36 return CEED_ERROR_SUCCESS; 37 } 38 39 //------------------------------------------------------------------------------ 40 // Sync host to device 41 //------------------------------------------------------------------------------ 42 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) { 43 int ierr; 44 Ceed ceed; 45 ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr); 46 CeedVector_Hip *impl; 47 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 48 49 CeedSize length; 50 ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr); 51 size_t bytes = length * sizeof(CeedScalar); 52 53 if (!impl->h_array) 54 // LCOV_EXCL_START 55 return CeedError(ceed, CEED_ERROR_BACKEND, 56 "No valid host data to sync to device"); 57 // LCOV_EXCL_STOP 58 59 if (impl->d_array_borrowed) { 60 impl->d_array = impl->d_array_borrowed; 61 } else if (impl->d_array_owned) { 62 impl->d_array = impl->d_array_owned; 63 } else { 64 ierr = hipMalloc((void **)&impl->d_array_owned, bytes); 65 CeedChk_Hip(ceed, ierr); 66 impl->d_array = impl->d_array_owned; 67 } 68 69 ierr = hipMemcpy(impl->d_array, impl->h_array, bytes, 70 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 71 72 return CEED_ERROR_SUCCESS; 73 } 74 75 //------------------------------------------------------------------------------ 76 // Sync device to host 77 //------------------------------------------------------------------------------ 78 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) { 79 int ierr; 80 Ceed ceed; 81 ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr); 82 CeedVector_Hip *impl; 83 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 84 85 if (!impl->d_array) 86 // LCOV_EXCL_START 87 return CeedError(ceed, CEED_ERROR_BACKEND, 88 "No valid device data to sync to host"); 89 // LCOV_EXCL_STOP 90 91 if (impl->h_array_borrowed) { 92 impl->h_array = impl->h_array_borrowed; 93 } else if (impl->h_array_owned) { 94 impl->h_array = impl->h_array_owned; 95 } else { 96 CeedSize length; 97 ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr); 98 ierr = CeedCalloc(length, &impl->h_array_owned); CeedChkBackend(ierr); 99 impl->h_array = impl->h_array_owned; 100 } 101 102 CeedSize length; 103 ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr); 104 size_t bytes = length * sizeof(CeedScalar); 105 ierr = hipMemcpy(impl->h_array, impl->d_array, bytes, 106 hipMemcpyDeviceToHost); CeedChk_Hip(ceed, ierr); 107 108 return CEED_ERROR_SUCCESS; 109 } 110 111 //------------------------------------------------------------------------------ 112 // Sync arrays 113 //------------------------------------------------------------------------------ 114 static int CeedVectorSyncArray_Hip(const CeedVector vec, 115 CeedMemType mem_type) { 116 int ierr; 117 // Check whether device/host sync is needed 118 bool need_sync = false; 119 ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); 120 CeedChkBackend(ierr); 121 if (!need_sync) 122 return CEED_ERROR_SUCCESS; 123 124 switch (mem_type) { 125 case CEED_MEM_HOST: return CeedVectorSyncD2H_Hip(vec); 126 case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Hip(vec); 127 } 128 return CEED_ERROR_UNSUPPORTED; 129 } 130 131 //------------------------------------------------------------------------------ 132 // Set all pointers as invalid 133 //------------------------------------------------------------------------------ 134 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) { 135 int ierr; 136 CeedVector_Hip *impl; 137 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 138 139 impl->h_array = NULL; 140 impl->d_array = NULL; 141 142 return CEED_ERROR_SUCCESS; 143 } 144 145 //------------------------------------------------------------------------------ 146 // Check if CeedVector has any valid pointers 147 //------------------------------------------------------------------------------ 148 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, 149 bool *has_valid_array) { 150 int ierr; 151 CeedVector_Hip *impl; 152 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 153 154 *has_valid_array = !!impl->h_array || !!impl->d_array; 155 156 return CEED_ERROR_SUCCESS; 157 } 158 159 //------------------------------------------------------------------------------ 160 // Check if has any array of given type 161 //------------------------------------------------------------------------------ 162 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, 163 CeedMemType mem_type, bool *has_array_of_type) { 164 int ierr; 165 CeedVector_Hip *impl; 166 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 167 168 switch (mem_type) { 169 case CEED_MEM_HOST: 170 *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned; 171 break; 172 case CEED_MEM_DEVICE: 173 *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned; 174 break; 175 } 176 177 return CEED_ERROR_SUCCESS; 178 } 179 180 //------------------------------------------------------------------------------ 181 // Check if has borrowed array of given type 182 //------------------------------------------------------------------------------ 183 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, 184 CeedMemType mem_type, bool *has_borrowed_array_of_type) { 185 int ierr; 186 CeedVector_Hip *impl; 187 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 188 189 switch (mem_type) { 190 case CEED_MEM_HOST: 191 *has_borrowed_array_of_type = !!impl->h_array_borrowed; 192 break; 193 case CEED_MEM_DEVICE: 194 *has_borrowed_array_of_type = !!impl->d_array_borrowed; 195 break; 196 } 197 198 return CEED_ERROR_SUCCESS; 199 } 200 201 //------------------------------------------------------------------------------ 202 // Set array from host 203 //------------------------------------------------------------------------------ 204 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, 205 const CeedCopyMode copy_mode, CeedScalar *array) { 206 int ierr; 207 CeedVector_Hip *impl; 208 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 209 210 switch (copy_mode) { 211 case CEED_COPY_VALUES: { 212 CeedSize length; 213 if (!impl->h_array_owned) { 214 ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr); 215 ierr = CeedMalloc(length, &impl->h_array_owned); CeedChkBackend(ierr); 216 } 217 impl->h_array_borrowed = NULL; 218 impl->h_array = impl->h_array_owned; 219 if (array) { 220 CeedSize length; 221 ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr); 222 size_t bytes = length * sizeof(CeedScalar); 223 memcpy(impl->h_array, array, bytes); 224 } 225 } break; 226 case CEED_OWN_POINTER: 227 ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr); 228 impl->h_array_owned = array; 229 impl->h_array_borrowed = NULL; 230 impl->h_array = array; 231 break; 232 case CEED_USE_POINTER: 233 ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr); 234 impl->h_array_borrowed = array; 235 impl->h_array = array; 236 break; 237 } 238 239 return CEED_ERROR_SUCCESS; 240 } 241 242 //------------------------------------------------------------------------------ 243 // Set array from device 244 //------------------------------------------------------------------------------ 245 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, 246 const CeedCopyMode copy_mode, CeedScalar *array) { 247 int ierr; 248 Ceed ceed; 249 ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr); 250 CeedVector_Hip *impl; 251 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 252 253 switch (copy_mode) { 254 case CEED_COPY_VALUES: { 255 CeedSize length; 256 ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr); 257 size_t bytes = length * sizeof(CeedScalar); 258 if (!impl->d_array_owned) { 259 ierr = hipMalloc((void **)&impl->d_array_owned, bytes); 260 CeedChk_Hip(ceed, ierr); 261 } 262 impl->d_array_borrowed = NULL; 263 impl->d_array = impl->d_array_owned; 264 if (array) { 265 ierr = hipMemcpy(impl->d_array, array, bytes, 266 hipMemcpyDeviceToDevice); CeedChk_Hip(ceed, ierr); 267 } 268 } break; 269 case CEED_OWN_POINTER: 270 ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr); 271 impl->d_array_owned = array; 272 impl->d_array_borrowed = NULL; 273 impl->d_array = array; 274 break; 275 case CEED_USE_POINTER: 276 ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr); 277 impl->d_array_owned = NULL; 278 impl->d_array_borrowed = array; 279 impl->d_array = array; 280 break; 281 } 282 283 return CEED_ERROR_SUCCESS; 284 } 285 286 //------------------------------------------------------------------------------ 287 // Set the array used by a vector, 288 // freeing any previously allocated array if applicable 289 //------------------------------------------------------------------------------ 290 static int CeedVectorSetArray_Hip(const CeedVector vec, 291 const CeedMemType mem_type, 292 const CeedCopyMode copy_mode, CeedScalar *array) { 293 int ierr; 294 Ceed ceed; 295 ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr); 296 CeedVector_Hip *impl; 297 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 298 299 ierr = CeedVectorSetAllInvalid_Hip(vec); CeedChkBackend(ierr); 300 switch (mem_type) { 301 case CEED_MEM_HOST: 302 return CeedVectorSetArrayHost_Hip(vec, copy_mode, array); 303 case CEED_MEM_DEVICE: 304 return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array); 305 } 306 307 return CEED_ERROR_UNSUPPORTED; 308 } 309 310 //------------------------------------------------------------------------------ 311 // Set host array to value 312 //------------------------------------------------------------------------------ 313 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedInt length, 314 CeedScalar val) { 315 for (int i = 0; i < length; i++) 316 h_array[i] = val; 317 return CEED_ERROR_SUCCESS; 318 } 319 320 //------------------------------------------------------------------------------ 321 // Set device array to value (impl in .hip file) 322 //------------------------------------------------------------------------------ 323 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedInt length, CeedScalar val); 324 325 //------------------------------------------------------------------------------ 326 // Set a vector to a value, 327 //------------------------------------------------------------------------------ 328 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) { 329 int ierr; 330 Ceed ceed; 331 ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr); 332 CeedVector_Hip *impl; 333 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 334 CeedSize length; 335 ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr); 336 337 // Set value for synced device/host array 338 if (!impl->d_array && !impl->h_array) { 339 if (impl->d_array_borrowed) { 340 impl->d_array = impl->d_array_borrowed; 341 } else if (impl->h_array_borrowed) { 342 impl->h_array = impl->h_array_borrowed; 343 } else if (impl->d_array_owned) { 344 impl->d_array = impl->d_array_owned; 345 } else if (impl->h_array_owned) { 346 impl->h_array = impl->h_array_owned; 347 } else { 348 ierr = CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL); 349 CeedChkBackend(ierr); 350 } 351 } 352 if (impl->d_array) { 353 ierr = CeedDeviceSetValue_Hip(impl->d_array, length, val); CeedChkBackend(ierr); 354 } 355 if (impl->h_array) { 356 ierr = CeedHostSetValue_Hip(impl->h_array, length, val); CeedChkBackend(ierr); 357 } 358 359 return CEED_ERROR_SUCCESS; 360 } 361 362 //------------------------------------------------------------------------------ 363 // Vector Take Array 364 //------------------------------------------------------------------------------ 365 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, 366 CeedScalar **array) { 367 int ierr; 368 Ceed ceed; 369 ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr); 370 CeedVector_Hip *impl; 371 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 372 373 // Sync array to requested mem_type 374 ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr); 375 376 // Update pointer 377 switch (mem_type) { 378 case CEED_MEM_HOST: 379 (*array) = impl->h_array_borrowed; 380 impl->h_array_borrowed = NULL; 381 impl->h_array = NULL; 382 break; 383 case CEED_MEM_DEVICE: 384 (*array) = impl->d_array_borrowed; 385 impl->d_array_borrowed = NULL; 386 impl->d_array = NULL; 387 break; 388 } 389 390 return CEED_ERROR_SUCCESS; 391 } 392 393 //------------------------------------------------------------------------------ 394 // Core logic for array syncronization for GetArray. 395 // If a different memory type is most up to date, this will perform a copy 396 //------------------------------------------------------------------------------ 397 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, 398 const CeedMemType mem_type, CeedScalar **array) { 399 int ierr; 400 Ceed ceed; 401 ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr); 402 CeedVector_Hip *impl; 403 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 404 405 // Sync array to requested mem_type 406 ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr); 407 408 // Update pointer 409 switch (mem_type) { 410 case CEED_MEM_HOST: 411 *array = impl->h_array; 412 break; 413 case CEED_MEM_DEVICE: 414 *array = impl->d_array; 415 break; 416 } 417 418 return CEED_ERROR_SUCCESS; 419 } 420 421 //------------------------------------------------------------------------------ 422 // Get read-only access to a vector via the specified mem_type 423 //------------------------------------------------------------------------------ 424 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, 425 const CeedMemType mem_type, const CeedScalar **array) { 426 return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array); 427 } 428 429 //------------------------------------------------------------------------------ 430 // Get read/write access to a vector via the specified mem_type 431 //------------------------------------------------------------------------------ 432 static int CeedVectorGetArray_Hip(const CeedVector vec, 433 const CeedMemType mem_type, 434 CeedScalar **array) { 435 int ierr; 436 CeedVector_Hip *impl; 437 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 438 439 ierr = CeedVectorGetArrayCore_Hip(vec, mem_type, array); CeedChkBackend(ierr); 440 441 ierr = CeedVectorSetAllInvalid_Hip(vec); CeedChkBackend(ierr); 442 switch (mem_type) { 443 case CEED_MEM_HOST: 444 impl->h_array = *array; 445 break; 446 case CEED_MEM_DEVICE: 447 impl->d_array = *array; 448 break; 449 } 450 451 return CEED_ERROR_SUCCESS; 452 } 453 454 //------------------------------------------------------------------------------ 455 // Get write access to a vector via the specified mem_type 456 //------------------------------------------------------------------------------ 457 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, 458 const CeedMemType mem_type, CeedScalar **array) { 459 int ierr; 460 CeedVector_Hip *impl; 461 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 462 463 bool has_array_of_type = true; 464 ierr = CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type); 465 CeedChkBackend(ierr); 466 if (!has_array_of_type) { 467 // Allocate if array is not yet allocated 468 ierr = CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL); 469 CeedChkBackend(ierr); 470 } else { 471 // Select dirty array 472 switch (mem_type) { 473 case CEED_MEM_HOST: 474 if (impl->h_array_borrowed) 475 impl->h_array = impl->h_array_borrowed; 476 else 477 impl->h_array = impl->h_array_owned; 478 break; 479 case CEED_MEM_DEVICE: 480 if (impl->d_array_borrowed) 481 impl->d_array = impl->d_array_borrowed; 482 else 483 impl->d_array = impl->d_array_owned; 484 } 485 } 486 487 return CeedVectorGetArray_Hip(vec, mem_type, array); 488 } 489 490 //------------------------------------------------------------------------------ 491 // Get the norm of a CeedVector 492 //------------------------------------------------------------------------------ 493 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, 494 CeedScalar *norm) { 495 int ierr; 496 Ceed ceed; 497 ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr); 498 CeedVector_Hip *impl; 499 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 500 CeedSize length; 501 ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr); 502 hipblasHandle_t handle; 503 ierr = CeedHipGetHipblasHandle(ceed, &handle); CeedChkBackend(ierr); 504 505 // Compute norm 506 const CeedScalar *d_array; 507 ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array); 508 CeedChkBackend(ierr); 509 switch (type) { 510 case CEED_NORM_1: { 511 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 512 ierr = hipblasSasum(handle, length, (float *) d_array, 1, (float *) norm); 513 } else { 514 ierr = hipblasDasum(handle, length, (double *) d_array, 1, (double *) norm); 515 } 516 CeedChk_Hipblas(ceed, ierr); 517 break; 518 } 519 case CEED_NORM_2: { 520 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 521 ierr = hipblasSnrm2(handle, length, (float *) d_array, 1, (float *) norm); 522 } else { 523 ierr = hipblasDnrm2(handle, length, (double *) d_array, 1, (double *) norm); 524 } 525 CeedChk_Hipblas(ceed, ierr); 526 break; 527 } 528 case CEED_NORM_MAX: { 529 CeedInt indx; 530 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 531 ierr = hipblasIsamax(handle, length, (float *) d_array, 1, &indx); 532 } else { 533 ierr = hipblasIdamax(handle, length, (double *) d_array, 1, &indx); 534 } 535 CeedChk_Hipblas(ceed, ierr); 536 CeedScalar normNoAbs; 537 ierr = hipMemcpy(&normNoAbs, impl->d_array+indx-1, sizeof(CeedScalar), 538 hipMemcpyDeviceToHost); CeedChk_Hip(ceed, ierr); 539 *norm = fabs(normNoAbs); 540 break; 541 } 542 } 543 ierr = CeedVectorRestoreArrayRead(vec, &d_array); CeedChkBackend(ierr); 544 545 return CEED_ERROR_SUCCESS; 546 } 547 548 //------------------------------------------------------------------------------ 549 // Take reciprocal of a vector on host 550 //------------------------------------------------------------------------------ 551 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedInt length) { 552 for (int i = 0; i < length; i++) 553 if (fabs(h_array[i]) > CEED_EPSILON) 554 h_array[i] = 1./h_array[i]; 555 return CEED_ERROR_SUCCESS; 556 } 557 558 //------------------------------------------------------------------------------ 559 // Take reciprocal of a vector on device (impl in .cu file) 560 //------------------------------------------------------------------------------ 561 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedInt length); 562 563 //------------------------------------------------------------------------------ 564 // Take reciprocal of a vector 565 //------------------------------------------------------------------------------ 566 static int CeedVectorReciprocal_Hip(CeedVector vec) { 567 int ierr; 568 Ceed ceed; 569 ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr); 570 CeedVector_Hip *impl; 571 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 572 CeedSize length; 573 ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr); 574 575 // Set value for synced device/host array 576 if (impl->d_array) { 577 ierr = CeedDeviceReciprocal_Hip(impl->d_array, length); CeedChkBackend(ierr); 578 } 579 if (impl->h_array) { 580 ierr = CeedHostReciprocal_Hip(impl->h_array, length); CeedChkBackend(ierr); 581 } 582 583 return CEED_ERROR_SUCCESS; 584 } 585 586 //------------------------------------------------------------------------------ 587 // Compute x = alpha x on the host 588 //------------------------------------------------------------------------------ 589 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, 590 CeedInt length) { 591 for (int i = 0; i < length; i++) 592 x_array[i] *= alpha; 593 return CEED_ERROR_SUCCESS; 594 } 595 596 //------------------------------------------------------------------------------ 597 // Compute x = alpha x on device (impl in .cu file) 598 //------------------------------------------------------------------------------ 599 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, 600 CeedInt length); 601 602 //------------------------------------------------------------------------------ 603 // Compute x = alpha x 604 //------------------------------------------------------------------------------ 605 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) { 606 int ierr; 607 Ceed ceed; 608 ierr = CeedVectorGetCeed(x, &ceed); CeedChkBackend(ierr); 609 CeedVector_Hip *x_impl; 610 ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr); 611 CeedSize length; 612 ierr = CeedVectorGetLength(x, &length); CeedChkBackend(ierr); 613 614 // Set value for synced device/host array 615 if (x_impl->d_array) { 616 ierr = CeedDeviceScale_Hip(x_impl->d_array, alpha, length); 617 CeedChkBackend(ierr); 618 } 619 if (x_impl->h_array) { 620 ierr = CeedHostScale_Hip(x_impl->h_array, alpha, length); CeedChkBackend(ierr); 621 } 622 623 return CEED_ERROR_SUCCESS; 624 } 625 626 //------------------------------------------------------------------------------ 627 // Compute y = alpha x + y on the host 628 //------------------------------------------------------------------------------ 629 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, 630 CeedScalar *x_array, CeedInt length) { 631 for (int i = 0; i < length; i++) 632 y_array[i] += alpha * x_array[i]; 633 return CEED_ERROR_SUCCESS; 634 } 635 636 //------------------------------------------------------------------------------ 637 // Compute y = alpha x + y on device (impl in .cu file) 638 //------------------------------------------------------------------------------ 639 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, 640 CeedScalar *x_array, CeedInt length); 641 642 //------------------------------------------------------------------------------ 643 // Compute y = alpha x + y 644 //------------------------------------------------------------------------------ 645 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) { 646 int ierr; 647 Ceed ceed; 648 ierr = CeedVectorGetCeed(y, &ceed); CeedChkBackend(ierr); 649 CeedVector_Hip *y_impl, *x_impl; 650 ierr = CeedVectorGetData(y, &y_impl); CeedChkBackend(ierr); 651 ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr); 652 CeedSize length; 653 ierr = CeedVectorGetLength(y, &length); CeedChkBackend(ierr); 654 655 // Set value for synced device/host array 656 if (y_impl->d_array) { 657 ierr = CeedVectorSyncArray(x, CEED_MEM_DEVICE); CeedChkBackend(ierr); 658 ierr = CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length); 659 CeedChkBackend(ierr); 660 } 661 if (y_impl->h_array) { 662 ierr = CeedVectorSyncArray(x, CEED_MEM_HOST); CeedChkBackend(ierr); 663 ierr = CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length); 664 CeedChkBackend(ierr); 665 } 666 667 return CEED_ERROR_SUCCESS; 668 } 669 670 //------------------------------------------------------------------------------ 671 // Compute the pointwise multiplication w = x .* y on the host 672 //------------------------------------------------------------------------------ 673 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, 674 CeedScalar *y_array, CeedInt length) { 675 for (int i = 0; i < length; i++) 676 w_array[i] = x_array[i] * y_array[i]; 677 return CEED_ERROR_SUCCESS; 678 } 679 680 //------------------------------------------------------------------------------ 681 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 682 //------------------------------------------------------------------------------ 683 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, 684 CeedScalar *y_array, CeedInt length); 685 686 //------------------------------------------------------------------------------ 687 // Compute the pointwise multiplication w = x .* y 688 //------------------------------------------------------------------------------ 689 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, 690 CeedVector y) { 691 int ierr; 692 Ceed ceed; 693 ierr = CeedVectorGetCeed(w, &ceed); CeedChkBackend(ierr); 694 CeedVector_Hip *w_impl, *x_impl, *y_impl; 695 ierr = CeedVectorGetData(w, &w_impl); CeedChkBackend(ierr); 696 ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr); 697 ierr = CeedVectorGetData(y, &y_impl); CeedChkBackend(ierr); 698 CeedSize length; 699 ierr = CeedVectorGetLength(w, &length); CeedChkBackend(ierr); 700 701 // Set value for synced device/host array 702 if (!w_impl->d_array && !w_impl->h_array) { 703 ierr = CeedVectorSetValue(w, 0.0); CeedChkBackend(ierr); 704 } 705 if (w_impl->d_array) { 706 ierr = CeedVectorSyncArray(x, CEED_MEM_DEVICE); CeedChkBackend(ierr); 707 ierr = CeedVectorSyncArray(y, CEED_MEM_DEVICE); CeedChkBackend(ierr); 708 ierr = CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, 709 y_impl->d_array, length); 710 CeedChkBackend(ierr); 711 } 712 if (w_impl->h_array) { 713 ierr = CeedVectorSyncArray(x, CEED_MEM_HOST); CeedChkBackend(ierr); 714 ierr = CeedVectorSyncArray(y, CEED_MEM_HOST); CeedChkBackend(ierr); 715 ierr = CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, 716 y_impl->h_array, length); 717 CeedChkBackend(ierr); 718 } 719 720 return CEED_ERROR_SUCCESS; 721 } 722 723 //------------------------------------------------------------------------------ 724 // Destroy the vector 725 //------------------------------------------------------------------------------ 726 static int CeedVectorDestroy_Hip(const CeedVector vec) { 727 int ierr; 728 Ceed ceed; 729 ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr); 730 CeedVector_Hip *impl; 731 ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); 732 733 ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr); 734 ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr); 735 ierr = CeedFree(&impl); CeedChkBackend(ierr); 736 737 return CEED_ERROR_SUCCESS; 738 } 739 740 //------------------------------------------------------------------------------ 741 // Create a vector of the specified length (does not allocate memory) 742 //------------------------------------------------------------------------------ 743 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) { 744 CeedVector_Hip *impl; 745 int ierr; 746 Ceed ceed; 747 ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr); 748 749 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", 750 CeedVectorHasValidArray_Hip); CeedChkBackend(ierr); 751 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", 752 CeedVectorHasBorrowedArrayOfType_Hip); 753 CeedChkBackend(ierr); 754 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", 755 CeedVectorSetArray_Hip); CeedChkBackend(ierr); 756 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", 757 CeedVectorTakeArray_Hip); CeedChkBackend(ierr); 758 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", 759 (int (*)())(CeedVectorSetValue_Hip)); CeedChkBackend(ierr); 760 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", 761 CeedVectorSyncArray_Hip); CeedChkBackend(ierr); 762 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", 763 CeedVectorGetArray_Hip); CeedChkBackend(ierr); 764 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", 765 CeedVectorGetArrayRead_Hip); CeedChkBackend(ierr); 766 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", 767 CeedVectorGetArrayWrite_Hip); CeedChkBackend(ierr); 768 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Norm", 769 CeedVectorNorm_Hip); CeedChkBackend(ierr); 770 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", 771 CeedVectorReciprocal_Hip); CeedChkBackend(ierr); 772 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Scale", 773 (int (*)())(CeedVectorScale_Hip)); CeedChkBackend(ierr); 774 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", 775 (int (*)())(CeedVectorAXPY_Hip)); CeedChkBackend(ierr); 776 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", 777 CeedVectorPointwiseMult_Hip); CeedChkBackend(ierr); 778 ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", 779 CeedVectorDestroy_Hip); CeedChkBackend(ierr); 780 781 ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr); 782 ierr = CeedVectorSetData(vec, impl); CeedChkBackend(ierr); 783 784 return CEED_ERROR_SUCCESS; 785 } 786