xref: /libCEED/rust/libceed-sys/c-src/backends/hip-ref/ceed-hip-ref-vector.c (revision 9330daecb0fc008043eec1b94c46ef7aecbb00cd)
13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
30d0321e0SJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
50d0321e0SJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
70d0321e0SJeremy L Thompson 
849aac155SJeremy L Thompson #include <ceed.h>
90d0321e0SJeremy L Thompson #include <ceed/backend.h>
100d0321e0SJeremy L Thompson #include <math.h>
1149aac155SJeremy L Thompson #include <stdbool.h>
120d0321e0SJeremy L Thompson #include <string.h>
13c85e8640SSebastian Grimberg #include <hip/hip_runtime.h>
140d0321e0SJeremy L Thompson 
1549aac155SJeremy L Thompson #include "../hip/ceed-hip-common.h"
162b730f8bSJeremy L Thompson #include "ceed-hip-ref.h"
17f48ed27dSnbeams 
18f48ed27dSnbeams //------------------------------------------------------------------------------
19f48ed27dSnbeams // Check if host/device sync is needed
20f48ed27dSnbeams //------------------------------------------------------------------------------
212b730f8bSJeremy L Thompson static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22f48ed27dSnbeams   CeedVector_Hip *impl;
232b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
24f48ed27dSnbeams 
25f48ed27dSnbeams   bool has_valid_array = false;
262b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27f48ed27dSnbeams   switch (mem_type) {
28f48ed27dSnbeams     case CEED_MEM_HOST:
29f48ed27dSnbeams       *need_sync = has_valid_array && !impl->h_array;
30f48ed27dSnbeams       break;
31f48ed27dSnbeams     case CEED_MEM_DEVICE:
32f48ed27dSnbeams       *need_sync = has_valid_array && !impl->d_array;
33f48ed27dSnbeams       break;
34f48ed27dSnbeams   }
35f48ed27dSnbeams 
36f48ed27dSnbeams   return CEED_ERROR_SUCCESS;
37f48ed27dSnbeams }
38f48ed27dSnbeams 
390d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
400d0321e0SJeremy L Thompson // Sync host to device
410d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
420d0321e0SJeremy L Thompson static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
430d0321e0SJeremy L Thompson   Ceed ceed;
442b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
450d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
462b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
470d0321e0SJeremy L Thompson 
48539ec17dSJeremy L Thompson   CeedSize length;
492b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
50539ec17dSJeremy L Thompson   size_t bytes = length * sizeof(CeedScalar);
51539ec17dSJeremy L Thompson 
526574a04fSJeremy L Thompson   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
530d0321e0SJeremy L Thompson 
540d0321e0SJeremy L Thompson   if (impl->d_array_borrowed) {
550d0321e0SJeremy L Thompson     impl->d_array = impl->d_array_borrowed;
560d0321e0SJeremy L Thompson   } else if (impl->d_array_owned) {
570d0321e0SJeremy L Thompson     impl->d_array = impl->d_array_owned;
580d0321e0SJeremy L Thompson   } else {
592b730f8bSJeremy L Thompson     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
600d0321e0SJeremy L Thompson     impl->d_array = impl->d_array_owned;
610d0321e0SJeremy L Thompson   }
620d0321e0SJeremy L Thompson 
632b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
640d0321e0SJeremy L Thompson 
650d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
660d0321e0SJeremy L Thompson }
670d0321e0SJeremy L Thompson 
680d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
690d0321e0SJeremy L Thompson // Sync device to host
700d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
710d0321e0SJeremy L Thompson static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
720d0321e0SJeremy L Thompson   Ceed ceed;
732b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
740d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
752b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
760d0321e0SJeremy L Thompson 
776574a04fSJeremy L Thompson   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
780d0321e0SJeremy L Thompson 
790d0321e0SJeremy L Thompson   if (impl->h_array_borrowed) {
800d0321e0SJeremy L Thompson     impl->h_array = impl->h_array_borrowed;
810d0321e0SJeremy L Thompson   } else if (impl->h_array_owned) {
820d0321e0SJeremy L Thompson     impl->h_array = impl->h_array_owned;
830d0321e0SJeremy L Thompson   } else {
841f9221feSJeremy L Thompson     CeedSize length;
852b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorGetLength(vec, &length));
862b730f8bSJeremy L Thompson     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
870d0321e0SJeremy L Thompson     impl->h_array = impl->h_array_owned;
880d0321e0SJeremy L Thompson   }
890d0321e0SJeremy L Thompson 
90539ec17dSJeremy L Thompson   CeedSize length;
912b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
92539ec17dSJeremy L Thompson   size_t bytes = length * sizeof(CeedScalar);
932b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
940d0321e0SJeremy L Thompson 
950d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
960d0321e0SJeremy L Thompson }
970d0321e0SJeremy L Thompson 
980d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
990d0321e0SJeremy L Thompson // Sync arrays
1000d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1012b730f8bSJeremy L Thompson static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
102f48ed27dSnbeams   // Check whether device/host sync is needed
103f48ed27dSnbeams   bool need_sync = false;
1042b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
1052b730f8bSJeremy L Thompson   if (!need_sync) return CEED_ERROR_SUCCESS;
106f48ed27dSnbeams 
10743c928f4SJeremy L Thompson   switch (mem_type) {
1082b730f8bSJeremy L Thompson     case CEED_MEM_HOST:
1092b730f8bSJeremy L Thompson       return CeedVectorSyncD2H_Hip(vec);
1102b730f8bSJeremy L Thompson     case CEED_MEM_DEVICE:
1112b730f8bSJeremy L Thompson       return CeedVectorSyncH2D_Hip(vec);
1120d0321e0SJeremy L Thompson   }
1130d0321e0SJeremy L Thompson   return CEED_ERROR_UNSUPPORTED;
1140d0321e0SJeremy L Thompson }
1150d0321e0SJeremy L Thompson 
1160d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1170d0321e0SJeremy L Thompson // Set all pointers as invalid
1180d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1190d0321e0SJeremy L Thompson static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
1200d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
1212b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
1220d0321e0SJeremy L Thompson 
1230d0321e0SJeremy L Thompson   impl->h_array = NULL;
1240d0321e0SJeremy L Thompson   impl->d_array = NULL;
1250d0321e0SJeremy L Thompson 
1260d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1270d0321e0SJeremy L Thompson }
1280d0321e0SJeremy L Thompson 
1290d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1300d0321e0SJeremy L Thompson // Check if CeedVector has any valid pointers
1310d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1322b730f8bSJeremy L Thompson static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
1330d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
1342b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
1350d0321e0SJeremy L Thompson 
1360d0321e0SJeremy L Thompson   *has_valid_array = !!impl->h_array || !!impl->d_array;
1370d0321e0SJeremy L Thompson 
1380d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1390d0321e0SJeremy L Thompson }
1400d0321e0SJeremy L Thompson 
1410d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1420d0321e0SJeremy L Thompson // Check if has any array of given type
1430d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1442b730f8bSJeremy L Thompson static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
1450d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
1462b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
1470d0321e0SJeremy L Thompson 
14843c928f4SJeremy L Thompson   switch (mem_type) {
1490d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
1500d0321e0SJeremy L Thompson       *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
1510d0321e0SJeremy L Thompson       break;
1520d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
1530d0321e0SJeremy L Thompson       *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
1540d0321e0SJeremy L Thompson       break;
1550d0321e0SJeremy L Thompson   }
1560d0321e0SJeremy L Thompson 
1570d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1580d0321e0SJeremy L Thompson }
1590d0321e0SJeremy L Thompson 
1600d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1610d0321e0SJeremy L Thompson // Check if has borrowed array of given type
1620d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1632b730f8bSJeremy L Thompson static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
1640d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
1652b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
1660d0321e0SJeremy L Thompson 
16743c928f4SJeremy L Thompson   switch (mem_type) {
1680d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
1690d0321e0SJeremy L Thompson       *has_borrowed_array_of_type = !!impl->h_array_borrowed;
1700d0321e0SJeremy L Thompson       break;
1710d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
1720d0321e0SJeremy L Thompson       *has_borrowed_array_of_type = !!impl->d_array_borrowed;
1730d0321e0SJeremy L Thompson       break;
1740d0321e0SJeremy L Thompson   }
1750d0321e0SJeremy L Thompson 
1760d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1770d0321e0SJeremy L Thompson }
1780d0321e0SJeremy L Thompson 
1790d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1800d0321e0SJeremy L Thompson // Set array from host
1810d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1822b730f8bSJeremy L Thompson static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
1830d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
1842b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
1850d0321e0SJeremy L Thompson 
18643c928f4SJeremy L Thompson   switch (copy_mode) {
1870d0321e0SJeremy L Thompson     case CEED_COPY_VALUES: {
1881f9221feSJeremy L Thompson       CeedSize length;
1890d0321e0SJeremy L Thompson       if (!impl->h_array_owned) {
1902b730f8bSJeremy L Thompson         CeedCallBackend(CeedVectorGetLength(vec, &length));
1912b730f8bSJeremy L Thompson         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
1920d0321e0SJeremy L Thompson       }
1930d0321e0SJeremy L Thompson       impl->h_array_borrowed = NULL;
1940d0321e0SJeremy L Thompson       impl->h_array          = impl->h_array_owned;
195539ec17dSJeremy L Thompson       if (array) {
196539ec17dSJeremy L Thompson         CeedSize length;
1972b730f8bSJeremy L Thompson         CeedCallBackend(CeedVectorGetLength(vec, &length));
198539ec17dSJeremy L Thompson         size_t bytes = length * sizeof(CeedScalar);
199539ec17dSJeremy L Thompson         memcpy(impl->h_array, array, bytes);
200539ec17dSJeremy L Thompson       }
2010d0321e0SJeremy L Thompson     } break;
2020d0321e0SJeremy L Thompson     case CEED_OWN_POINTER:
2032b730f8bSJeremy L Thompson       CeedCallBackend(CeedFree(&impl->h_array_owned));
2040d0321e0SJeremy L Thompson       impl->h_array_owned    = array;
2050d0321e0SJeremy L Thompson       impl->h_array_borrowed = NULL;
2060d0321e0SJeremy L Thompson       impl->h_array          = array;
2070d0321e0SJeremy L Thompson       break;
2080d0321e0SJeremy L Thompson     case CEED_USE_POINTER:
2092b730f8bSJeremy L Thompson       CeedCallBackend(CeedFree(&impl->h_array_owned));
2100d0321e0SJeremy L Thompson       impl->h_array_borrowed = array;
2110d0321e0SJeremy L Thompson       impl->h_array          = array;
2120d0321e0SJeremy L Thompson       break;
2130d0321e0SJeremy L Thompson   }
2140d0321e0SJeremy L Thompson 
2150d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2160d0321e0SJeremy L Thompson }
2170d0321e0SJeremy L Thompson 
2180d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2190d0321e0SJeremy L Thompson // Set array from device
2200d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2212b730f8bSJeremy L Thompson static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
2220d0321e0SJeremy L Thompson   Ceed ceed;
2232b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
2240d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
2252b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
2260d0321e0SJeremy L Thompson 
22743c928f4SJeremy L Thompson   switch (copy_mode) {
228539ec17dSJeremy L Thompson     case CEED_COPY_VALUES: {
229539ec17dSJeremy L Thompson       CeedSize length;
2302b730f8bSJeremy L Thompson       CeedCallBackend(CeedVectorGetLength(vec, &length));
231539ec17dSJeremy L Thompson       size_t bytes = length * sizeof(CeedScalar);
2320d0321e0SJeremy L Thompson       if (!impl->d_array_owned) {
2332b730f8bSJeremy L Thompson         CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
2340d0321e0SJeremy L Thompson       }
2350d0321e0SJeremy L Thompson       impl->d_array_borrowed = NULL;
2360d0321e0SJeremy L Thompson       impl->d_array          = impl->d_array_owned;
2370d0321e0SJeremy L Thompson       if (array) {
2382b730f8bSJeremy L Thompson         CeedCallHip(ceed, hipMemcpy(impl->d_array, array, bytes, hipMemcpyDeviceToDevice));
2390d0321e0SJeremy L Thompson       }
240539ec17dSJeremy L Thompson     } break;
2410d0321e0SJeremy L Thompson     case CEED_OWN_POINTER:
2422b730f8bSJeremy L Thompson       CeedCallHip(ceed, hipFree(impl->d_array_owned));
2430d0321e0SJeremy L Thompson       impl->d_array_owned    = array;
2440d0321e0SJeremy L Thompson       impl->d_array_borrowed = NULL;
2450d0321e0SJeremy L Thompson       impl->d_array          = array;
2460d0321e0SJeremy L Thompson       break;
2470d0321e0SJeremy L Thompson     case CEED_USE_POINTER:
2482b730f8bSJeremy L Thompson       CeedCallHip(ceed, hipFree(impl->d_array_owned));
2490d0321e0SJeremy L Thompson       impl->d_array_owned    = NULL;
2500d0321e0SJeremy L Thompson       impl->d_array_borrowed = array;
2510d0321e0SJeremy L Thompson       impl->d_array          = array;
2520d0321e0SJeremy L Thompson       break;
2530d0321e0SJeremy L Thompson   }
2540d0321e0SJeremy L Thompson 
2550d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2560d0321e0SJeremy L Thompson }
2570d0321e0SJeremy L Thompson 
2580d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2590d0321e0SJeremy L Thompson // Set the array used by a vector,
2600d0321e0SJeremy L Thompson //   freeing any previously allocated array if applicable
2610d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2622b730f8bSJeremy L Thompson static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
2630d0321e0SJeremy L Thompson   Ceed ceed;
2642b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
2650d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
2662b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
2670d0321e0SJeremy L Thompson 
2682b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
26943c928f4SJeremy L Thompson   switch (mem_type) {
2700d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
27143c928f4SJeremy L Thompson       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
2720d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
27343c928f4SJeremy L Thompson       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
2740d0321e0SJeremy L Thompson   }
2750d0321e0SJeremy L Thompson 
2760d0321e0SJeremy L Thompson   return CEED_ERROR_UNSUPPORTED;
2770d0321e0SJeremy L Thompson }
2780d0321e0SJeremy L Thompson 
2790d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2800d0321e0SJeremy L Thompson // Set host array to value
2810d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
282*9330daecSnbeams static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
283*9330daecSnbeams   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
2840d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2850d0321e0SJeremy L Thompson }
2860d0321e0SJeremy L Thompson 
2870d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2880d0321e0SJeremy L Thompson // Set device array to value (impl in .hip file)
2890d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
290*9330daecSnbeams int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
2910d0321e0SJeremy L Thompson 
2920d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2930d0321e0SJeremy L Thompson // Set a vector to a value,
2940d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2950d0321e0SJeremy L Thompson static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
2960d0321e0SJeremy L Thompson   Ceed ceed;
2972b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
2980d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
2992b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
3001f9221feSJeremy L Thompson   CeedSize length;
3012b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
3020d0321e0SJeremy L Thompson 
3030d0321e0SJeremy L Thompson   // Set value for synced device/host array
3040d0321e0SJeremy L Thompson   if (!impl->d_array && !impl->h_array) {
3050d0321e0SJeremy L Thompson     if (impl->d_array_borrowed) {
3060d0321e0SJeremy L Thompson       impl->d_array = impl->d_array_borrowed;
3070d0321e0SJeremy L Thompson     } else if (impl->h_array_borrowed) {
3080d0321e0SJeremy L Thompson       impl->h_array = impl->h_array_borrowed;
3090d0321e0SJeremy L Thompson     } else if (impl->d_array_owned) {
3100d0321e0SJeremy L Thompson       impl->d_array = impl->d_array_owned;
3110d0321e0SJeremy L Thompson     } else if (impl->h_array_owned) {
3120d0321e0SJeremy L Thompson       impl->h_array = impl->h_array_owned;
3130d0321e0SJeremy L Thompson     } else {
3142b730f8bSJeremy L Thompson       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
3150d0321e0SJeremy L Thompson     }
3160d0321e0SJeremy L Thompson   }
3170d0321e0SJeremy L Thompson   if (impl->d_array) {
3182b730f8bSJeremy L Thompson     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
3190d0321e0SJeremy L Thompson   }
3200d0321e0SJeremy L Thompson   if (impl->h_array) {
3212b730f8bSJeremy L Thompson     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
3220d0321e0SJeremy L Thompson   }
3230d0321e0SJeremy L Thompson 
3240d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
3250d0321e0SJeremy L Thompson }
3260d0321e0SJeremy L Thompson 
3270d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3280d0321e0SJeremy L Thompson // Vector Take Array
3290d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3302b730f8bSJeremy L Thompson static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
3310d0321e0SJeremy L Thompson   Ceed ceed;
3322b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
3330d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
3342b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
3350d0321e0SJeremy L Thompson 
33643c928f4SJeremy L Thompson   // Sync array to requested mem_type
3372b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
3380d0321e0SJeremy L Thompson 
3390d0321e0SJeremy L Thompson   // Update pointer
34043c928f4SJeremy L Thompson   switch (mem_type) {
3410d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
3420d0321e0SJeremy L Thompson       (*array)               = impl->h_array_borrowed;
3430d0321e0SJeremy L Thompson       impl->h_array_borrowed = NULL;
3440d0321e0SJeremy L Thompson       impl->h_array          = NULL;
3450d0321e0SJeremy L Thompson       break;
3460d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
3470d0321e0SJeremy L Thompson       (*array)               = impl->d_array_borrowed;
3480d0321e0SJeremy L Thompson       impl->d_array_borrowed = NULL;
3490d0321e0SJeremy L Thompson       impl->d_array          = NULL;
3500d0321e0SJeremy L Thompson       break;
3510d0321e0SJeremy L Thompson   }
3520d0321e0SJeremy L Thompson 
3530d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
3540d0321e0SJeremy L Thompson }
3550d0321e0SJeremy L Thompson 
3560d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3570d0321e0SJeremy L Thompson // Core logic for array syncronization for GetArray.
3580d0321e0SJeremy L Thompson //   If a different memory type is most up to date, this will perform a copy
3590d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3602b730f8bSJeremy L Thompson static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
3610d0321e0SJeremy L Thompson   Ceed ceed;
3622b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
3630d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
3642b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
3650d0321e0SJeremy L Thompson 
36643c928f4SJeremy L Thompson   // Sync array to requested mem_type
3672b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
3680d0321e0SJeremy L Thompson 
3690d0321e0SJeremy L Thompson   // Update pointer
37043c928f4SJeremy L Thompson   switch (mem_type) {
3710d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
3720d0321e0SJeremy L Thompson       *array = impl->h_array;
3730d0321e0SJeremy L Thompson       break;
3740d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
3750d0321e0SJeremy L Thompson       *array = impl->d_array;
3760d0321e0SJeremy L Thompson       break;
3770d0321e0SJeremy L Thompson   }
3780d0321e0SJeremy L Thompson 
3790d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
3800d0321e0SJeremy L Thompson }
3810d0321e0SJeremy L Thompson 
3820d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
38343c928f4SJeremy L Thompson // Get read-only access to a vector via the specified mem_type
3840d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3852b730f8bSJeremy L Thompson static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
38643c928f4SJeremy L Thompson   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
3870d0321e0SJeremy L Thompson }
3880d0321e0SJeremy L Thompson 
3890d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
39043c928f4SJeremy L Thompson // Get read/write access to a vector via the specified mem_type
3910d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3922b730f8bSJeremy L Thompson static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
3930d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
3942b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
3950d0321e0SJeremy L Thompson 
3962b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
3970d0321e0SJeremy L Thompson 
3982b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
39943c928f4SJeremy L Thompson   switch (mem_type) {
4000d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
4010d0321e0SJeremy L Thompson       impl->h_array = *array;
4020d0321e0SJeremy L Thompson       break;
4030d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
4040d0321e0SJeremy L Thompson       impl->d_array = *array;
4050d0321e0SJeremy L Thompson       break;
4060d0321e0SJeremy L Thompson   }
4070d0321e0SJeremy L Thompson 
4080d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
4090d0321e0SJeremy L Thompson }
4100d0321e0SJeremy L Thompson 
4110d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
41243c928f4SJeremy L Thompson // Get write access to a vector via the specified mem_type
4130d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
4142b730f8bSJeremy L Thompson static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
4150d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
4162b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
4170d0321e0SJeremy L Thompson 
4180d0321e0SJeremy L Thompson   bool has_array_of_type = true;
4192b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
4200d0321e0SJeremy L Thompson   if (!has_array_of_type) {
4210d0321e0SJeremy L Thompson     // Allocate if array is not yet allocated
4222b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
4230d0321e0SJeremy L Thompson   } else {
4240d0321e0SJeremy L Thompson     // Select dirty array
42543c928f4SJeremy L Thompson     switch (mem_type) {
4260d0321e0SJeremy L Thompson       case CEED_MEM_HOST:
4272b730f8bSJeremy L Thompson         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
4282b730f8bSJeremy L Thompson         else impl->h_array = impl->h_array_owned;
4290d0321e0SJeremy L Thompson         break;
4300d0321e0SJeremy L Thompson       case CEED_MEM_DEVICE:
4312b730f8bSJeremy L Thompson         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
4322b730f8bSJeremy L Thompson         else impl->d_array = impl->d_array_owned;
4330d0321e0SJeremy L Thompson     }
4340d0321e0SJeremy L Thompson   }
4350d0321e0SJeremy L Thompson 
43643c928f4SJeremy L Thompson   return CeedVectorGetArray_Hip(vec, mem_type, array);
4370d0321e0SJeremy L Thompson }
4380d0321e0SJeremy L Thompson 
4390d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
4400d0321e0SJeremy L Thompson // Get the norm of a CeedVector
4410d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
4422b730f8bSJeremy L Thompson static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
4430d0321e0SJeremy L Thompson   Ceed ceed;
4442b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
4450d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
4462b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
4471f9221feSJeremy L Thompson   CeedSize length;
4482b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
4490d0321e0SJeremy L Thompson   hipblasHandle_t handle;
450eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
4510d0321e0SJeremy L Thompson 
452*9330daecSnbeams   // Is the vector too long to handle with int32? If so, we will divide
453*9330daecSnbeams   // it up into "int32-sized" subsections and make repeated BLAS calls.
454*9330daecSnbeams   CeedSize num_calls = length / INT_MAX;
455*9330daecSnbeams   if (length % INT_MAX > 0) num_calls += 1;
456*9330daecSnbeams 
4570d0321e0SJeremy L Thompson   // Compute norm
4580d0321e0SJeremy L Thompson   const CeedScalar *d_array;
4592b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
4600d0321e0SJeremy L Thompson   switch (type) {
4610d0321e0SJeremy L Thompson     case CEED_NORM_1: {
4620d0321e0SJeremy L Thompson       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
463*9330daecSnbeams         if (num_calls <= 1) CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)length, (float *)d_array, 1, (float *)norm));
464*9330daecSnbeams         else {
465*9330daecSnbeams           float  sub_norm = 0.0;
466*9330daecSnbeams           float *d_array_start;
467*9330daecSnbeams           for (CeedInt i = 0; i < num_calls; i++) {
468*9330daecSnbeams             d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
469*9330daecSnbeams             CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
470*9330daecSnbeams             CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
471*9330daecSnbeams             CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
472*9330daecSnbeams             *norm += sub_norm;
473*9330daecSnbeams           }
474*9330daecSnbeams         }
4750d0321e0SJeremy L Thompson       } else {
476*9330daecSnbeams         if (num_calls <= 1) CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)length, (double *)d_array, 1, (double *)norm));
477*9330daecSnbeams         else {
478*9330daecSnbeams           double  sub_norm = 0.0;
479*9330daecSnbeams           double *d_array_start;
480*9330daecSnbeams           for (CeedInt i = 0; i < num_calls; i++) {
481*9330daecSnbeams             d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
482*9330daecSnbeams             CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
483*9330daecSnbeams             CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
484*9330daecSnbeams             CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
485*9330daecSnbeams             *norm += sub_norm;
486*9330daecSnbeams           }
487*9330daecSnbeams         }
4880d0321e0SJeremy L Thompson       }
4890d0321e0SJeremy L Thompson       break;
4900d0321e0SJeremy L Thompson     }
4910d0321e0SJeremy L Thompson     case CEED_NORM_2: {
4920d0321e0SJeremy L Thompson       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
493*9330daecSnbeams         if (num_calls <= 1) CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)length, (float *)d_array, 1, (float *)norm));
494*9330daecSnbeams         else {
495*9330daecSnbeams           float  sub_norm = 0.0, norm_sum = 0.0;
496*9330daecSnbeams           float *d_array_start;
497*9330daecSnbeams           for (CeedInt i = 0; i < num_calls; i++) {
498*9330daecSnbeams             d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
499*9330daecSnbeams             CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
500*9330daecSnbeams             CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
501*9330daecSnbeams             CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
502*9330daecSnbeams             norm_sum += sub_norm * sub_norm;
503*9330daecSnbeams           }
504*9330daecSnbeams           *norm = sqrt(norm_sum);
505*9330daecSnbeams         }
5060d0321e0SJeremy L Thompson       } else {
507*9330daecSnbeams         if (num_calls <= 1) CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)length, (double *)d_array, 1, (double *)norm));
508*9330daecSnbeams         else {
509*9330daecSnbeams           double  sub_norm = 0.0, norm_sum = 0.0;
510*9330daecSnbeams           double *d_array_start;
511*9330daecSnbeams           for (CeedInt i = 0; i < num_calls; i++) {
512*9330daecSnbeams             d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
513*9330daecSnbeams             CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
514*9330daecSnbeams             CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
515*9330daecSnbeams             CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
516*9330daecSnbeams             norm_sum += sub_norm * sub_norm;
517*9330daecSnbeams           }
518*9330daecSnbeams           *norm = sqrt(norm_sum);
519*9330daecSnbeams         }
5200d0321e0SJeremy L Thompson       }
5210d0321e0SJeremy L Thompson       break;
5220d0321e0SJeremy L Thompson     }
5230d0321e0SJeremy L Thompson     case CEED_NORM_MAX: {
5240d0321e0SJeremy L Thompson       CeedInt indx;
5250d0321e0SJeremy L Thompson       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
526*9330daecSnbeams         if (num_calls <= 1) {
527*9330daecSnbeams           CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)length, (float *)d_array, 1, &indx));
5280d0321e0SJeremy L Thompson           CeedScalar normNoAbs;
5292b730f8bSJeremy L Thompson           CeedCallHip(ceed, hipMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
5300d0321e0SJeremy L Thompson           *norm = fabs(normNoAbs);
531*9330daecSnbeams         } else {
532*9330daecSnbeams           float  sub_max = 0.0, current_max = 0.0;
533*9330daecSnbeams           float *d_array_start;
534*9330daecSnbeams           for (CeedInt i = 0; i < num_calls; i++) {
535*9330daecSnbeams             d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
536*9330daecSnbeams             CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
537*9330daecSnbeams             CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
538*9330daecSnbeams             CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &indx));
539*9330daecSnbeams             CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
540*9330daecSnbeams             if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
541*9330daecSnbeams           }
542*9330daecSnbeams           *norm = current_max;
543*9330daecSnbeams         }
544*9330daecSnbeams       } else {
545*9330daecSnbeams         if (num_calls <= 1) {
546*9330daecSnbeams           CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)length, (double *)d_array, 1, &indx));
547*9330daecSnbeams           CeedScalar normNoAbs;
548*9330daecSnbeams           CeedCallHip(ceed, hipMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
549*9330daecSnbeams           *norm = fabs(normNoAbs);
550*9330daecSnbeams         } else {
551*9330daecSnbeams           double  sub_max = 0.0, current_max = 0.0;
552*9330daecSnbeams           double *d_array_start;
553*9330daecSnbeams           for (CeedInt i = 0; i < num_calls; i++) {
554*9330daecSnbeams             d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
555*9330daecSnbeams             CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
556*9330daecSnbeams             CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
557*9330daecSnbeams             CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &indx));
558*9330daecSnbeams             CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
559*9330daecSnbeams             if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
560*9330daecSnbeams           }
561*9330daecSnbeams           *norm = current_max;
562*9330daecSnbeams         }
563*9330daecSnbeams       }
5640d0321e0SJeremy L Thompson       break;
5650d0321e0SJeremy L Thompson     }
5660d0321e0SJeremy L Thompson   }
5672b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
5680d0321e0SJeremy L Thompson 
5690d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
5700d0321e0SJeremy L Thompson }
5710d0321e0SJeremy L Thompson 
5720d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5730d0321e0SJeremy L Thompson // Take reciprocal of a vector on host
5740d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
575*9330daecSnbeams static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
576*9330daecSnbeams   for (CeedSize i = 0; i < length; i++) {
5772b730f8bSJeremy L Thompson     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
5782b730f8bSJeremy L Thompson   }
5790d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
5800d0321e0SJeremy L Thompson }
5810d0321e0SJeremy L Thompson 
5820d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5830d0321e0SJeremy L Thompson // Take reciprocal of a vector on device (impl in .cu file)
5840d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
585*9330daecSnbeams int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
5860d0321e0SJeremy L Thompson 
5870d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5880d0321e0SJeremy L Thompson // Take reciprocal of a vector
5890d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5900d0321e0SJeremy L Thompson static int CeedVectorReciprocal_Hip(CeedVector vec) {
5910d0321e0SJeremy L Thompson   Ceed ceed;
5922b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
5930d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
5942b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
5951f9221feSJeremy L Thompson   CeedSize length;
5962b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
5970d0321e0SJeremy L Thompson 
5980d0321e0SJeremy L Thompson   // Set value for synced device/host array
5992b730f8bSJeremy L Thompson   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
6002b730f8bSJeremy L Thompson   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
6010d0321e0SJeremy L Thompson 
6020d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
6030d0321e0SJeremy L Thompson }
6040d0321e0SJeremy L Thompson 
6050d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6060d0321e0SJeremy L Thompson // Compute x = alpha x on the host
6070d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
608*9330daecSnbeams static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
609*9330daecSnbeams   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
6100d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
6110d0321e0SJeremy L Thompson }
6120d0321e0SJeremy L Thompson 
6130d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6140d0321e0SJeremy L Thompson // Compute x = alpha x on device (impl in .cu file)
6150d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
616*9330daecSnbeams int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
6170d0321e0SJeremy L Thompson 
6180d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6190d0321e0SJeremy L Thompson // Compute x = alpha x
6200d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6210d0321e0SJeremy L Thompson static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
6220d0321e0SJeremy L Thompson   Ceed ceed;
6232b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
6240d0321e0SJeremy L Thompson   CeedVector_Hip *x_impl;
6252b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(x, &x_impl));
6261f9221feSJeremy L Thompson   CeedSize length;
6272b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(x, &length));
6280d0321e0SJeremy L Thompson 
6290d0321e0SJeremy L Thompson   // Set value for synced device/host array
6302b730f8bSJeremy L Thompson   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
6312b730f8bSJeremy L Thompson   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
6320d0321e0SJeremy L Thompson 
6330d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
6340d0321e0SJeremy L Thompson }
6350d0321e0SJeremy L Thompson 
6360d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6370d0321e0SJeremy L Thompson // Compute y = alpha x + y on the host
6380d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
639*9330daecSnbeams static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
640*9330daecSnbeams   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
6410d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
6420d0321e0SJeremy L Thompson }
6430d0321e0SJeremy L Thompson 
6440d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6450d0321e0SJeremy L Thompson // Compute y = alpha x + y on device (impl in .cu file)
6460d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
647*9330daecSnbeams int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
6480d0321e0SJeremy L Thompson 
6490d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6500d0321e0SJeremy L Thompson // Compute y = alpha x + y
6510d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6520d0321e0SJeremy L Thompson static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
6530d0321e0SJeremy L Thompson   Ceed ceed;
6542b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
6550d0321e0SJeremy L Thompson   CeedVector_Hip *y_impl, *x_impl;
6562b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(y, &y_impl));
6572b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(x, &x_impl));
6581f9221feSJeremy L Thompson   CeedSize length;
6592b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(y, &length));
6600d0321e0SJeremy L Thompson 
6610d0321e0SJeremy L Thompson   // Set value for synced device/host array
6620d0321e0SJeremy L Thompson   if (y_impl->d_array) {
6632b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
6642b730f8bSJeremy L Thompson     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
6650d0321e0SJeremy L Thompson   }
6660d0321e0SJeremy L Thompson   if (y_impl->h_array) {
6672b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
6682b730f8bSJeremy L Thompson     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
6690d0321e0SJeremy L Thompson   }
6700d0321e0SJeremy L Thompson 
6710d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
6720d0321e0SJeremy L Thompson }
673ff1e7120SSebastian Grimberg 
6745fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
6755fb68f37SKaren (Ren) Stengel // Compute y = alpha x + beta y on the host
6765fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
677*9330daecSnbeams static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
678*9330daecSnbeams   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
6795fb68f37SKaren (Ren) Stengel   return CEED_ERROR_SUCCESS;
6805fb68f37SKaren (Ren) Stengel }
6815fb68f37SKaren (Ren) Stengel 
6825fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
6835fb68f37SKaren (Ren) Stengel // Compute y = alpha x + beta y on device (impl in .cu file)
6845fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
685*9330daecSnbeams int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
6865fb68f37SKaren (Ren) Stengel 
6875fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
6885fb68f37SKaren (Ren) Stengel // Compute y = alpha x + beta y
6895fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
6905fb68f37SKaren (Ren) Stengel static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
6915fb68f37SKaren (Ren) Stengel   Ceed ceed;
6925fb68f37SKaren (Ren) Stengel   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
6935fb68f37SKaren (Ren) Stengel   CeedVector_Hip *y_impl, *x_impl;
6945fb68f37SKaren (Ren) Stengel   CeedCallBackend(CeedVectorGetData(y, &y_impl));
6955fb68f37SKaren (Ren) Stengel   CeedCallBackend(CeedVectorGetData(x, &x_impl));
6965fb68f37SKaren (Ren) Stengel   CeedSize length;
6975fb68f37SKaren (Ren) Stengel   CeedCallBackend(CeedVectorGetLength(y, &length));
6985fb68f37SKaren (Ren) Stengel 
6995fb68f37SKaren (Ren) Stengel   // Set value for synced device/host array
7005fb68f37SKaren (Ren) Stengel   if (y_impl->d_array) {
7015fb68f37SKaren (Ren) Stengel     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
7025fb68f37SKaren (Ren) Stengel     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
7035fb68f37SKaren (Ren) Stengel   }
7045fb68f37SKaren (Ren) Stengel   if (y_impl->h_array) {
7055fb68f37SKaren (Ren) Stengel     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
7065fb68f37SKaren (Ren) Stengel     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
7075fb68f37SKaren (Ren) Stengel   }
7085fb68f37SKaren (Ren) Stengel 
7095fb68f37SKaren (Ren) Stengel   return CEED_ERROR_SUCCESS;
7105fb68f37SKaren (Ren) Stengel }
7110d0321e0SJeremy L Thompson 
7120d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
7130d0321e0SJeremy L Thompson // Compute the pointwise multiplication w = x .* y on the host
7140d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
715*9330daecSnbeams static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
716*9330daecSnbeams   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
7170d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
7180d0321e0SJeremy L Thompson }
7190d0321e0SJeremy L Thompson 
7200d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
7210d0321e0SJeremy L Thompson // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
7220d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
723*9330daecSnbeams int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
7240d0321e0SJeremy L Thompson 
7250d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
7260d0321e0SJeremy L Thompson // Compute the pointwise multiplication w = x .* y
7270d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
7282b730f8bSJeremy L Thompson static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
7290d0321e0SJeremy L Thompson   Ceed ceed;
7302b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
7310d0321e0SJeremy L Thompson   CeedVector_Hip *w_impl, *x_impl, *y_impl;
7322b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(w, &w_impl));
7332b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(x, &x_impl));
7342b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(y, &y_impl));
7351f9221feSJeremy L Thompson   CeedSize length;
7362b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(w, &length));
7370d0321e0SJeremy L Thompson 
7380d0321e0SJeremy L Thompson   // Set value for synced device/host array
7390d0321e0SJeremy L Thompson   if (!w_impl->d_array && !w_impl->h_array) {
7402b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSetValue(w, 0.0));
7410d0321e0SJeremy L Thompson   }
7420d0321e0SJeremy L Thompson   if (w_impl->d_array) {
7432b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
7442b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
7452b730f8bSJeremy L Thompson     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
7460d0321e0SJeremy L Thompson   }
7470d0321e0SJeremy L Thompson   if (w_impl->h_array) {
7482b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
7492b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
7502b730f8bSJeremy L Thompson     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
7510d0321e0SJeremy L Thompson   }
7520d0321e0SJeremy L Thompson 
7530d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
7540d0321e0SJeremy L Thompson }
7550d0321e0SJeremy L Thompson 
7560d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
7570d0321e0SJeremy L Thompson // Destroy the vector
7580d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
7590d0321e0SJeremy L Thompson static int CeedVectorDestroy_Hip(const CeedVector vec) {
7600d0321e0SJeremy L Thompson   Ceed ceed;
7612b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
7620d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
7632b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
7640d0321e0SJeremy L Thompson 
7652b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipFree(impl->d_array_owned));
7662b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&impl->h_array_owned));
7672b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&impl));
7680d0321e0SJeremy L Thompson 
7690d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
7700d0321e0SJeremy L Thompson }
7710d0321e0SJeremy L Thompson 
7720d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
7730d0321e0SJeremy L Thompson // Create a vector of the specified length (does not allocate memory)
7740d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
7751f9221feSJeremy L Thompson int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
7760d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
7770d0321e0SJeremy L Thompson   Ceed            ceed;
7782b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
7790d0321e0SJeremy L Thompson 
7802b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
7812b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
7822b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
7832b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
784008736bdSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Hip));
7852b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
7862b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
7872b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
7882b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
7892b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
7902b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
791008736bdSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Hip));
792008736bdSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Hip));
793008736bdSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Hip));
7942b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
7952b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
7960d0321e0SJeremy L Thompson 
7972b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(1, &impl));
7982b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorSetData(vec, impl));
7990d0321e0SJeremy L Thompson 
8000d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
8010d0321e0SJeremy L Thompson }
8022a86cc9dSSebastian Grimberg 
8032a86cc9dSSebastian Grimberg //------------------------------------------------------------------------------
804