1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <math.h> 11 #include <stdbool.h> 12 #include <string.h> 13 #include <hip/hip_runtime.h> 14 15 #include "../hip/ceed-hip-common.h" 16 #include "ceed-hip-ref.h" 17 18 //------------------------------------------------------------------------------ 19 // Check if host/device sync is needed 20 //------------------------------------------------------------------------------ 21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 22 CeedVector_Hip *impl; 23 CeedCallBackend(CeedVectorGetData(vec, &impl)); 24 25 bool has_valid_array = false; 26 CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 27 switch (mem_type) { 28 case CEED_MEM_HOST: 29 *need_sync = has_valid_array && !impl->h_array; 30 break; 31 case CEED_MEM_DEVICE: 32 *need_sync = has_valid_array && !impl->d_array; 33 break; 34 } 35 36 return CEED_ERROR_SUCCESS; 37 } 38 39 //------------------------------------------------------------------------------ 40 // Sync host to device 41 //------------------------------------------------------------------------------ 42 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) { 43 Ceed ceed; 44 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 45 CeedVector_Hip *impl; 46 CeedCallBackend(CeedVectorGetData(vec, &impl)); 47 48 CeedSize length; 49 CeedCallBackend(CeedVectorGetLength(vec, &length)); 50 size_t bytes = length * sizeof(CeedScalar); 51 52 if (!impl->h_array) { 53 // LCOV_EXCL_START 54 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 55 // LCOV_EXCL_STOP 56 } 57 58 if (impl->d_array_borrowed) { 59 impl->d_array = impl->d_array_borrowed; 60 } else if (impl->d_array_owned) { 61 impl->d_array = impl->d_array_owned; 62 } else { 63 CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes)); 64 impl->d_array = impl->d_array_owned; 65 } 66 67 CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice)); 68 69 return CEED_ERROR_SUCCESS; 70 } 71 72 //------------------------------------------------------------------------------ 73 // Sync device to host 74 //------------------------------------------------------------------------------ 75 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) { 76 Ceed ceed; 77 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 78 CeedVector_Hip *impl; 79 CeedCallBackend(CeedVectorGetData(vec, &impl)); 80 81 if (!impl->d_array) { 82 // LCOV_EXCL_START 83 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 84 // LCOV_EXCL_STOP 85 } 86 87 if (impl->h_array_borrowed) { 88 impl->h_array = impl->h_array_borrowed; 89 } else if (impl->h_array_owned) { 90 impl->h_array = impl->h_array_owned; 91 } else { 92 CeedSize length; 93 CeedCallBackend(CeedVectorGetLength(vec, &length)); 94 CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 95 impl->h_array = impl->h_array_owned; 96 } 97 98 CeedSize length; 99 CeedCallBackend(CeedVectorGetLength(vec, &length)); 100 size_t bytes = length * sizeof(CeedScalar); 101 CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost)); 102 103 return CEED_ERROR_SUCCESS; 104 } 105 106 //------------------------------------------------------------------------------ 107 // Sync arrays 108 //------------------------------------------------------------------------------ 109 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) { 110 // Check whether device/host sync is needed 111 bool need_sync = false; 112 CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync)); 113 if (!need_sync) return CEED_ERROR_SUCCESS; 114 115 switch (mem_type) { 116 case CEED_MEM_HOST: 117 return CeedVectorSyncD2H_Hip(vec); 118 case CEED_MEM_DEVICE: 119 return CeedVectorSyncH2D_Hip(vec); 120 } 121 return CEED_ERROR_UNSUPPORTED; 122 } 123 124 //------------------------------------------------------------------------------ 125 // Set all pointers as invalid 126 //------------------------------------------------------------------------------ 127 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) { 128 CeedVector_Hip *impl; 129 CeedCallBackend(CeedVectorGetData(vec, &impl)); 130 131 impl->h_array = NULL; 132 impl->d_array = NULL; 133 134 return CEED_ERROR_SUCCESS; 135 } 136 137 //------------------------------------------------------------------------------ 138 // Check if CeedVector has any valid pointers 139 //------------------------------------------------------------------------------ 140 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) { 141 CeedVector_Hip *impl; 142 CeedCallBackend(CeedVectorGetData(vec, &impl)); 143 144 *has_valid_array = !!impl->h_array || !!impl->d_array; 145 146 return CEED_ERROR_SUCCESS; 147 } 148 149 //------------------------------------------------------------------------------ 150 // Check if has any array of given type 151 //------------------------------------------------------------------------------ 152 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 153 CeedVector_Hip *impl; 154 CeedCallBackend(CeedVectorGetData(vec, &impl)); 155 156 switch (mem_type) { 157 case CEED_MEM_HOST: 158 *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned; 159 break; 160 case CEED_MEM_DEVICE: 161 *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned; 162 break; 163 } 164 165 return CEED_ERROR_SUCCESS; 166 } 167 168 //------------------------------------------------------------------------------ 169 // Check if has borrowed array of given type 170 //------------------------------------------------------------------------------ 171 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 172 CeedVector_Hip *impl; 173 CeedCallBackend(CeedVectorGetData(vec, &impl)); 174 175 switch (mem_type) { 176 case CEED_MEM_HOST: 177 *has_borrowed_array_of_type = !!impl->h_array_borrowed; 178 break; 179 case CEED_MEM_DEVICE: 180 *has_borrowed_array_of_type = !!impl->d_array_borrowed; 181 break; 182 } 183 184 return CEED_ERROR_SUCCESS; 185 } 186 187 //------------------------------------------------------------------------------ 188 // Set array from host 189 //------------------------------------------------------------------------------ 190 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 191 CeedVector_Hip *impl; 192 CeedCallBackend(CeedVectorGetData(vec, &impl)); 193 194 switch (copy_mode) { 195 case CEED_COPY_VALUES: { 196 CeedSize length; 197 if (!impl->h_array_owned) { 198 CeedCallBackend(CeedVectorGetLength(vec, &length)); 199 CeedCallBackend(CeedMalloc(length, &impl->h_array_owned)); 200 } 201 impl->h_array_borrowed = NULL; 202 impl->h_array = impl->h_array_owned; 203 if (array) { 204 CeedSize length; 205 CeedCallBackend(CeedVectorGetLength(vec, &length)); 206 size_t bytes = length * sizeof(CeedScalar); 207 memcpy(impl->h_array, array, bytes); 208 } 209 } break; 210 case CEED_OWN_POINTER: 211 CeedCallBackend(CeedFree(&impl->h_array_owned)); 212 impl->h_array_owned = array; 213 impl->h_array_borrowed = NULL; 214 impl->h_array = array; 215 break; 216 case CEED_USE_POINTER: 217 CeedCallBackend(CeedFree(&impl->h_array_owned)); 218 impl->h_array_borrowed = array; 219 impl->h_array = array; 220 break; 221 } 222 223 return CEED_ERROR_SUCCESS; 224 } 225 226 //------------------------------------------------------------------------------ 227 // Set array from device 228 //------------------------------------------------------------------------------ 229 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 230 Ceed ceed; 231 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 232 CeedVector_Hip *impl; 233 CeedCallBackend(CeedVectorGetData(vec, &impl)); 234 235 switch (copy_mode) { 236 case CEED_COPY_VALUES: { 237 CeedSize length; 238 CeedCallBackend(CeedVectorGetLength(vec, &length)); 239 size_t bytes = length * sizeof(CeedScalar); 240 if (!impl->d_array_owned) { 241 CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes)); 242 } 243 impl->d_array_borrowed = NULL; 244 impl->d_array = impl->d_array_owned; 245 if (array) { 246 CeedCallHip(ceed, hipMemcpy(impl->d_array, array, bytes, hipMemcpyDeviceToDevice)); 247 } 248 } break; 249 case CEED_OWN_POINTER: 250 CeedCallHip(ceed, hipFree(impl->d_array_owned)); 251 impl->d_array_owned = array; 252 impl->d_array_borrowed = NULL; 253 impl->d_array = array; 254 break; 255 case CEED_USE_POINTER: 256 CeedCallHip(ceed, hipFree(impl->d_array_owned)); 257 impl->d_array_owned = NULL; 258 impl->d_array_borrowed = array; 259 impl->d_array = array; 260 break; 261 } 262 263 return CEED_ERROR_SUCCESS; 264 } 265 266 //------------------------------------------------------------------------------ 267 // Set the array used by a vector, 268 // freeing any previously allocated array if applicable 269 //------------------------------------------------------------------------------ 270 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 271 Ceed ceed; 272 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 273 CeedVector_Hip *impl; 274 CeedCallBackend(CeedVectorGetData(vec, &impl)); 275 276 CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec)); 277 switch (mem_type) { 278 case CEED_MEM_HOST: 279 return CeedVectorSetArrayHost_Hip(vec, copy_mode, array); 280 case CEED_MEM_DEVICE: 281 return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array); 282 } 283 284 return CEED_ERROR_UNSUPPORTED; 285 } 286 287 //------------------------------------------------------------------------------ 288 // Set host array to value 289 //------------------------------------------------------------------------------ 290 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedInt length, CeedScalar val) { 291 for (int i = 0; i < length; i++) h_array[i] = val; 292 return CEED_ERROR_SUCCESS; 293 } 294 295 //------------------------------------------------------------------------------ 296 // Set device array to value (impl in .hip file) 297 //------------------------------------------------------------------------------ 298 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedInt length, CeedScalar val); 299 300 //------------------------------------------------------------------------------ 301 // Set a vector to a value, 302 //------------------------------------------------------------------------------ 303 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) { 304 Ceed ceed; 305 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 306 CeedVector_Hip *impl; 307 CeedCallBackend(CeedVectorGetData(vec, &impl)); 308 CeedSize length; 309 CeedCallBackend(CeedVectorGetLength(vec, &length)); 310 311 // Set value for synced device/host array 312 if (!impl->d_array && !impl->h_array) { 313 if (impl->d_array_borrowed) { 314 impl->d_array = impl->d_array_borrowed; 315 } else if (impl->h_array_borrowed) { 316 impl->h_array = impl->h_array_borrowed; 317 } else if (impl->d_array_owned) { 318 impl->d_array = impl->d_array_owned; 319 } else if (impl->h_array_owned) { 320 impl->h_array = impl->h_array_owned; 321 } else { 322 CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 323 } 324 } 325 if (impl->d_array) { 326 CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val)); 327 } 328 if (impl->h_array) { 329 CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val)); 330 } 331 332 return CEED_ERROR_SUCCESS; 333 } 334 335 //------------------------------------------------------------------------------ 336 // Vector Take Array 337 //------------------------------------------------------------------------------ 338 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 339 Ceed ceed; 340 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 341 CeedVector_Hip *impl; 342 CeedCallBackend(CeedVectorGetData(vec, &impl)); 343 344 // Sync array to requested mem_type 345 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 346 347 // Update pointer 348 switch (mem_type) { 349 case CEED_MEM_HOST: 350 (*array) = impl->h_array_borrowed; 351 impl->h_array_borrowed = NULL; 352 impl->h_array = NULL; 353 break; 354 case CEED_MEM_DEVICE: 355 (*array) = impl->d_array_borrowed; 356 impl->d_array_borrowed = NULL; 357 impl->d_array = NULL; 358 break; 359 } 360 361 return CEED_ERROR_SUCCESS; 362 } 363 364 //------------------------------------------------------------------------------ 365 // Core logic for array syncronization for GetArray. 366 // If a different memory type is most up to date, this will perform a copy 367 //------------------------------------------------------------------------------ 368 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 369 Ceed ceed; 370 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 371 CeedVector_Hip *impl; 372 CeedCallBackend(CeedVectorGetData(vec, &impl)); 373 374 // Sync array to requested mem_type 375 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 376 377 // Update pointer 378 switch (mem_type) { 379 case CEED_MEM_HOST: 380 *array = impl->h_array; 381 break; 382 case CEED_MEM_DEVICE: 383 *array = impl->d_array; 384 break; 385 } 386 387 return CEED_ERROR_SUCCESS; 388 } 389 390 //------------------------------------------------------------------------------ 391 // Get read-only access to a vector via the specified mem_type 392 //------------------------------------------------------------------------------ 393 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 394 return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array); 395 } 396 397 //------------------------------------------------------------------------------ 398 // Get read/write access to a vector via the specified mem_type 399 //------------------------------------------------------------------------------ 400 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 401 CeedVector_Hip *impl; 402 CeedCallBackend(CeedVectorGetData(vec, &impl)); 403 404 CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array)); 405 406 CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec)); 407 switch (mem_type) { 408 case CEED_MEM_HOST: 409 impl->h_array = *array; 410 break; 411 case CEED_MEM_DEVICE: 412 impl->d_array = *array; 413 break; 414 } 415 416 return CEED_ERROR_SUCCESS; 417 } 418 419 //------------------------------------------------------------------------------ 420 // Get write access to a vector via the specified mem_type 421 //------------------------------------------------------------------------------ 422 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 423 CeedVector_Hip *impl; 424 CeedCallBackend(CeedVectorGetData(vec, &impl)); 425 426 bool has_array_of_type = true; 427 CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type)); 428 if (!has_array_of_type) { 429 // Allocate if array is not yet allocated 430 CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 431 } else { 432 // Select dirty array 433 switch (mem_type) { 434 case CEED_MEM_HOST: 435 if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 436 else impl->h_array = impl->h_array_owned; 437 break; 438 case CEED_MEM_DEVICE: 439 if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 440 else impl->d_array = impl->d_array_owned; 441 } 442 } 443 444 return CeedVectorGetArray_Hip(vec, mem_type, array); 445 } 446 447 //------------------------------------------------------------------------------ 448 // Get the norm of a CeedVector 449 //------------------------------------------------------------------------------ 450 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) { 451 Ceed ceed; 452 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 453 CeedVector_Hip *impl; 454 CeedCallBackend(CeedVectorGetData(vec, &impl)); 455 CeedSize length; 456 CeedCallBackend(CeedVectorGetLength(vec, &length)); 457 hipblasHandle_t handle; 458 CeedCallBackend(CeedHipGetHipblasHandle(ceed, &handle)); 459 460 // Compute norm 461 const CeedScalar *d_array; 462 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 463 switch (type) { 464 case CEED_NORM_1: { 465 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 466 CeedCallHipblas(ceed, hipblasSasum(handle, length, (float *)d_array, 1, (float *)norm)); 467 } else { 468 CeedCallHipblas(ceed, hipblasDasum(handle, length, (double *)d_array, 1, (double *)norm)); 469 } 470 break; 471 } 472 case CEED_NORM_2: { 473 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 474 CeedCallHipblas(ceed, hipblasSnrm2(handle, length, (float *)d_array, 1, (float *)norm)); 475 } else { 476 CeedCallHipblas(ceed, hipblasDnrm2(handle, length, (double *)d_array, 1, (double *)norm)); 477 } 478 break; 479 } 480 case CEED_NORM_MAX: { 481 CeedInt indx; 482 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 483 CeedCallHipblas(ceed, hipblasIsamax(handle, length, (float *)d_array, 1, &indx)); 484 } else { 485 CeedCallHipblas(ceed, hipblasIdamax(handle, length, (double *)d_array, 1, &indx)); 486 } 487 CeedScalar normNoAbs; 488 CeedCallHip(ceed, hipMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost)); 489 *norm = fabs(normNoAbs); 490 break; 491 } 492 } 493 CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 494 495 return CEED_ERROR_SUCCESS; 496 } 497 498 //------------------------------------------------------------------------------ 499 // Take reciprocal of a vector on host 500 //------------------------------------------------------------------------------ 501 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedInt length) { 502 for (int i = 0; i < length; i++) { 503 if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 504 } 505 return CEED_ERROR_SUCCESS; 506 } 507 508 //------------------------------------------------------------------------------ 509 // Take reciprocal of a vector on device (impl in .cu file) 510 //------------------------------------------------------------------------------ 511 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedInt length); 512 513 //------------------------------------------------------------------------------ 514 // Take reciprocal of a vector 515 //------------------------------------------------------------------------------ 516 static int CeedVectorReciprocal_Hip(CeedVector vec) { 517 Ceed ceed; 518 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 519 CeedVector_Hip *impl; 520 CeedCallBackend(CeedVectorGetData(vec, &impl)); 521 CeedSize length; 522 CeedCallBackend(CeedVectorGetLength(vec, &length)); 523 524 // Set value for synced device/host array 525 if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length)); 526 if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length)); 527 528 return CEED_ERROR_SUCCESS; 529 } 530 531 //------------------------------------------------------------------------------ 532 // Compute x = alpha x on the host 533 //------------------------------------------------------------------------------ 534 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedInt length) { 535 for (int i = 0; i < length; i++) x_array[i] *= alpha; 536 return CEED_ERROR_SUCCESS; 537 } 538 539 //------------------------------------------------------------------------------ 540 // Compute x = alpha x on device (impl in .cu file) 541 //------------------------------------------------------------------------------ 542 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedInt length); 543 544 //------------------------------------------------------------------------------ 545 // Compute x = alpha x 546 //------------------------------------------------------------------------------ 547 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) { 548 Ceed ceed; 549 CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 550 CeedVector_Hip *x_impl; 551 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 552 CeedSize length; 553 CeedCallBackend(CeedVectorGetLength(x, &length)); 554 555 // Set value for synced device/host array 556 if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length)); 557 if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length)); 558 559 return CEED_ERROR_SUCCESS; 560 } 561 562 //------------------------------------------------------------------------------ 563 // Compute y = alpha x + y on the host 564 //------------------------------------------------------------------------------ 565 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) { 566 for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 567 return CEED_ERROR_SUCCESS; 568 } 569 570 //------------------------------------------------------------------------------ 571 // Compute y = alpha x + y on device (impl in .cu file) 572 //------------------------------------------------------------------------------ 573 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length); 574 575 //------------------------------------------------------------------------------ 576 // Compute y = alpha x + y 577 //------------------------------------------------------------------------------ 578 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) { 579 Ceed ceed; 580 CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 581 CeedVector_Hip *y_impl, *x_impl; 582 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 583 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 584 CeedSize length; 585 CeedCallBackend(CeedVectorGetLength(y, &length)); 586 587 // Set value for synced device/host array 588 if (y_impl->d_array) { 589 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 590 CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length)); 591 } 592 if (y_impl->h_array) { 593 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 594 CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length)); 595 } 596 597 return CEED_ERROR_SUCCESS; 598 } 599 //------------------------------------------------------------------------------ 600 // Compute y = alpha x + beta y on the host 601 //------------------------------------------------------------------------------ 602 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length) { 603 for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i]; 604 return CEED_ERROR_SUCCESS; 605 } 606 607 //------------------------------------------------------------------------------ 608 // Compute y = alpha x + beta y on device (impl in .cu file) 609 //------------------------------------------------------------------------------ 610 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length); 611 612 //------------------------------------------------------------------------------ 613 // Compute y = alpha x + beta y 614 //------------------------------------------------------------------------------ 615 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) { 616 Ceed ceed; 617 CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 618 CeedVector_Hip *y_impl, *x_impl; 619 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 620 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 621 CeedSize length; 622 CeedCallBackend(CeedVectorGetLength(y, &length)); 623 624 // Set value for synced device/host array 625 if (y_impl->d_array) { 626 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 627 CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length)); 628 } 629 if (y_impl->h_array) { 630 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 631 CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length)); 632 } 633 634 return CEED_ERROR_SUCCESS; 635 } 636 637 //------------------------------------------------------------------------------ 638 // Compute the pointwise multiplication w = x .* y on the host 639 //------------------------------------------------------------------------------ 640 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) { 641 for (int i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 642 return CEED_ERROR_SUCCESS; 643 } 644 645 //------------------------------------------------------------------------------ 646 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 647 //------------------------------------------------------------------------------ 648 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length); 649 650 //------------------------------------------------------------------------------ 651 // Compute the pointwise multiplication w = x .* y 652 //------------------------------------------------------------------------------ 653 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) { 654 Ceed ceed; 655 CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 656 CeedVector_Hip *w_impl, *x_impl, *y_impl; 657 CeedCallBackend(CeedVectorGetData(w, &w_impl)); 658 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 659 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 660 CeedSize length; 661 CeedCallBackend(CeedVectorGetLength(w, &length)); 662 663 // Set value for synced device/host array 664 if (!w_impl->d_array && !w_impl->h_array) { 665 CeedCallBackend(CeedVectorSetValue(w, 0.0)); 666 } 667 if (w_impl->d_array) { 668 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 669 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 670 CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 671 } 672 if (w_impl->h_array) { 673 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 674 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 675 CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 676 } 677 678 return CEED_ERROR_SUCCESS; 679 } 680 681 //------------------------------------------------------------------------------ 682 // Destroy the vector 683 //------------------------------------------------------------------------------ 684 static int CeedVectorDestroy_Hip(const CeedVector vec) { 685 Ceed ceed; 686 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 687 CeedVector_Hip *impl; 688 CeedCallBackend(CeedVectorGetData(vec, &impl)); 689 690 CeedCallHip(ceed, hipFree(impl->d_array_owned)); 691 CeedCallBackend(CeedFree(&impl->h_array_owned)); 692 CeedCallBackend(CeedFree(&impl)); 693 694 return CEED_ERROR_SUCCESS; 695 } 696 697 //------------------------------------------------------------------------------ 698 // Create a vector of the specified length (does not allocate memory) 699 //------------------------------------------------------------------------------ 700 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) { 701 CeedVector_Hip *impl; 702 Ceed ceed; 703 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 704 705 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip)); 706 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip)); 707 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip)); 708 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip)); 709 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Hip)); 710 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip)); 711 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip)); 712 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip)); 713 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip)); 714 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip)); 715 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip)); 716 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Hip)); 717 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Hip)); 718 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Hip)); 719 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip)); 720 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip)); 721 722 CeedCallBackend(CeedCalloc(1, &impl)); 723 CeedCallBackend(CeedVectorSetData(vec, impl)); 724 725 return CEED_ERROR_SUCCESS; 726 } 727 728 //------------------------------------------------------------------------------ 729