xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision 5aed82e4fa97acf4ba24a7f10a35f5303a6798e0)
1*5aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
30d0321e0SJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
50d0321e0SJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
70d0321e0SJeremy L Thompson 
849aac155SJeremy L Thompson #include <ceed.h>
90d0321e0SJeremy L Thompson #include <ceed/backend.h>
100d0321e0SJeremy L Thompson #include <math.h>
1149aac155SJeremy L Thompson #include <stdbool.h>
120d0321e0SJeremy L Thompson #include <string.h>
13c85e8640SSebastian Grimberg #include <hip/hip_runtime.h>
140d0321e0SJeremy L Thompson 
1549aac155SJeremy L Thompson #include "../hip/ceed-hip-common.h"
162b730f8bSJeremy L Thompson #include "ceed-hip-ref.h"
17f48ed27dSnbeams 
18f48ed27dSnbeams //------------------------------------------------------------------------------
19f48ed27dSnbeams // Check if host/device sync is needed
20f48ed27dSnbeams //------------------------------------------------------------------------------
212b730f8bSJeremy L Thompson static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22f48ed27dSnbeams   CeedVector_Hip *impl;
23f48ed27dSnbeams   bool            has_valid_array = false;
24b7453713SJeremy L Thompson 
25b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
262b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27f48ed27dSnbeams   switch (mem_type) {
28f48ed27dSnbeams     case CEED_MEM_HOST:
29f48ed27dSnbeams       *need_sync = has_valid_array && !impl->h_array;
30f48ed27dSnbeams       break;
31f48ed27dSnbeams     case CEED_MEM_DEVICE:
32f48ed27dSnbeams       *need_sync = has_valid_array && !impl->d_array;
33f48ed27dSnbeams       break;
34f48ed27dSnbeams   }
35f48ed27dSnbeams   return CEED_ERROR_SUCCESS;
36f48ed27dSnbeams }
37f48ed27dSnbeams 
380d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
390d0321e0SJeremy L Thompson // Sync host to device
400d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
410d0321e0SJeremy L Thompson static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
420d0321e0SJeremy L Thompson   Ceed            ceed;
43b7453713SJeremy L Thompson   CeedSize        length;
44672b0f2aSSebastian Grimberg   size_t          bytes;
450d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
46b7453713SJeremy L Thompson 
47b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
482b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
490d0321e0SJeremy L Thompson 
506574a04fSJeremy L Thompson   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
510d0321e0SJeremy L Thompson 
52672b0f2aSSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
53672b0f2aSSebastian Grimberg   bytes = length * sizeof(CeedScalar);
540d0321e0SJeremy L Thompson   if (impl->d_array_borrowed) {
550d0321e0SJeremy L Thompson     impl->d_array = impl->d_array_borrowed;
560d0321e0SJeremy L Thompson   } else if (impl->d_array_owned) {
570d0321e0SJeremy L Thompson     impl->d_array = impl->d_array_owned;
580d0321e0SJeremy L Thompson   } else {
592b730f8bSJeremy L Thompson     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
600d0321e0SJeremy L Thompson     impl->d_array = impl->d_array_owned;
610d0321e0SJeremy L Thompson   }
622b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
630d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
640d0321e0SJeremy L Thompson }
650d0321e0SJeremy L Thompson 
660d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
670d0321e0SJeremy L Thompson // Sync device to host
680d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
690d0321e0SJeremy L Thompson static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
700d0321e0SJeremy L Thompson   Ceed            ceed;
71b7453713SJeremy L Thompson   CeedSize        length;
72672b0f2aSSebastian Grimberg   size_t          bytes;
730d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
74b7453713SJeremy L Thompson 
75b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
762b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
770d0321e0SJeremy L Thompson 
786574a04fSJeremy L Thompson   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
790d0321e0SJeremy L Thompson 
800d0321e0SJeremy L Thompson   if (impl->h_array_borrowed) {
810d0321e0SJeremy L Thompson     impl->h_array = impl->h_array_borrowed;
820d0321e0SJeremy L Thompson   } else if (impl->h_array_owned) {
830d0321e0SJeremy L Thompson     impl->h_array = impl->h_array_owned;
840d0321e0SJeremy L Thompson   } else {
851f9221feSJeremy L Thompson     CeedSize length;
86672b0f2aSSebastian Grimberg 
872b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorGetLength(vec, &length));
882b730f8bSJeremy L Thompson     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
890d0321e0SJeremy L Thompson     impl->h_array = impl->h_array_owned;
900d0321e0SJeremy L Thompson   }
910d0321e0SJeremy L Thompson 
922b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
93672b0f2aSSebastian Grimberg   bytes = length * sizeof(CeedScalar);
94b7453713SJeremy L Thompson   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
950d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
960d0321e0SJeremy L Thompson }
970d0321e0SJeremy L Thompson 
980d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
990d0321e0SJeremy L Thompson // Sync arrays
1000d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1012b730f8bSJeremy L Thompson static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
102f48ed27dSnbeams   bool need_sync = false;
103b7453713SJeremy L Thompson 
104b7453713SJeremy L Thompson   // Check whether device/host sync is needed
1052b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
1062b730f8bSJeremy L Thompson   if (!need_sync) return CEED_ERROR_SUCCESS;
107f48ed27dSnbeams 
10843c928f4SJeremy L Thompson   switch (mem_type) {
1092b730f8bSJeremy L Thompson     case CEED_MEM_HOST:
1102b730f8bSJeremy L Thompson       return CeedVectorSyncD2H_Hip(vec);
1112b730f8bSJeremy L Thompson     case CEED_MEM_DEVICE:
1122b730f8bSJeremy L Thompson       return CeedVectorSyncH2D_Hip(vec);
1130d0321e0SJeremy L Thompson   }
1140d0321e0SJeremy L Thompson   return CEED_ERROR_UNSUPPORTED;
1150d0321e0SJeremy L Thompson }
1160d0321e0SJeremy L Thompson 
1170d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1180d0321e0SJeremy L Thompson // Set all pointers as invalid
1190d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1200d0321e0SJeremy L Thompson static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
1210d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
1220d0321e0SJeremy L Thompson 
123b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
1240d0321e0SJeremy L Thompson   impl->h_array = NULL;
1250d0321e0SJeremy L Thompson   impl->d_array = NULL;
1260d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1270d0321e0SJeremy L Thompson }
1280d0321e0SJeremy L Thompson 
1290d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
130b2165e7aSSebastian Grimberg // Check if CeedVector has any valid pointer
1310d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1322b730f8bSJeremy L Thompson static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
1330d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
134b7453713SJeremy L Thompson 
1352b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
1361c66c397SJeremy L Thompson   *has_valid_array = impl->h_array || impl->d_array;
1370d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1380d0321e0SJeremy L Thompson }
1390d0321e0SJeremy L Thompson 
1400d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
141b2165e7aSSebastian Grimberg // Check if has array of given type
1420d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1432b730f8bSJeremy L Thompson static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
1440d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
1450d0321e0SJeremy L Thompson 
146b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
14743c928f4SJeremy L Thompson   switch (mem_type) {
1480d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
1491c66c397SJeremy L Thompson       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
1500d0321e0SJeremy L Thompson       break;
1510d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
1521c66c397SJeremy L Thompson       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
1530d0321e0SJeremy L Thompson       break;
1540d0321e0SJeremy L Thompson   }
1550d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1560d0321e0SJeremy L Thompson }
1570d0321e0SJeremy L Thompson 
1580d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1590d0321e0SJeremy L Thompson // Check if has borrowed array of given type
1600d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1612b730f8bSJeremy L Thompson static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
1620d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
1630d0321e0SJeremy L Thompson 
164b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
16543c928f4SJeremy L Thompson   switch (mem_type) {
1660d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
1671c66c397SJeremy L Thompson       *has_borrowed_array_of_type = impl->h_array_borrowed;
1680d0321e0SJeremy L Thompson       break;
1690d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
1701c66c397SJeremy L Thompson       *has_borrowed_array_of_type = impl->d_array_borrowed;
1710d0321e0SJeremy L Thompson       break;
1720d0321e0SJeremy L Thompson   }
1730d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1740d0321e0SJeremy L Thompson }
1750d0321e0SJeremy L Thompson 
1760d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1770d0321e0SJeremy L Thompson // Set array from host
1780d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1792b730f8bSJeremy L Thompson static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
180a267acd1SJeremy L Thompson   CeedSize        length;
1810d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
1820d0321e0SJeremy L Thompson 
183b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
184a267acd1SJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
185a267acd1SJeremy L Thompson 
186f5d1e504SJeremy L Thompson   CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned,
187f5d1e504SJeremy L Thompson                                              (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array));
1880d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1890d0321e0SJeremy L Thompson }
1900d0321e0SJeremy L Thompson 
1910d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1920d0321e0SJeremy L Thompson // Set array from device
1930d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1942b730f8bSJeremy L Thompson static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
195a267acd1SJeremy L Thompson   CeedSize        length;
1960d0321e0SJeremy L Thompson   Ceed            ceed;
1970d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
1980d0321e0SJeremy L Thompson 
199b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
200b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
201a267acd1SJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
202f5d1e504SJeremy L Thompson 
203f5d1e504SJeremy L Thompson   CeedCallBackend(CeedSetDeviceCeedScalarArray_Hip(ceed, array, copy_mode, length, (const CeedScalar **)&impl->d_array_owned,
204f5d1e504SJeremy L Thompson                                                    (const CeedScalar **)&impl->d_array_borrowed, (const CeedScalar **)&impl->d_array));
2050d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2060d0321e0SJeremy L Thompson }
2070d0321e0SJeremy L Thompson 
2080d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2090d0321e0SJeremy L Thompson // Set the array used by a vector,
2100d0321e0SJeremy L Thompson //   freeing any previously allocated array if applicable
2110d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2122b730f8bSJeremy L Thompson static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
2130d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
2140d0321e0SJeremy L Thompson 
215b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
2162b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
21743c928f4SJeremy L Thompson   switch (mem_type) {
2180d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
21943c928f4SJeremy L Thompson       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
2200d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
22143c928f4SJeremy L Thompson       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
2220d0321e0SJeremy L Thompson   }
2230d0321e0SJeremy L Thompson   return CEED_ERROR_UNSUPPORTED;
2240d0321e0SJeremy L Thompson }
2250d0321e0SJeremy L Thompson 
2260d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2270d0321e0SJeremy L Thompson // Set host array to value
2280d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2299330daecSnbeams static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
2309330daecSnbeams   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
2310d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2320d0321e0SJeremy L Thompson }
2330d0321e0SJeremy L Thompson 
2340d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2350d0321e0SJeremy L Thompson // Set device array to value (impl in .hip file)
2360d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2379330daecSnbeams int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
2380d0321e0SJeremy L Thompson 
2390d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
240b2165e7aSSebastian Grimberg // Set a vector to a value
2410d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2420d0321e0SJeremy L Thompson static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
2431f9221feSJeremy L Thompson   CeedSize        length;
244b7453713SJeremy L Thompson   CeedVector_Hip *impl;
2450d0321e0SJeremy L Thompson 
246b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
247b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
2480d0321e0SJeremy L Thompson   // Set value for synced device/host array
2490d0321e0SJeremy L Thompson   if (!impl->d_array && !impl->h_array) {
2500d0321e0SJeremy L Thompson     if (impl->d_array_borrowed) {
2510d0321e0SJeremy L Thompson       impl->d_array = impl->d_array_borrowed;
2520d0321e0SJeremy L Thompson     } else if (impl->h_array_borrowed) {
2530d0321e0SJeremy L Thompson       impl->h_array = impl->h_array_borrowed;
2540d0321e0SJeremy L Thompson     } else if (impl->d_array_owned) {
2550d0321e0SJeremy L Thompson       impl->d_array = impl->d_array_owned;
2560d0321e0SJeremy L Thompson     } else if (impl->h_array_owned) {
2570d0321e0SJeremy L Thompson       impl->h_array = impl->h_array_owned;
2580d0321e0SJeremy L Thompson     } else {
2592b730f8bSJeremy L Thompson       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
2600d0321e0SJeremy L Thompson     }
2610d0321e0SJeremy L Thompson   }
2620d0321e0SJeremy L Thompson   if (impl->d_array) {
2632b730f8bSJeremy L Thompson     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
264b2165e7aSSebastian Grimberg     impl->h_array = NULL;
2650d0321e0SJeremy L Thompson   }
2660d0321e0SJeremy L Thompson   if (impl->h_array) {
2672b730f8bSJeremy L Thompson     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
268b2165e7aSSebastian Grimberg     impl->d_array = NULL;
2690d0321e0SJeremy L Thompson   }
2700d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2710d0321e0SJeremy L Thompson }
2720d0321e0SJeremy L Thompson 
2730d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2740d0321e0SJeremy L Thompson // Vector Take Array
2750d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2762b730f8bSJeremy L Thompson static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
2770d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
278b7453713SJeremy L Thompson 
2792b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
2800d0321e0SJeremy L Thompson 
28143c928f4SJeremy L Thompson   // Sync array to requested mem_type
2822b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
2830d0321e0SJeremy L Thompson 
2840d0321e0SJeremy L Thompson   // Update pointer
28543c928f4SJeremy L Thompson   switch (mem_type) {
2860d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
2870d0321e0SJeremy L Thompson       (*array)               = impl->h_array_borrowed;
2880d0321e0SJeremy L Thompson       impl->h_array_borrowed = NULL;
2890d0321e0SJeremy L Thompson       impl->h_array          = NULL;
2900d0321e0SJeremy L Thompson       break;
2910d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
2920d0321e0SJeremy L Thompson       (*array)               = impl->d_array_borrowed;
2930d0321e0SJeremy L Thompson       impl->d_array_borrowed = NULL;
2940d0321e0SJeremy L Thompson       impl->d_array          = NULL;
2950d0321e0SJeremy L Thompson       break;
2960d0321e0SJeremy L Thompson   }
2970d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2980d0321e0SJeremy L Thompson }
2990d0321e0SJeremy L Thompson 
3000d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3010d0321e0SJeremy L Thompson // Core logic for array syncronization for GetArray.
3020d0321e0SJeremy L Thompson //   If a different memory type is most up to date, this will perform a copy
3030d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3042b730f8bSJeremy L Thompson static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
3050d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
306b7453713SJeremy L Thompson 
3072b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
3080d0321e0SJeremy L Thompson 
30943c928f4SJeremy L Thompson   // Sync array to requested mem_type
3102b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
3110d0321e0SJeremy L Thompson 
3120d0321e0SJeremy L Thompson   // Update pointer
31343c928f4SJeremy L Thompson   switch (mem_type) {
3140d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
3150d0321e0SJeremy L Thompson       *array = impl->h_array;
3160d0321e0SJeremy L Thompson       break;
3170d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
3180d0321e0SJeremy L Thompson       *array = impl->d_array;
3190d0321e0SJeremy L Thompson       break;
3200d0321e0SJeremy L Thompson   }
3210d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
3220d0321e0SJeremy L Thompson }
3230d0321e0SJeremy L Thompson 
3240d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
32543c928f4SJeremy L Thompson // Get read-only access to a vector via the specified mem_type
3260d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3272b730f8bSJeremy L Thompson static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
32843c928f4SJeremy L Thompson   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
3290d0321e0SJeremy L Thompson }
3300d0321e0SJeremy L Thompson 
3310d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
33243c928f4SJeremy L Thompson // Get read/write access to a vector via the specified mem_type
3330d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3342b730f8bSJeremy L Thompson static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
3350d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
336b7453713SJeremy L Thompson 
3372b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
3382b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
3392b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
34043c928f4SJeremy L Thompson   switch (mem_type) {
3410d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
3420d0321e0SJeremy L Thompson       impl->h_array = *array;
3430d0321e0SJeremy L Thompson       break;
3440d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
3450d0321e0SJeremy L Thompson       impl->d_array = *array;
3460d0321e0SJeremy L Thompson       break;
3470d0321e0SJeremy L Thompson   }
3480d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
3490d0321e0SJeremy L Thompson }
3500d0321e0SJeremy L Thompson 
3510d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
35243c928f4SJeremy L Thompson // Get write access to a vector via the specified mem_type
3530d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3542b730f8bSJeremy L Thompson static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
3550d0321e0SJeremy L Thompson   bool            has_array_of_type = true;
356b7453713SJeremy L Thompson   CeedVector_Hip *impl;
357b7453713SJeremy L Thompson 
358b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
3592b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
3600d0321e0SJeremy L Thompson   if (!has_array_of_type) {
3610d0321e0SJeremy L Thompson     // Allocate if array is not yet allocated
3622b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
3630d0321e0SJeremy L Thompson   } else {
3640d0321e0SJeremy L Thompson     // Select dirty array
36543c928f4SJeremy L Thompson     switch (mem_type) {
3660d0321e0SJeremy L Thompson       case CEED_MEM_HOST:
3672b730f8bSJeremy L Thompson         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
3682b730f8bSJeremy L Thompson         else impl->h_array = impl->h_array_owned;
3690d0321e0SJeremy L Thompson         break;
3700d0321e0SJeremy L Thompson       case CEED_MEM_DEVICE:
3712b730f8bSJeremy L Thompson         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
3722b730f8bSJeremy L Thompson         else impl->d_array = impl->d_array_owned;
3730d0321e0SJeremy L Thompson     }
3740d0321e0SJeremy L Thompson   }
37543c928f4SJeremy L Thompson   return CeedVectorGetArray_Hip(vec, mem_type, array);
3760d0321e0SJeremy L Thompson }
3770d0321e0SJeremy L Thompson 
3780d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3790d0321e0SJeremy L Thompson // Get the norm of a CeedVector
3800d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3812b730f8bSJeremy L Thompson static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
3820d0321e0SJeremy L Thompson   Ceed              ceed;
383672b0f2aSSebastian Grimberg   CeedSize          length, num_calls;
384b7453713SJeremy L Thompson   const CeedScalar *d_array;
385b7453713SJeremy L Thompson   CeedVector_Hip   *impl;
3860d0321e0SJeremy L Thompson   hipblasHandle_t   handle;
387b7453713SJeremy L Thompson 
388b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
389b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
390b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
391eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
3920d0321e0SJeremy L Thompson 
3939330daecSnbeams   // Is the vector too long to handle with int32? If so, we will divide
3949330daecSnbeams   // it up into "int32-sized" subsections and make repeated BLAS calls.
395672b0f2aSSebastian Grimberg   num_calls = length / INT_MAX;
3969330daecSnbeams   if (length % INT_MAX > 0) num_calls += 1;
3979330daecSnbeams 
3980d0321e0SJeremy L Thompson   // Compute norm
3992b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
4000d0321e0SJeremy L Thompson   switch (type) {
4010d0321e0SJeremy L Thompson     case CEED_NORM_1: {
402f6f49adbSnbeams       *norm = 0.0;
4030d0321e0SJeremy L Thompson       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
4049330daecSnbeams         float  sub_norm = 0.0;
4059330daecSnbeams         float *d_array_start;
406b7453713SJeremy L Thompson 
4079330daecSnbeams         for (CeedInt i = 0; i < num_calls; i++) {
4089330daecSnbeams           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
4099330daecSnbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
4109330daecSnbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
411b7453713SJeremy L Thompson 
4129330daecSnbeams           CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
4139330daecSnbeams           *norm += sub_norm;
4149330daecSnbeams         }
4150d0321e0SJeremy L Thompson       } else {
4169330daecSnbeams         double  sub_norm = 0.0;
4179330daecSnbeams         double *d_array_start;
418b7453713SJeremy L Thompson 
4199330daecSnbeams         for (CeedInt i = 0; i < num_calls; i++) {
4209330daecSnbeams           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
4219330daecSnbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
4229330daecSnbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
423b7453713SJeremy L Thompson 
4249330daecSnbeams           CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
4259330daecSnbeams           *norm += sub_norm;
4269330daecSnbeams         }
4279330daecSnbeams       }
4280d0321e0SJeremy L Thompson       break;
4290d0321e0SJeremy L Thompson     }
4300d0321e0SJeremy L Thompson     case CEED_NORM_2: {
4310d0321e0SJeremy L Thompson       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
4329330daecSnbeams         float  sub_norm = 0.0, norm_sum = 0.0;
4339330daecSnbeams         float *d_array_start;
434b7453713SJeremy L Thompson 
4359330daecSnbeams         for (CeedInt i = 0; i < num_calls; i++) {
4369330daecSnbeams           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
4379330daecSnbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
4389330daecSnbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
439b7453713SJeremy L Thompson 
4409330daecSnbeams           CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
4419330daecSnbeams           norm_sum += sub_norm * sub_norm;
4429330daecSnbeams         }
4439330daecSnbeams         *norm = sqrt(norm_sum);
4440d0321e0SJeremy L Thompson       } else {
4459330daecSnbeams         double  sub_norm = 0.0, norm_sum = 0.0;
4469330daecSnbeams         double *d_array_start;
447b7453713SJeremy L Thompson 
4489330daecSnbeams         for (CeedInt i = 0; i < num_calls; i++) {
4499330daecSnbeams           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
4509330daecSnbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
4519330daecSnbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
452b7453713SJeremy L Thompson 
4539330daecSnbeams           CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
4549330daecSnbeams           norm_sum += sub_norm * sub_norm;
4559330daecSnbeams         }
4569330daecSnbeams         *norm = sqrt(norm_sum);
4579330daecSnbeams       }
4580d0321e0SJeremy L Thompson       break;
4590d0321e0SJeremy L Thompson     }
4600d0321e0SJeremy L Thompson     case CEED_NORM_MAX: {
461b7453713SJeremy L Thompson       CeedInt index;
462b7453713SJeremy L Thompson 
4630d0321e0SJeremy L Thompson       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
4649330daecSnbeams         float  sub_max = 0.0, current_max = 0.0;
4659330daecSnbeams         float *d_array_start;
4669330daecSnbeams         for (CeedInt i = 0; i < num_calls; i++) {
4679330daecSnbeams           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
4689330daecSnbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
4699330daecSnbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
470b7453713SJeremy L Thompson 
471b7453713SJeremy L Thompson           CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
472b7453713SJeremy L Thompson           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
4739330daecSnbeams           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
4749330daecSnbeams         }
4759330daecSnbeams         *norm = current_max;
4769330daecSnbeams       } else {
4779330daecSnbeams         double  sub_max = 0.0, current_max = 0.0;
4789330daecSnbeams         double *d_array_start;
479b7453713SJeremy L Thompson 
4809330daecSnbeams         for (CeedInt i = 0; i < num_calls; i++) {
4819330daecSnbeams           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
4829330daecSnbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
4839330daecSnbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
484b7453713SJeremy L Thompson 
485b7453713SJeremy L Thompson           CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
486b7453713SJeremy L Thompson           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
4879330daecSnbeams           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
4889330daecSnbeams         }
4899330daecSnbeams         *norm = current_max;
4909330daecSnbeams       }
4910d0321e0SJeremy L Thompson       break;
4920d0321e0SJeremy L Thompson     }
4930d0321e0SJeremy L Thompson   }
4942b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
4950d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
4960d0321e0SJeremy L Thompson }
4970d0321e0SJeremy L Thompson 
4980d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
4990d0321e0SJeremy L Thompson // Take reciprocal of a vector on host
5000d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5019330daecSnbeams static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
5029330daecSnbeams   for (CeedSize i = 0; i < length; i++) {
5032b730f8bSJeremy L Thompson     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
5042b730f8bSJeremy L Thompson   }
5050d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
5060d0321e0SJeremy L Thompson }
5070d0321e0SJeremy L Thompson 
5080d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5090d0321e0SJeremy L Thompson // Take reciprocal of a vector on device (impl in .cu file)
5100d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5119330daecSnbeams int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
5120d0321e0SJeremy L Thompson 
5130d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5140d0321e0SJeremy L Thompson // Take reciprocal of a vector
5150d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5160d0321e0SJeremy L Thompson static int CeedVectorReciprocal_Hip(CeedVector vec) {
5171f9221feSJeremy L Thompson   CeedSize        length;
518b7453713SJeremy L Thompson   CeedVector_Hip *impl;
5190d0321e0SJeremy L Thompson 
520b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
521b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(vec, &length));
5220d0321e0SJeremy L Thompson   // Set value for synced device/host array
5232b730f8bSJeremy L Thompson   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
5242b730f8bSJeremy L Thompson   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
5250d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
5260d0321e0SJeremy L Thompson }
5270d0321e0SJeremy L Thompson 
5280d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5290d0321e0SJeremy L Thompson // Compute x = alpha x on the host
5300d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5319330daecSnbeams static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
5329330daecSnbeams   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
5330d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
5340d0321e0SJeremy L Thompson }
5350d0321e0SJeremy L Thompson 
5360d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5370d0321e0SJeremy L Thompson // Compute x = alpha x on device (impl in .cu file)
5380d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5399330daecSnbeams int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
5400d0321e0SJeremy L Thompson 
5410d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5420d0321e0SJeremy L Thompson // Compute x = alpha x
5430d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5440d0321e0SJeremy L Thompson static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
5451f9221feSJeremy L Thompson   CeedSize        length;
546b7453713SJeremy L Thompson   CeedVector_Hip *x_impl;
5470d0321e0SJeremy L Thompson 
548b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(x, &x_impl));
549b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(x, &length));
5500d0321e0SJeremy L Thompson   // Set value for synced device/host array
5512b730f8bSJeremy L Thompson   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
5522b730f8bSJeremy L Thompson   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
5530d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
5540d0321e0SJeremy L Thompson }
5550d0321e0SJeremy L Thompson 
5560d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5570d0321e0SJeremy L Thompson // Compute y = alpha x + y on the host
5580d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5599330daecSnbeams static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
5609330daecSnbeams   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
5610d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
5620d0321e0SJeremy L Thompson }
5630d0321e0SJeremy L Thompson 
5640d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5650d0321e0SJeremy L Thompson // Compute y = alpha x + y on device (impl in .cu file)
5660d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5679330daecSnbeams int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
5680d0321e0SJeremy L Thompson 
5690d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5700d0321e0SJeremy L Thompson // Compute y = alpha x + y
5710d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5720d0321e0SJeremy L Thompson static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
573b7453713SJeremy L Thompson   CeedSize        length;
5740d0321e0SJeremy L Thompson   CeedVector_Hip *y_impl, *x_impl;
575b7453713SJeremy L Thompson 
5762b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(y, &y_impl));
5772b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(x, &x_impl));
5782b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(y, &length));
5790d0321e0SJeremy L Thompson   // Set value for synced device/host array
5800d0321e0SJeremy L Thompson   if (y_impl->d_array) {
5812b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
5822b730f8bSJeremy L Thompson     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
5830d0321e0SJeremy L Thompson   }
5840d0321e0SJeremy L Thompson   if (y_impl->h_array) {
5852b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
5862b730f8bSJeremy L Thompson     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
5870d0321e0SJeremy L Thompson   }
5880d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
5890d0321e0SJeremy L Thompson }
590ff1e7120SSebastian Grimberg 
5915fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
5925fb68f37SKaren (Ren) Stengel // Compute y = alpha x + beta y on the host
5935fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
5949330daecSnbeams static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
595aa67b842SZach Atkins   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
5965fb68f37SKaren (Ren) Stengel   return CEED_ERROR_SUCCESS;
5975fb68f37SKaren (Ren) Stengel }
5985fb68f37SKaren (Ren) Stengel 
5995fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
6005fb68f37SKaren (Ren) Stengel // Compute y = alpha x + beta y on device (impl in .cu file)
6015fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
6029330daecSnbeams int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
6035fb68f37SKaren (Ren) Stengel 
6045fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
6055fb68f37SKaren (Ren) Stengel // Compute y = alpha x + beta y
6065fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------
6075fb68f37SKaren (Ren) Stengel static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
608b7453713SJeremy L Thompson   CeedSize        length;
6095fb68f37SKaren (Ren) Stengel   CeedVector_Hip *y_impl, *x_impl;
610b7453713SJeremy L Thompson 
6115fb68f37SKaren (Ren) Stengel   CeedCallBackend(CeedVectorGetData(y, &y_impl));
6125fb68f37SKaren (Ren) Stengel   CeedCallBackend(CeedVectorGetData(x, &x_impl));
6135fb68f37SKaren (Ren) Stengel   CeedCallBackend(CeedVectorGetLength(y, &length));
6145fb68f37SKaren (Ren) Stengel   // Set value for synced device/host array
6155fb68f37SKaren (Ren) Stengel   if (y_impl->d_array) {
6165fb68f37SKaren (Ren) Stengel     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
6175fb68f37SKaren (Ren) Stengel     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
6185fb68f37SKaren (Ren) Stengel   }
6195fb68f37SKaren (Ren) Stengel   if (y_impl->h_array) {
6205fb68f37SKaren (Ren) Stengel     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
6215fb68f37SKaren (Ren) Stengel     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
6225fb68f37SKaren (Ren) Stengel   }
6235fb68f37SKaren (Ren) Stengel   return CEED_ERROR_SUCCESS;
6245fb68f37SKaren (Ren) Stengel }
6250d0321e0SJeremy L Thompson 
6260d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6270d0321e0SJeremy L Thompson // Compute the pointwise multiplication w = x .* y on the host
6280d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6299330daecSnbeams static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
6309330daecSnbeams   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
6310d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
6320d0321e0SJeremy L Thompson }
6330d0321e0SJeremy L Thompson 
6340d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6350d0321e0SJeremy L Thompson // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
6360d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6379330daecSnbeams int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
6380d0321e0SJeremy L Thompson 
6390d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6400d0321e0SJeremy L Thompson // Compute the pointwise multiplication w = x .* y
6410d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6422b730f8bSJeremy L Thompson static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
643b7453713SJeremy L Thompson   CeedSize        length;
6440d0321e0SJeremy L Thompson   CeedVector_Hip *w_impl, *x_impl, *y_impl;
645b7453713SJeremy L Thompson 
6462b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(w, &w_impl));
6472b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(x, &x_impl));
6482b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetData(y, &y_impl));
6492b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetLength(w, &length));
6500d0321e0SJeremy L Thompson 
6510d0321e0SJeremy L Thompson   // Set value for synced device/host array
6520d0321e0SJeremy L Thompson   if (!w_impl->d_array && !w_impl->h_array) {
6532b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSetValue(w, 0.0));
6540d0321e0SJeremy L Thompson   }
6550d0321e0SJeremy L Thompson   if (w_impl->d_array) {
6562b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
6572b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
6582b730f8bSJeremy L Thompson     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
6590d0321e0SJeremy L Thompson   }
6600d0321e0SJeremy L Thompson   if (w_impl->h_array) {
6612b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
6622b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
6632b730f8bSJeremy L Thompson     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
6640d0321e0SJeremy L Thompson   }
6650d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
6660d0321e0SJeremy L Thompson }
6670d0321e0SJeremy L Thompson 
6680d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6690d0321e0SJeremy L Thompson // Destroy the vector
6700d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6710d0321e0SJeremy L Thompson static int CeedVectorDestroy_Hip(const CeedVector vec) {
6720d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
6730d0321e0SJeremy L Thompson 
674b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetData(vec, &impl));
6756e536b99SJeremy L Thompson   CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
6762b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&impl->h_array_owned));
6772b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&impl));
6780d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
6790d0321e0SJeremy L Thompson }
6800d0321e0SJeremy L Thompson 
6810d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6820d0321e0SJeremy L Thompson // Create a vector of the specified length (does not allocate memory)
6830d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
6841f9221feSJeremy L Thompson int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
6850d0321e0SJeremy L Thompson   CeedVector_Hip *impl;
6860d0321e0SJeremy L Thompson   Ceed            ceed;
6870d0321e0SJeremy L Thompson 
688b7453713SJeremy L Thompson   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
6892b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
6902b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
6912b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
6922b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
693008736bdSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Hip));
6942b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
6952b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
6962b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
6972b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
6982b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
6992b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
700008736bdSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Hip));
701008736bdSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Hip));
702008736bdSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Hip));
7032b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
7042b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
7052b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(1, &impl));
7062b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorSetData(vec, impl));
7070d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
7080d0321e0SJeremy L Thompson }
7092a86cc9dSSebastian Grimberg 
7102a86cc9dSSebastian Grimberg //------------------------------------------------------------------------------
711