1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 #include <hip/hip_runtime.h> 11 #include <math.h> 12 #include <string.h> 13 14 #include "ceed-hip-ref.h" 15 16 //------------------------------------------------------------------------------ 17 // Check if host/device sync is needed 18 //------------------------------------------------------------------------------ 19 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 20 CeedVector_Hip *impl; 21 CeedCallBackend(CeedVectorGetData(vec, &impl)); 22 23 bool has_valid_array = false; 24 CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 25 switch (mem_type) { 26 case CEED_MEM_HOST: 27 *need_sync = has_valid_array && !impl->h_array; 28 break; 29 case CEED_MEM_DEVICE: 30 *need_sync = has_valid_array && !impl->d_array; 31 break; 32 } 33 34 return CEED_ERROR_SUCCESS; 35 } 36 37 //------------------------------------------------------------------------------ 38 // Sync host to device 39 //------------------------------------------------------------------------------ 40 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) { 41 Ceed ceed; 42 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 43 CeedVector_Hip *impl; 44 CeedCallBackend(CeedVectorGetData(vec, &impl)); 45 46 CeedSize length; 47 CeedCallBackend(CeedVectorGetLength(vec, &length)); 48 size_t bytes = length * sizeof(CeedScalar); 49 50 if (!impl->h_array) { 51 // LCOV_EXCL_START 52 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 53 // LCOV_EXCL_STOP 54 } 55 56 if (impl->d_array_borrowed) { 57 impl->d_array = impl->d_array_borrowed; 58 } else if (impl->d_array_owned) { 59 impl->d_array = impl->d_array_owned; 60 } else { 61 CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes)); 62 impl->d_array = impl->d_array_owned; 63 } 64 65 CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice)); 66 67 return CEED_ERROR_SUCCESS; 68 } 69 70 //------------------------------------------------------------------------------ 71 // Sync device to host 72 //------------------------------------------------------------------------------ 73 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) { 74 Ceed ceed; 75 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 76 CeedVector_Hip *impl; 77 CeedCallBackend(CeedVectorGetData(vec, &impl)); 78 79 if (!impl->d_array) { 80 // LCOV_EXCL_START 81 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 82 // LCOV_EXCL_STOP 83 } 84 85 if (impl->h_array_borrowed) { 86 impl->h_array = impl->h_array_borrowed; 87 } else if (impl->h_array_owned) { 88 impl->h_array = impl->h_array_owned; 89 } else { 90 CeedSize length; 91 CeedCallBackend(CeedVectorGetLength(vec, &length)); 92 CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 93 impl->h_array = impl->h_array_owned; 94 } 95 96 CeedSize length; 97 CeedCallBackend(CeedVectorGetLength(vec, &length)); 98 size_t bytes = length * sizeof(CeedScalar); 99 CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost)); 100 101 return CEED_ERROR_SUCCESS; 102 } 103 104 //------------------------------------------------------------------------------ 105 // Sync arrays 106 //------------------------------------------------------------------------------ 107 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) { 108 // Check whether device/host sync is needed 109 bool need_sync = false; 110 CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync)); 111 if (!need_sync) return CEED_ERROR_SUCCESS; 112 113 switch (mem_type) { 114 case CEED_MEM_HOST: 115 return CeedVectorSyncD2H_Hip(vec); 116 case CEED_MEM_DEVICE: 117 return CeedVectorSyncH2D_Hip(vec); 118 } 119 return CEED_ERROR_UNSUPPORTED; 120 } 121 122 //------------------------------------------------------------------------------ 123 // Set all pointers as invalid 124 //------------------------------------------------------------------------------ 125 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) { 126 CeedVector_Hip *impl; 127 CeedCallBackend(CeedVectorGetData(vec, &impl)); 128 129 impl->h_array = NULL; 130 impl->d_array = NULL; 131 132 return CEED_ERROR_SUCCESS; 133 } 134 135 //------------------------------------------------------------------------------ 136 // Check if CeedVector has any valid pointers 137 //------------------------------------------------------------------------------ 138 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) { 139 CeedVector_Hip *impl; 140 CeedCallBackend(CeedVectorGetData(vec, &impl)); 141 142 *has_valid_array = !!impl->h_array || !!impl->d_array; 143 144 return CEED_ERROR_SUCCESS; 145 } 146 147 //------------------------------------------------------------------------------ 148 // Check if has any array of given type 149 //------------------------------------------------------------------------------ 150 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 151 CeedVector_Hip *impl; 152 CeedCallBackend(CeedVectorGetData(vec, &impl)); 153 154 switch (mem_type) { 155 case CEED_MEM_HOST: 156 *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned; 157 break; 158 case CEED_MEM_DEVICE: 159 *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned; 160 break; 161 } 162 163 return CEED_ERROR_SUCCESS; 164 } 165 166 //------------------------------------------------------------------------------ 167 // Check if has borrowed array of given type 168 //------------------------------------------------------------------------------ 169 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 170 CeedVector_Hip *impl; 171 CeedCallBackend(CeedVectorGetData(vec, &impl)); 172 173 switch (mem_type) { 174 case CEED_MEM_HOST: 175 *has_borrowed_array_of_type = !!impl->h_array_borrowed; 176 break; 177 case CEED_MEM_DEVICE: 178 *has_borrowed_array_of_type = !!impl->d_array_borrowed; 179 break; 180 } 181 182 return CEED_ERROR_SUCCESS; 183 } 184 185 //------------------------------------------------------------------------------ 186 // Set array from host 187 //------------------------------------------------------------------------------ 188 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 189 CeedVector_Hip *impl; 190 CeedCallBackend(CeedVectorGetData(vec, &impl)); 191 192 switch (copy_mode) { 193 case CEED_COPY_VALUES: { 194 CeedSize length; 195 if (!impl->h_array_owned) { 196 CeedCallBackend(CeedVectorGetLength(vec, &length)); 197 CeedCallBackend(CeedMalloc(length, &impl->h_array_owned)); 198 } 199 impl->h_array_borrowed = NULL; 200 impl->h_array = impl->h_array_owned; 201 if (array) { 202 CeedSize length; 203 CeedCallBackend(CeedVectorGetLength(vec, &length)); 204 size_t bytes = length * sizeof(CeedScalar); 205 memcpy(impl->h_array, array, bytes); 206 } 207 } break; 208 case CEED_OWN_POINTER: 209 CeedCallBackend(CeedFree(&impl->h_array_owned)); 210 impl->h_array_owned = array; 211 impl->h_array_borrowed = NULL; 212 impl->h_array = array; 213 break; 214 case CEED_USE_POINTER: 215 CeedCallBackend(CeedFree(&impl->h_array_owned)); 216 impl->h_array_borrowed = array; 217 impl->h_array = array; 218 break; 219 } 220 221 return CEED_ERROR_SUCCESS; 222 } 223 224 //------------------------------------------------------------------------------ 225 // Set array from device 226 //------------------------------------------------------------------------------ 227 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 228 Ceed ceed; 229 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 230 CeedVector_Hip *impl; 231 CeedCallBackend(CeedVectorGetData(vec, &impl)); 232 233 switch (copy_mode) { 234 case CEED_COPY_VALUES: { 235 CeedSize length; 236 CeedCallBackend(CeedVectorGetLength(vec, &length)); 237 size_t bytes = length * sizeof(CeedScalar); 238 if (!impl->d_array_owned) { 239 CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes)); 240 } 241 impl->d_array_borrowed = NULL; 242 impl->d_array = impl->d_array_owned; 243 if (array) { 244 CeedCallHip(ceed, hipMemcpy(impl->d_array, array, bytes, hipMemcpyDeviceToDevice)); 245 } 246 } break; 247 case CEED_OWN_POINTER: 248 CeedCallHip(ceed, hipFree(impl->d_array_owned)); 249 impl->d_array_owned = array; 250 impl->d_array_borrowed = NULL; 251 impl->d_array = array; 252 break; 253 case CEED_USE_POINTER: 254 CeedCallHip(ceed, hipFree(impl->d_array_owned)); 255 impl->d_array_owned = NULL; 256 impl->d_array_borrowed = array; 257 impl->d_array = array; 258 break; 259 } 260 261 return CEED_ERROR_SUCCESS; 262 } 263 264 //------------------------------------------------------------------------------ 265 // Set the array used by a vector, 266 // freeing any previously allocated array if applicable 267 //------------------------------------------------------------------------------ 268 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 269 Ceed ceed; 270 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 271 CeedVector_Hip *impl; 272 CeedCallBackend(CeedVectorGetData(vec, &impl)); 273 274 CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec)); 275 switch (mem_type) { 276 case CEED_MEM_HOST: 277 return CeedVectorSetArrayHost_Hip(vec, copy_mode, array); 278 case CEED_MEM_DEVICE: 279 return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array); 280 } 281 282 return CEED_ERROR_UNSUPPORTED; 283 } 284 285 //------------------------------------------------------------------------------ 286 // Set host array to value 287 //------------------------------------------------------------------------------ 288 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedInt length, CeedScalar val) { 289 for (int i = 0; i < length; i++) h_array[i] = val; 290 return CEED_ERROR_SUCCESS; 291 } 292 293 //------------------------------------------------------------------------------ 294 // Set device array to value (impl in .hip file) 295 //------------------------------------------------------------------------------ 296 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedInt length, CeedScalar val); 297 298 //------------------------------------------------------------------------------ 299 // Set a vector to a value, 300 //------------------------------------------------------------------------------ 301 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) { 302 Ceed ceed; 303 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 304 CeedVector_Hip *impl; 305 CeedCallBackend(CeedVectorGetData(vec, &impl)); 306 CeedSize length; 307 CeedCallBackend(CeedVectorGetLength(vec, &length)); 308 309 // Set value for synced device/host array 310 if (!impl->d_array && !impl->h_array) { 311 if (impl->d_array_borrowed) { 312 impl->d_array = impl->d_array_borrowed; 313 } else if (impl->h_array_borrowed) { 314 impl->h_array = impl->h_array_borrowed; 315 } else if (impl->d_array_owned) { 316 impl->d_array = impl->d_array_owned; 317 } else if (impl->h_array_owned) { 318 impl->h_array = impl->h_array_owned; 319 } else { 320 CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 321 } 322 } 323 if (impl->d_array) { 324 CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val)); 325 } 326 if (impl->h_array) { 327 CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val)); 328 } 329 330 return CEED_ERROR_SUCCESS; 331 } 332 333 //------------------------------------------------------------------------------ 334 // Vector Take Array 335 //------------------------------------------------------------------------------ 336 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 337 Ceed ceed; 338 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 339 CeedVector_Hip *impl; 340 CeedCallBackend(CeedVectorGetData(vec, &impl)); 341 342 // Sync array to requested mem_type 343 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 344 345 // Update pointer 346 switch (mem_type) { 347 case CEED_MEM_HOST: 348 (*array) = impl->h_array_borrowed; 349 impl->h_array_borrowed = NULL; 350 impl->h_array = NULL; 351 break; 352 case CEED_MEM_DEVICE: 353 (*array) = impl->d_array_borrowed; 354 impl->d_array_borrowed = NULL; 355 impl->d_array = NULL; 356 break; 357 } 358 359 return CEED_ERROR_SUCCESS; 360 } 361 362 //------------------------------------------------------------------------------ 363 // Core logic for array syncronization for GetArray. 364 // If a different memory type is most up to date, this will perform a copy 365 //------------------------------------------------------------------------------ 366 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 367 Ceed ceed; 368 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 369 CeedVector_Hip *impl; 370 CeedCallBackend(CeedVectorGetData(vec, &impl)); 371 372 // Sync array to requested mem_type 373 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 374 375 // Update pointer 376 switch (mem_type) { 377 case CEED_MEM_HOST: 378 *array = impl->h_array; 379 break; 380 case CEED_MEM_DEVICE: 381 *array = impl->d_array; 382 break; 383 } 384 385 return CEED_ERROR_SUCCESS; 386 } 387 388 //------------------------------------------------------------------------------ 389 // Get read-only access to a vector via the specified mem_type 390 //------------------------------------------------------------------------------ 391 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 392 return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array); 393 } 394 395 //------------------------------------------------------------------------------ 396 // Get read/write access to a vector via the specified mem_type 397 //------------------------------------------------------------------------------ 398 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 399 CeedVector_Hip *impl; 400 CeedCallBackend(CeedVectorGetData(vec, &impl)); 401 402 CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array)); 403 404 CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec)); 405 switch (mem_type) { 406 case CEED_MEM_HOST: 407 impl->h_array = *array; 408 break; 409 case CEED_MEM_DEVICE: 410 impl->d_array = *array; 411 break; 412 } 413 414 return CEED_ERROR_SUCCESS; 415 } 416 417 //------------------------------------------------------------------------------ 418 // Get write access to a vector via the specified mem_type 419 //------------------------------------------------------------------------------ 420 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 421 CeedVector_Hip *impl; 422 CeedCallBackend(CeedVectorGetData(vec, &impl)); 423 424 bool has_array_of_type = true; 425 CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type)); 426 if (!has_array_of_type) { 427 // Allocate if array is not yet allocated 428 CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 429 } else { 430 // Select dirty array 431 switch (mem_type) { 432 case CEED_MEM_HOST: 433 if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 434 else impl->h_array = impl->h_array_owned; 435 break; 436 case CEED_MEM_DEVICE: 437 if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 438 else impl->d_array = impl->d_array_owned; 439 } 440 } 441 442 return CeedVectorGetArray_Hip(vec, mem_type, array); 443 } 444 445 //------------------------------------------------------------------------------ 446 // Get the norm of a CeedVector 447 //------------------------------------------------------------------------------ 448 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) { 449 Ceed ceed; 450 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 451 CeedVector_Hip *impl; 452 CeedCallBackend(CeedVectorGetData(vec, &impl)); 453 CeedSize length; 454 CeedCallBackend(CeedVectorGetLength(vec, &length)); 455 hipblasHandle_t handle; 456 CeedCallBackend(CeedHipGetHipblasHandle(ceed, &handle)); 457 458 // Compute norm 459 const CeedScalar *d_array; 460 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 461 switch (type) { 462 case CEED_NORM_1: { 463 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 464 CeedCallHipblas(ceed, hipblasSasum(handle, length, (float *)d_array, 1, (float *)norm)); 465 } else { 466 CeedCallHipblas(ceed, hipblasDasum(handle, length, (double *)d_array, 1, (double *)norm)); 467 } 468 break; 469 } 470 case CEED_NORM_2: { 471 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 472 CeedCallHipblas(ceed, hipblasSnrm2(handle, length, (float *)d_array, 1, (float *)norm)); 473 } else { 474 CeedCallHipblas(ceed, hipblasDnrm2(handle, length, (double *)d_array, 1, (double *)norm)); 475 } 476 break; 477 } 478 case CEED_NORM_MAX: { 479 CeedInt indx; 480 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 481 CeedCallHipblas(ceed, hipblasIsamax(handle, length, (float *)d_array, 1, &indx)); 482 } else { 483 CeedCallHipblas(ceed, hipblasIdamax(handle, length, (double *)d_array, 1, &indx)); 484 } 485 CeedScalar normNoAbs; 486 CeedCallHip(ceed, hipMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost)); 487 *norm = fabs(normNoAbs); 488 break; 489 } 490 } 491 CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 492 493 return CEED_ERROR_SUCCESS; 494 } 495 496 //------------------------------------------------------------------------------ 497 // Take reciprocal of a vector on host 498 //------------------------------------------------------------------------------ 499 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedInt length) { 500 for (int i = 0; i < length; i++) { 501 if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 502 } 503 return CEED_ERROR_SUCCESS; 504 } 505 506 //------------------------------------------------------------------------------ 507 // Take reciprocal of a vector on device (impl in .cu file) 508 //------------------------------------------------------------------------------ 509 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedInt length); 510 511 //------------------------------------------------------------------------------ 512 // Take reciprocal of a vector 513 //------------------------------------------------------------------------------ 514 static int CeedVectorReciprocal_Hip(CeedVector vec) { 515 Ceed ceed; 516 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 517 CeedVector_Hip *impl; 518 CeedCallBackend(CeedVectorGetData(vec, &impl)); 519 CeedSize length; 520 CeedCallBackend(CeedVectorGetLength(vec, &length)); 521 522 // Set value for synced device/host array 523 if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length)); 524 if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length)); 525 526 return CEED_ERROR_SUCCESS; 527 } 528 529 //------------------------------------------------------------------------------ 530 // Compute x = alpha x on the host 531 //------------------------------------------------------------------------------ 532 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedInt length) { 533 for (int i = 0; i < length; i++) x_array[i] *= alpha; 534 return CEED_ERROR_SUCCESS; 535 } 536 537 //------------------------------------------------------------------------------ 538 // Compute x = alpha x on device (impl in .cu file) 539 //------------------------------------------------------------------------------ 540 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedInt length); 541 542 //------------------------------------------------------------------------------ 543 // Compute x = alpha x 544 //------------------------------------------------------------------------------ 545 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) { 546 Ceed ceed; 547 CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 548 CeedVector_Hip *x_impl; 549 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 550 CeedSize length; 551 CeedCallBackend(CeedVectorGetLength(x, &length)); 552 553 // Set value for synced device/host array 554 if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length)); 555 if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length)); 556 557 return CEED_ERROR_SUCCESS; 558 } 559 560 //------------------------------------------------------------------------------ 561 // Compute y = alpha x + y on the host 562 //------------------------------------------------------------------------------ 563 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) { 564 for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 565 return CEED_ERROR_SUCCESS; 566 } 567 568 //------------------------------------------------------------------------------ 569 // Compute y = alpha x + y on device (impl in .cu file) 570 //------------------------------------------------------------------------------ 571 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length); 572 573 //------------------------------------------------------------------------------ 574 // Compute y = alpha x + y 575 //------------------------------------------------------------------------------ 576 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) { 577 Ceed ceed; 578 CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 579 CeedVector_Hip *y_impl, *x_impl; 580 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 581 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 582 CeedSize length; 583 CeedCallBackend(CeedVectorGetLength(y, &length)); 584 585 // Set value for synced device/host array 586 if (y_impl->d_array) { 587 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 588 CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length)); 589 } 590 if (y_impl->h_array) { 591 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 592 CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length)); 593 } 594 595 return CEED_ERROR_SUCCESS; 596 } 597 598 //------------------------------------------------------------------------------ 599 // Compute the pointwise multiplication w = x .* y on the host 600 //------------------------------------------------------------------------------ 601 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) { 602 for (int i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 603 return CEED_ERROR_SUCCESS; 604 } 605 606 //------------------------------------------------------------------------------ 607 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 608 //------------------------------------------------------------------------------ 609 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length); 610 611 //------------------------------------------------------------------------------ 612 // Compute the pointwise multiplication w = x .* y 613 //------------------------------------------------------------------------------ 614 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) { 615 Ceed ceed; 616 CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 617 CeedVector_Hip *w_impl, *x_impl, *y_impl; 618 CeedCallBackend(CeedVectorGetData(w, &w_impl)); 619 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 620 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 621 CeedSize length; 622 CeedCallBackend(CeedVectorGetLength(w, &length)); 623 624 // Set value for synced device/host array 625 if (!w_impl->d_array && !w_impl->h_array) { 626 CeedCallBackend(CeedVectorSetValue(w, 0.0)); 627 } 628 if (w_impl->d_array) { 629 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 630 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 631 CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 632 } 633 if (w_impl->h_array) { 634 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 635 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 636 CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 637 } 638 639 return CEED_ERROR_SUCCESS; 640 } 641 642 //------------------------------------------------------------------------------ 643 // Destroy the vector 644 //------------------------------------------------------------------------------ 645 static int CeedVectorDestroy_Hip(const CeedVector vec) { 646 Ceed ceed; 647 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 648 CeedVector_Hip *impl; 649 CeedCallBackend(CeedVectorGetData(vec, &impl)); 650 651 CeedCallHip(ceed, hipFree(impl->d_array_owned)); 652 CeedCallBackend(CeedFree(&impl->h_array_owned)); 653 CeedCallBackend(CeedFree(&impl)); 654 655 return CEED_ERROR_SUCCESS; 656 } 657 658 //------------------------------------------------------------------------------ 659 // Create a vector of the specified length (does not allocate memory) 660 //------------------------------------------------------------------------------ 661 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) { 662 CeedVector_Hip *impl; 663 Ceed ceed; 664 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 665 666 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip)); 667 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip)); 668 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip)); 669 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip)); 670 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())(CeedVectorSetValue_Hip))); 671 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip)); 672 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip)); 673 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip)); 674 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip)); 675 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip)); 676 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip)); 677 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())(CeedVectorScale_Hip))); 678 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())(CeedVectorAXPY_Hip))); 679 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip)); 680 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip)); 681 682 CeedCallBackend(CeedCalloc(1, &impl)); 683 CeedCallBackend(CeedVectorSetData(vec, impl)); 684 685 return CEED_ERROR_SUCCESS; 686 } 687