1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <math.h> 11 #include <stdbool.h> 12 #include <string.h> 13 #include <hip/hip_runtime.h> 14 15 #include "../hip/ceed-hip-common.h" 16 #include "ceed-hip-ref.h" 17 18 //------------------------------------------------------------------------------ 19 // Check if host/device sync is needed 20 //------------------------------------------------------------------------------ 21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 22 CeedVector_Hip *impl; 23 bool has_valid_array = false; 24 25 CeedCallBackend(CeedVectorGetData(vec, &impl)); 26 CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 27 switch (mem_type) { 28 case CEED_MEM_HOST: 29 *need_sync = has_valid_array && !impl->h_array; 30 break; 31 case CEED_MEM_DEVICE: 32 *need_sync = has_valid_array && !impl->d_array; 33 break; 34 } 35 return CEED_ERROR_SUCCESS; 36 } 37 38 //------------------------------------------------------------------------------ 39 // Sync host to device 40 //------------------------------------------------------------------------------ 41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) { 42 Ceed ceed; 43 CeedSize length; 44 CeedVector_Hip *impl; 45 46 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 47 CeedCallBackend(CeedVectorGetData(vec, &impl)); 48 49 CeedCallBackend(CeedVectorGetLength(vec, &length)); 50 size_t bytes = length * sizeof(CeedScalar); 51 52 CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 53 54 if (impl->d_array_borrowed) { 55 impl->d_array = impl->d_array_borrowed; 56 } else if (impl->d_array_owned) { 57 impl->d_array = impl->d_array_owned; 58 } else { 59 CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes)); 60 impl->d_array = impl->d_array_owned; 61 } 62 CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice)); 63 return CEED_ERROR_SUCCESS; 64 } 65 66 //------------------------------------------------------------------------------ 67 // Sync device to host 68 //------------------------------------------------------------------------------ 69 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) { 70 Ceed ceed; 71 CeedSize length; 72 CeedVector_Hip *impl; 73 74 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 75 CeedCallBackend(CeedVectorGetData(vec, &impl)); 76 77 CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 78 79 if (impl->h_array_borrowed) { 80 impl->h_array = impl->h_array_borrowed; 81 } else if (impl->h_array_owned) { 82 impl->h_array = impl->h_array_owned; 83 } else { 84 CeedSize length; 85 CeedCallBackend(CeedVectorGetLength(vec, &length)); 86 CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 87 impl->h_array = impl->h_array_owned; 88 } 89 90 CeedCallBackend(CeedVectorGetLength(vec, &length)); 91 size_t bytes = length * sizeof(CeedScalar); 92 93 CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost)); 94 return CEED_ERROR_SUCCESS; 95 } 96 97 //------------------------------------------------------------------------------ 98 // Sync arrays 99 //------------------------------------------------------------------------------ 100 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) { 101 bool need_sync = false; 102 103 // Check whether device/host sync is needed 104 CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync)); 105 if (!need_sync) return CEED_ERROR_SUCCESS; 106 107 switch (mem_type) { 108 case CEED_MEM_HOST: 109 return CeedVectorSyncD2H_Hip(vec); 110 case CEED_MEM_DEVICE: 111 return CeedVectorSyncH2D_Hip(vec); 112 } 113 return CEED_ERROR_UNSUPPORTED; 114 } 115 116 //------------------------------------------------------------------------------ 117 // Set all pointers as invalid 118 //------------------------------------------------------------------------------ 119 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) { 120 CeedVector_Hip *impl; 121 122 CeedCallBackend(CeedVectorGetData(vec, &impl)); 123 impl->h_array = NULL; 124 impl->d_array = NULL; 125 return CEED_ERROR_SUCCESS; 126 } 127 128 //------------------------------------------------------------------------------ 129 // Check if CeedVector has any valid pointer 130 //------------------------------------------------------------------------------ 131 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) { 132 CeedVector_Hip *impl; 133 134 CeedCallBackend(CeedVectorGetData(vec, &impl)); 135 *has_valid_array = impl->h_array || impl->d_array; 136 return CEED_ERROR_SUCCESS; 137 } 138 139 //------------------------------------------------------------------------------ 140 // Check if has array of given type 141 //------------------------------------------------------------------------------ 142 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 143 CeedVector_Hip *impl; 144 145 CeedCallBackend(CeedVectorGetData(vec, &impl)); 146 switch (mem_type) { 147 case CEED_MEM_HOST: 148 *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned; 149 break; 150 case CEED_MEM_DEVICE: 151 *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned; 152 break; 153 } 154 return CEED_ERROR_SUCCESS; 155 } 156 157 //------------------------------------------------------------------------------ 158 // Check if has borrowed array of given type 159 //------------------------------------------------------------------------------ 160 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 161 CeedVector_Hip *impl; 162 163 CeedCallBackend(CeedVectorGetData(vec, &impl)); 164 switch (mem_type) { 165 case CEED_MEM_HOST: 166 *has_borrowed_array_of_type = impl->h_array_borrowed; 167 break; 168 case CEED_MEM_DEVICE: 169 *has_borrowed_array_of_type = impl->d_array_borrowed; 170 break; 171 } 172 return CEED_ERROR_SUCCESS; 173 } 174 175 //------------------------------------------------------------------------------ 176 // Set array from host 177 //------------------------------------------------------------------------------ 178 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 179 CeedVector_Hip *impl; 180 181 CeedCallBackend(CeedVectorGetData(vec, &impl)); 182 switch (copy_mode) { 183 case CEED_COPY_VALUES: { 184 CeedSize length; 185 186 if (!impl->h_array_owned) { 187 CeedCallBackend(CeedVectorGetLength(vec, &length)); 188 CeedCallBackend(CeedMalloc(length, &impl->h_array_owned)); 189 } 190 impl->h_array_borrowed = NULL; 191 impl->h_array = impl->h_array_owned; 192 if (array) { 193 CeedSize length; 194 195 CeedCallBackend(CeedVectorGetLength(vec, &length)); 196 size_t bytes = length * sizeof(CeedScalar); 197 memcpy(impl->h_array, array, bytes); 198 } 199 } break; 200 case CEED_OWN_POINTER: 201 CeedCallBackend(CeedFree(&impl->h_array_owned)); 202 impl->h_array_owned = array; 203 impl->h_array_borrowed = NULL; 204 impl->h_array = array; 205 break; 206 case CEED_USE_POINTER: 207 CeedCallBackend(CeedFree(&impl->h_array_owned)); 208 impl->h_array_borrowed = array; 209 impl->h_array = array; 210 break; 211 } 212 return CEED_ERROR_SUCCESS; 213 } 214 215 //------------------------------------------------------------------------------ 216 // Set array from device 217 //------------------------------------------------------------------------------ 218 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 219 Ceed ceed; 220 CeedVector_Hip *impl; 221 222 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 223 CeedCallBackend(CeedVectorGetData(vec, &impl)); 224 switch (copy_mode) { 225 case CEED_COPY_VALUES: { 226 CeedSize length; 227 228 CeedCallBackend(CeedVectorGetLength(vec, &length)); 229 size_t bytes = length * sizeof(CeedScalar); 230 231 if (!impl->d_array_owned) { 232 CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes)); 233 } 234 impl->d_array_borrowed = NULL; 235 impl->d_array = impl->d_array_owned; 236 if (array) CeedCallHip(ceed, hipMemcpy(impl->d_array, array, bytes, hipMemcpyDeviceToDevice)); 237 } break; 238 case CEED_OWN_POINTER: 239 CeedCallHip(ceed, hipFree(impl->d_array_owned)); 240 impl->d_array_owned = array; 241 impl->d_array_borrowed = NULL; 242 impl->d_array = array; 243 break; 244 case CEED_USE_POINTER: 245 CeedCallHip(ceed, hipFree(impl->d_array_owned)); 246 impl->d_array_owned = NULL; 247 impl->d_array_borrowed = array; 248 impl->d_array = array; 249 break; 250 } 251 return CEED_ERROR_SUCCESS; 252 } 253 254 //------------------------------------------------------------------------------ 255 // Set the array used by a vector, 256 // freeing any previously allocated array if applicable 257 //------------------------------------------------------------------------------ 258 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 259 Ceed ceed; 260 CeedVector_Hip *impl; 261 262 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 263 CeedCallBackend(CeedVectorGetData(vec, &impl)); 264 CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec)); 265 switch (mem_type) { 266 case CEED_MEM_HOST: 267 return CeedVectorSetArrayHost_Hip(vec, copy_mode, array); 268 case CEED_MEM_DEVICE: 269 return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array); 270 } 271 return CEED_ERROR_UNSUPPORTED; 272 } 273 274 //------------------------------------------------------------------------------ 275 // Set host array to value 276 //------------------------------------------------------------------------------ 277 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) { 278 for (CeedSize i = 0; i < length; i++) h_array[i] = val; 279 return CEED_ERROR_SUCCESS; 280 } 281 282 //------------------------------------------------------------------------------ 283 // Set device array to value (impl in .hip file) 284 //------------------------------------------------------------------------------ 285 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val); 286 287 //------------------------------------------------------------------------------ 288 // Set a vector to a value 289 //------------------------------------------------------------------------------ 290 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) { 291 Ceed ceed; 292 CeedSize length; 293 CeedVector_Hip *impl; 294 295 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 296 CeedCallBackend(CeedVectorGetData(vec, &impl)); 297 CeedCallBackend(CeedVectorGetLength(vec, &length)); 298 // Set value for synced device/host array 299 if (!impl->d_array && !impl->h_array) { 300 if (impl->d_array_borrowed) { 301 impl->d_array = impl->d_array_borrowed; 302 } else if (impl->h_array_borrowed) { 303 impl->h_array = impl->h_array_borrowed; 304 } else if (impl->d_array_owned) { 305 impl->d_array = impl->d_array_owned; 306 } else if (impl->h_array_owned) { 307 impl->h_array = impl->h_array_owned; 308 } else { 309 CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 310 } 311 } 312 if (impl->d_array) { 313 CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val)); 314 impl->h_array = NULL; 315 } 316 if (impl->h_array) { 317 CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val)); 318 impl->d_array = NULL; 319 } 320 return CEED_ERROR_SUCCESS; 321 } 322 323 //------------------------------------------------------------------------------ 324 // Vector Take Array 325 //------------------------------------------------------------------------------ 326 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 327 Ceed ceed; 328 CeedVector_Hip *impl; 329 330 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 331 CeedCallBackend(CeedVectorGetData(vec, &impl)); 332 333 // Sync array to requested mem_type 334 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 335 336 // Update pointer 337 switch (mem_type) { 338 case CEED_MEM_HOST: 339 (*array) = impl->h_array_borrowed; 340 impl->h_array_borrowed = NULL; 341 impl->h_array = NULL; 342 break; 343 case CEED_MEM_DEVICE: 344 (*array) = impl->d_array_borrowed; 345 impl->d_array_borrowed = NULL; 346 impl->d_array = NULL; 347 break; 348 } 349 return CEED_ERROR_SUCCESS; 350 } 351 352 //------------------------------------------------------------------------------ 353 // Core logic for array syncronization for GetArray. 354 // If a different memory type is most up to date, this will perform a copy 355 //------------------------------------------------------------------------------ 356 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 357 Ceed ceed; 358 CeedVector_Hip *impl; 359 360 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 361 CeedCallBackend(CeedVectorGetData(vec, &impl)); 362 363 // Sync array to requested mem_type 364 CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 365 366 // Update pointer 367 switch (mem_type) { 368 case CEED_MEM_HOST: 369 *array = impl->h_array; 370 break; 371 case CEED_MEM_DEVICE: 372 *array = impl->d_array; 373 break; 374 } 375 return CEED_ERROR_SUCCESS; 376 } 377 378 //------------------------------------------------------------------------------ 379 // Get read-only access to a vector via the specified mem_type 380 //------------------------------------------------------------------------------ 381 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 382 return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array); 383 } 384 385 //------------------------------------------------------------------------------ 386 // Get read/write access to a vector via the specified mem_type 387 //------------------------------------------------------------------------------ 388 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 389 CeedVector_Hip *impl; 390 391 CeedCallBackend(CeedVectorGetData(vec, &impl)); 392 CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array)); 393 CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec)); 394 switch (mem_type) { 395 case CEED_MEM_HOST: 396 impl->h_array = *array; 397 break; 398 case CEED_MEM_DEVICE: 399 impl->d_array = *array; 400 break; 401 } 402 return CEED_ERROR_SUCCESS; 403 } 404 405 //------------------------------------------------------------------------------ 406 // Get write access to a vector via the specified mem_type 407 //------------------------------------------------------------------------------ 408 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 409 bool has_array_of_type = true; 410 CeedVector_Hip *impl; 411 412 CeedCallBackend(CeedVectorGetData(vec, &impl)); 413 CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type)); 414 if (!has_array_of_type) { 415 // Allocate if array is not yet allocated 416 CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 417 } else { 418 // Select dirty array 419 switch (mem_type) { 420 case CEED_MEM_HOST: 421 if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 422 else impl->h_array = impl->h_array_owned; 423 break; 424 case CEED_MEM_DEVICE: 425 if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 426 else impl->d_array = impl->d_array_owned; 427 } 428 } 429 return CeedVectorGetArray_Hip(vec, mem_type, array); 430 } 431 432 //------------------------------------------------------------------------------ 433 // Get the norm of a CeedVector 434 //------------------------------------------------------------------------------ 435 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) { 436 Ceed ceed; 437 CeedSize length; 438 const CeedScalar *d_array; 439 CeedVector_Hip *impl; 440 hipblasHandle_t handle; 441 442 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 443 CeedCallBackend(CeedVectorGetData(vec, &impl)); 444 CeedCallBackend(CeedVectorGetLength(vec, &length)); 445 CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle)); 446 447 // Is the vector too long to handle with int32? If so, we will divide 448 // it up into "int32-sized" subsections and make repeated BLAS calls. 449 CeedSize num_calls = length / INT_MAX; 450 451 if (length % INT_MAX > 0) num_calls += 1; 452 453 // Compute norm 454 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 455 switch (type) { 456 case CEED_NORM_1: { 457 *norm = 0.0; 458 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 459 float sub_norm = 0.0; 460 float *d_array_start; 461 462 for (CeedInt i = 0; i < num_calls; i++) { 463 d_array_start = (float *)d_array + (CeedSize)(i)*INT_MAX; 464 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 465 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 466 467 CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm)); 468 *norm += sub_norm; 469 } 470 } else { 471 double sub_norm = 0.0; 472 double *d_array_start; 473 474 for (CeedInt i = 0; i < num_calls; i++) { 475 d_array_start = (double *)d_array + (CeedSize)(i)*INT_MAX; 476 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 477 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 478 479 CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm)); 480 *norm += sub_norm; 481 } 482 } 483 break; 484 } 485 case CEED_NORM_2: { 486 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 487 float sub_norm = 0.0, norm_sum = 0.0; 488 float *d_array_start; 489 490 for (CeedInt i = 0; i < num_calls; i++) { 491 d_array_start = (float *)d_array + (CeedSize)(i)*INT_MAX; 492 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 493 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 494 495 CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm)); 496 norm_sum += sub_norm * sub_norm; 497 } 498 *norm = sqrt(norm_sum); 499 } else { 500 double sub_norm = 0.0, norm_sum = 0.0; 501 double *d_array_start; 502 503 for (CeedInt i = 0; i < num_calls; i++) { 504 d_array_start = (double *)d_array + (CeedSize)(i)*INT_MAX; 505 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 506 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 507 508 CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm)); 509 norm_sum += sub_norm * sub_norm; 510 } 511 *norm = sqrt(norm_sum); 512 } 513 break; 514 } 515 case CEED_NORM_MAX: { 516 CeedInt index; 517 518 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 519 float sub_max = 0.0, current_max = 0.0; 520 float *d_array_start; 521 for (CeedInt i = 0; i < num_calls; i++) { 522 d_array_start = (float *)d_array + (CeedSize)(i)*INT_MAX; 523 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 524 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 525 526 CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index)); 527 CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost)); 528 if (fabs(sub_max) > current_max) current_max = fabs(sub_max); 529 } 530 *norm = current_max; 531 } else { 532 double sub_max = 0.0, current_max = 0.0; 533 double *d_array_start; 534 535 for (CeedInt i = 0; i < num_calls; i++) { 536 d_array_start = (double *)d_array + (CeedSize)(i)*INT_MAX; 537 CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 538 CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 539 540 CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index)); 541 CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost)); 542 if (fabs(sub_max) > current_max) current_max = fabs(sub_max); 543 } 544 *norm = current_max; 545 } 546 break; 547 } 548 } 549 CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 550 return CEED_ERROR_SUCCESS; 551 } 552 553 //------------------------------------------------------------------------------ 554 // Take reciprocal of a vector on host 555 //------------------------------------------------------------------------------ 556 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) { 557 for (CeedSize i = 0; i < length; i++) { 558 if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 559 } 560 return CEED_ERROR_SUCCESS; 561 } 562 563 //------------------------------------------------------------------------------ 564 // Take reciprocal of a vector on device (impl in .cu file) 565 //------------------------------------------------------------------------------ 566 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length); 567 568 //------------------------------------------------------------------------------ 569 // Take reciprocal of a vector 570 //------------------------------------------------------------------------------ 571 static int CeedVectorReciprocal_Hip(CeedVector vec) { 572 Ceed ceed; 573 CeedSize length; 574 CeedVector_Hip *impl; 575 576 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 577 CeedCallBackend(CeedVectorGetData(vec, &impl)); 578 CeedCallBackend(CeedVectorGetLength(vec, &length)); 579 // Set value for synced device/host array 580 if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length)); 581 if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length)); 582 return CEED_ERROR_SUCCESS; 583 } 584 585 //------------------------------------------------------------------------------ 586 // Compute x = alpha x on the host 587 //------------------------------------------------------------------------------ 588 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 589 for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha; 590 return CEED_ERROR_SUCCESS; 591 } 592 593 //------------------------------------------------------------------------------ 594 // Compute x = alpha x on device (impl in .cu file) 595 //------------------------------------------------------------------------------ 596 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length); 597 598 //------------------------------------------------------------------------------ 599 // Compute x = alpha x 600 //------------------------------------------------------------------------------ 601 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) { 602 Ceed ceed; 603 CeedSize length; 604 CeedVector_Hip *x_impl; 605 606 CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 607 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 608 CeedCallBackend(CeedVectorGetLength(x, &length)); 609 // Set value for synced device/host array 610 if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length)); 611 if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length)); 612 return CEED_ERROR_SUCCESS; 613 } 614 615 //------------------------------------------------------------------------------ 616 // Compute y = alpha x + y on the host 617 //------------------------------------------------------------------------------ 618 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 619 for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 620 return CEED_ERROR_SUCCESS; 621 } 622 623 //------------------------------------------------------------------------------ 624 // Compute y = alpha x + y on device (impl in .cu file) 625 //------------------------------------------------------------------------------ 626 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length); 627 628 //------------------------------------------------------------------------------ 629 // Compute y = alpha x + y 630 //------------------------------------------------------------------------------ 631 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) { 632 Ceed ceed; 633 CeedSize length; 634 CeedVector_Hip *y_impl, *x_impl; 635 636 CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 637 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 638 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 639 CeedCallBackend(CeedVectorGetLength(y, &length)); 640 // Set value for synced device/host array 641 if (y_impl->d_array) { 642 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 643 CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length)); 644 } 645 if (y_impl->h_array) { 646 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 647 CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length)); 648 } 649 return CEED_ERROR_SUCCESS; 650 } 651 652 //------------------------------------------------------------------------------ 653 // Compute y = alpha x + beta y on the host 654 //------------------------------------------------------------------------------ 655 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) { 656 for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i]; 657 return CEED_ERROR_SUCCESS; 658 } 659 660 //------------------------------------------------------------------------------ 661 // Compute y = alpha x + beta y on device (impl in .cu file) 662 //------------------------------------------------------------------------------ 663 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length); 664 665 //------------------------------------------------------------------------------ 666 // Compute y = alpha x + beta y 667 //------------------------------------------------------------------------------ 668 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) { 669 Ceed ceed; 670 CeedSize length; 671 CeedVector_Hip *y_impl, *x_impl; 672 673 CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 674 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 675 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 676 CeedCallBackend(CeedVectorGetLength(y, &length)); 677 // Set value for synced device/host array 678 if (y_impl->d_array) { 679 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 680 CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length)); 681 } 682 if (y_impl->h_array) { 683 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 684 CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length)); 685 } 686 return CEED_ERROR_SUCCESS; 687 } 688 689 //------------------------------------------------------------------------------ 690 // Compute the pointwise multiplication w = x .* y on the host 691 //------------------------------------------------------------------------------ 692 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 693 for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 694 return CEED_ERROR_SUCCESS; 695 } 696 697 //------------------------------------------------------------------------------ 698 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 699 //------------------------------------------------------------------------------ 700 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length); 701 702 //------------------------------------------------------------------------------ 703 // Compute the pointwise multiplication w = x .* y 704 //------------------------------------------------------------------------------ 705 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) { 706 Ceed ceed; 707 CeedSize length; 708 CeedVector_Hip *w_impl, *x_impl, *y_impl; 709 710 CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 711 CeedCallBackend(CeedVectorGetData(w, &w_impl)); 712 CeedCallBackend(CeedVectorGetData(x, &x_impl)); 713 CeedCallBackend(CeedVectorGetData(y, &y_impl)); 714 CeedCallBackend(CeedVectorGetLength(w, &length)); 715 716 // Set value for synced device/host array 717 if (!w_impl->d_array && !w_impl->h_array) { 718 CeedCallBackend(CeedVectorSetValue(w, 0.0)); 719 } 720 if (w_impl->d_array) { 721 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 722 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 723 CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 724 } 725 if (w_impl->h_array) { 726 CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 727 CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 728 CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 729 } 730 return CEED_ERROR_SUCCESS; 731 } 732 733 //------------------------------------------------------------------------------ 734 // Destroy the vector 735 //------------------------------------------------------------------------------ 736 static int CeedVectorDestroy_Hip(const CeedVector vec) { 737 Ceed ceed; 738 CeedVector_Hip *impl; 739 740 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 741 CeedCallBackend(CeedVectorGetData(vec, &impl)); 742 CeedCallHip(ceed, hipFree(impl->d_array_owned)); 743 CeedCallBackend(CeedFree(&impl->h_array_owned)); 744 CeedCallBackend(CeedFree(&impl)); 745 return CEED_ERROR_SUCCESS; 746 } 747 748 //------------------------------------------------------------------------------ 749 // Create a vector of the specified length (does not allocate memory) 750 //------------------------------------------------------------------------------ 751 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) { 752 CeedVector_Hip *impl; 753 Ceed ceed; 754 755 CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 756 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip)); 757 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip)); 758 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip)); 759 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip)); 760 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Hip)); 761 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip)); 762 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip)); 763 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip)); 764 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip)); 765 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip)); 766 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip)); 767 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Hip)); 768 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Hip)); 769 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Hip)); 770 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip)); 771 CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip)); 772 CeedCallBackend(CeedCalloc(1, &impl)); 773 CeedCallBackend(CeedVectorSetData(vec, impl)); 774 return CEED_ERROR_SUCCESS; 775 } 776 777 //------------------------------------------------------------------------------ 778