13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 30d0321e0SJeremy L Thompson // 43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause 50d0321e0SJeremy L Thompson // 63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed 70d0321e0SJeremy L Thompson 849aac155SJeremy L Thompson #include <ceed.h> 90d0321e0SJeremy L Thompson #include <ceed/backend.h> 100d0321e0SJeremy L Thompson #include <math.h> 1149aac155SJeremy L Thompson #include <stdbool.h> 120d0321e0SJeremy L Thompson #include <string.h> 13c85e8640SSebastian Grimberg #include <hip/hip_runtime.h> 140d0321e0SJeremy L Thompson 1549aac155SJeremy L Thompson #include "../hip/ceed-hip-common.h" 162b730f8bSJeremy L Thompson #include "ceed-hip-ref.h" 17f48ed27dSnbeams 18f48ed27dSnbeams //------------------------------------------------------------------------------ 19f48ed27dSnbeams // Check if host/device sync is needed 20f48ed27dSnbeams //------------------------------------------------------------------------------ 212b730f8bSJeremy L Thompson static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 22f48ed27dSnbeams CeedVector_Hip *impl; 23f48ed27dSnbeams bool has_valid_array = false; 24b7453713SJeremy L Thompson 25b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 262b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 27f48ed27dSnbeams switch (mem_type) { 28f48ed27dSnbeams case CEED_MEM_HOST: 29f48ed27dSnbeams *need_sync = has_valid_array && !impl->h_array; 30f48ed27dSnbeams break; 31f48ed27dSnbeams case CEED_MEM_DEVICE: 32f48ed27dSnbeams *need_sync = has_valid_array && !impl->d_array; 33f48ed27dSnbeams break; 34f48ed27dSnbeams } 35f48ed27dSnbeams return CEED_ERROR_SUCCESS; 36f48ed27dSnbeams } 37f48ed27dSnbeams 380d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 390d0321e0SJeremy L Thompson // Sync host to device 400d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 410d0321e0SJeremy L Thompson static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) { 420d0321e0SJeremy L Thompson Ceed ceed; 43b7453713SJeremy L Thompson CeedSize length; 44672b0f2aSSebastian Grimberg size_t bytes; 450d0321e0SJeremy L Thompson CeedVector_Hip *impl; 46b7453713SJeremy L Thompson 47b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 482b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 490d0321e0SJeremy L Thompson 506574a04fSJeremy L Thompson CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 510d0321e0SJeremy L Thompson 52672b0f2aSSebastian Grimberg CeedCallBackend(CeedVectorGetLength(vec, &length)); 53672b0f2aSSebastian Grimberg bytes = length * sizeof(CeedScalar); 540d0321e0SJeremy L Thompson if (impl->d_array_borrowed) { 550d0321e0SJeremy L Thompson impl->d_array = impl->d_array_borrowed; 560d0321e0SJeremy L Thompson } else if (impl->d_array_owned) { 570d0321e0SJeremy L Thompson impl->d_array = impl->d_array_owned; 580d0321e0SJeremy L Thompson } else { 592b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes)); 600d0321e0SJeremy L Thompson impl->d_array = impl->d_array_owned; 610d0321e0SJeremy L Thompson } 622b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice)); 630d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 640d0321e0SJeremy L Thompson } 650d0321e0SJeremy L Thompson 660d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 670d0321e0SJeremy L Thompson // Sync device to host 680d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 690d0321e0SJeremy L Thompson static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) { 700d0321e0SJeremy L Thompson Ceed ceed; 71b7453713SJeremy L Thompson CeedSize length; 72672b0f2aSSebastian Grimberg size_t bytes; 730d0321e0SJeremy L Thompson CeedVector_Hip *impl; 74b7453713SJeremy L Thompson 75b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 762b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 770d0321e0SJeremy L Thompson 786574a04fSJeremy L Thompson CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 790d0321e0SJeremy L Thompson 800d0321e0SJeremy L Thompson if (impl->h_array_borrowed) { 810d0321e0SJeremy L Thompson impl->h_array = impl->h_array_borrowed; 820d0321e0SJeremy L Thompson } else if (impl->h_array_owned) { 830d0321e0SJeremy L Thompson impl->h_array = impl->h_array_owned; 840d0321e0SJeremy L Thompson } else { 851f9221feSJeremy L Thompson CeedSize length; 86672b0f2aSSebastian Grimberg 872b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 882b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 890d0321e0SJeremy L Thompson impl->h_array = impl->h_array_owned; 900d0321e0SJeremy L Thompson } 910d0321e0SJeremy L Thompson 922b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 93672b0f2aSSebastian Grimberg bytes = length * sizeof(CeedScalar); 94b7453713SJeremy L Thompson CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost)); 950d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 960d0321e0SJeremy L Thompson } 970d0321e0SJeremy L Thompson 980d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 990d0321e0SJeremy L Thompson // Sync arrays 1000d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1012b730f8bSJeremy L Thompson static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) { 102f48ed27dSnbeams bool need_sync = false; 103b7453713SJeremy L Thompson 104b7453713SJeremy L Thompson // Check whether device/host sync is needed 1052b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync)); 1062b730f8bSJeremy L Thompson if (!need_sync) return CEED_ERROR_SUCCESS; 107f48ed27dSnbeams 10843c928f4SJeremy L Thompson switch (mem_type) { 1092b730f8bSJeremy L Thompson case CEED_MEM_HOST: 1102b730f8bSJeremy L Thompson return CeedVectorSyncD2H_Hip(vec); 1112b730f8bSJeremy L Thompson case CEED_MEM_DEVICE: 1122b730f8bSJeremy L Thompson return CeedVectorSyncH2D_Hip(vec); 1130d0321e0SJeremy L Thompson } 1140d0321e0SJeremy L Thompson return CEED_ERROR_UNSUPPORTED; 1150d0321e0SJeremy L Thompson } 1160d0321e0SJeremy L Thompson 1170d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1180d0321e0SJeremy L Thompson // Set all pointers as invalid 1190d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1200d0321e0SJeremy L Thompson static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) { 1210d0321e0SJeremy L Thompson CeedVector_Hip *impl; 1220d0321e0SJeremy L Thompson 123b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 1240d0321e0SJeremy L Thompson impl->h_array = NULL; 1250d0321e0SJeremy L Thompson impl->d_array = NULL; 1260d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1270d0321e0SJeremy L Thompson } 1280d0321e0SJeremy L Thompson 1290d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 130b2165e7aSSebastian Grimberg // Check if CeedVector has any valid pointer 1310d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1322b730f8bSJeremy L Thompson static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) { 1330d0321e0SJeremy L Thompson CeedVector_Hip *impl; 134b7453713SJeremy L Thompson 1352b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 1361c66c397SJeremy L Thompson *has_valid_array = impl->h_array || impl->d_array; 1370d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1380d0321e0SJeremy L Thompson } 1390d0321e0SJeremy L Thompson 1400d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 141b2165e7aSSebastian Grimberg // Check if has array of given type 1420d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1432b730f8bSJeremy L Thompson static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 1440d0321e0SJeremy L Thompson CeedVector_Hip *impl; 1450d0321e0SJeremy L Thompson 146b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 14743c928f4SJeremy L Thompson switch (mem_type) { 1480d0321e0SJeremy L Thompson case CEED_MEM_HOST: 1491c66c397SJeremy L Thompson *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned; 1500d0321e0SJeremy L Thompson break; 1510d0321e0SJeremy L Thompson case CEED_MEM_DEVICE: 1521c66c397SJeremy L Thompson *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned; 1530d0321e0SJeremy L Thompson break; 1540d0321e0SJeremy L Thompson } 1550d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1560d0321e0SJeremy L Thompson } 1570d0321e0SJeremy L Thompson 1580d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1590d0321e0SJeremy L Thompson // Check if has borrowed array of given type 1600d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1612b730f8bSJeremy L Thompson static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 1620d0321e0SJeremy L Thompson CeedVector_Hip *impl; 1630d0321e0SJeremy L Thompson 164b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 16543c928f4SJeremy L Thompson switch (mem_type) { 1660d0321e0SJeremy L Thompson case CEED_MEM_HOST: 1671c66c397SJeremy L Thompson *has_borrowed_array_of_type = impl->h_array_borrowed; 1680d0321e0SJeremy L Thompson break; 1690d0321e0SJeremy L Thompson case CEED_MEM_DEVICE: 1701c66c397SJeremy L Thompson *has_borrowed_array_of_type = impl->d_array_borrowed; 1710d0321e0SJeremy L Thompson break; 1720d0321e0SJeremy L Thompson } 1730d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1740d0321e0SJeremy L Thompson } 1750d0321e0SJeremy L Thompson 1760d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1770d0321e0SJeremy L Thompson // Set array from host 1780d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1792b730f8bSJeremy L Thompson static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 180*a267acd1SJeremy L Thompson CeedSize length; 1810d0321e0SJeremy L Thompson CeedVector_Hip *impl; 1820d0321e0SJeremy L Thompson 183b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 184*a267acd1SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 185*a267acd1SJeremy L Thompson 18643c928f4SJeremy L Thompson switch (copy_mode) { 1870d0321e0SJeremy L Thompson case CEED_COPY_VALUES: { 188*a267acd1SJeremy L Thompson if (!impl->h_array_owned) CeedCallBackend(CeedMalloc(length, &impl->h_array_owned)); 189*a267acd1SJeremy L Thompson if (array) memcpy(impl->h_array_owned, array, length * sizeof(array[0])); 1900d0321e0SJeremy L Thompson impl->h_array_borrowed = NULL; 1910d0321e0SJeremy L Thompson impl->h_array = impl->h_array_owned; 1920d0321e0SJeremy L Thompson } break; 1930d0321e0SJeremy L Thompson case CEED_OWN_POINTER: 1942b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&impl->h_array_owned)); 1950d0321e0SJeremy L Thompson impl->h_array_owned = array; 1960d0321e0SJeremy L Thompson impl->h_array_borrowed = NULL; 197*a267acd1SJeremy L Thompson impl->h_array = impl->h_array_owned; 1980d0321e0SJeremy L Thompson break; 1990d0321e0SJeremy L Thompson case CEED_USE_POINTER: 2002b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&impl->h_array_owned)); 2010d0321e0SJeremy L Thompson impl->h_array_borrowed = array; 202*a267acd1SJeremy L Thompson impl->h_array = impl->h_array_borrowed; 2030d0321e0SJeremy L Thompson break; 2040d0321e0SJeremy L Thompson } 2050d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 2060d0321e0SJeremy L Thompson } 2070d0321e0SJeremy L Thompson 2080d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 2090d0321e0SJeremy L Thompson // Set array from device 2100d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 2112b730f8bSJeremy L Thompson static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 212*a267acd1SJeremy L Thompson CeedSize length; 2130d0321e0SJeremy L Thompson Ceed ceed; 2140d0321e0SJeremy L Thompson CeedVector_Hip *impl; 2150d0321e0SJeremy L Thompson 216b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 217b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 218*a267acd1SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 21943c928f4SJeremy L Thompson switch (copy_mode) { 220539ec17dSJeremy L Thompson case CEED_COPY_VALUES: { 221*a267acd1SJeremy L Thompson if (!impl->d_array_owned) CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, length * sizeof(array[0]))); 222*a267acd1SJeremy L Thompson if (array) CeedCallHip(ceed, hipMemcpy(impl->d_array_owned, array, length * sizeof(array[0]), hipMemcpyDeviceToDevice)); 2230d0321e0SJeremy L Thompson impl->d_array_borrowed = NULL; 2243ce2313bSJeremy L Thompson impl->d_array = impl->d_array_owned; 225539ec17dSJeremy L Thompson } break; 2260d0321e0SJeremy L Thompson case CEED_OWN_POINTER: 2272b730f8bSJeremy L Thompson CeedCallHip(ceed, hipFree(impl->d_array_owned)); 2280d0321e0SJeremy L Thompson impl->d_array_owned = array; 2290d0321e0SJeremy L Thompson impl->d_array_borrowed = NULL; 230*a267acd1SJeremy L Thompson impl->d_array = impl->d_array_owned; 2310d0321e0SJeremy L Thompson break; 2320d0321e0SJeremy L Thompson case CEED_USE_POINTER: 2332b730f8bSJeremy L Thompson CeedCallHip(ceed, hipFree(impl->d_array_owned)); 2340d0321e0SJeremy L Thompson impl->d_array_owned = NULL; 2350d0321e0SJeremy L Thompson impl->d_array_borrowed = array; 236*a267acd1SJeremy L Thompson impl->d_array = impl->d_array_borrowed; 2370d0321e0SJeremy L Thompson break; 2380d0321e0SJeremy L Thompson } 2390d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 2400d0321e0SJeremy L Thompson } 2410d0321e0SJeremy L Thompson 2420d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 2430d0321e0SJeremy L Thompson // Set the array used by a vector, 2440d0321e0SJeremy L Thompson // freeing any previously allocated array if applicable 2450d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 2462b730f8bSJeremy L Thompson static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 2470d0321e0SJeremy L Thompson CeedVector_Hip *impl; 2480d0321e0SJeremy L Thompson 249b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 2502b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec)); 25143c928f4SJeremy L Thompson switch (mem_type) { 2520d0321e0SJeremy L Thompson case CEED_MEM_HOST: 25343c928f4SJeremy L Thompson return CeedVectorSetArrayHost_Hip(vec, copy_mode, array); 2540d0321e0SJeremy L Thompson case CEED_MEM_DEVICE: 25543c928f4SJeremy L Thompson return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array); 2560d0321e0SJeremy L Thompson } 2570d0321e0SJeremy L Thompson return CEED_ERROR_UNSUPPORTED; 2580d0321e0SJeremy L Thompson } 2590d0321e0SJeremy L Thompson 2600d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 2610d0321e0SJeremy L Thompson // Set host array to value 2620d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 2639330daecSnbeams static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) { 2649330daecSnbeams for (CeedSize i = 0; i < length; i++) h_array[i] = val; 2650d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 2660d0321e0SJeremy L Thompson } 2670d0321e0SJeremy L Thompson 2680d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 2690d0321e0SJeremy L Thompson // Set device array to value (impl in .hip file) 2700d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 2719330daecSnbeams int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val); 2720d0321e0SJeremy L Thompson 2730d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 274b2165e7aSSebastian Grimberg // Set a vector to a value 2750d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 2760d0321e0SJeremy L Thompson static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) { 2771f9221feSJeremy L Thompson CeedSize length; 278b7453713SJeremy L Thompson CeedVector_Hip *impl; 2790d0321e0SJeremy L Thompson 280b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 281b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 2820d0321e0SJeremy L Thompson // Set value for synced device/host array 2830d0321e0SJeremy L Thompson if (!impl->d_array && !impl->h_array) { 2840d0321e0SJeremy L Thompson if (impl->d_array_borrowed) { 2850d0321e0SJeremy L Thompson impl->d_array = impl->d_array_borrowed; 2860d0321e0SJeremy L Thompson } else if (impl->h_array_borrowed) { 2870d0321e0SJeremy L Thompson impl->h_array = impl->h_array_borrowed; 2880d0321e0SJeremy L Thompson } else if (impl->d_array_owned) { 2890d0321e0SJeremy L Thompson impl->d_array = impl->d_array_owned; 2900d0321e0SJeremy L Thompson } else if (impl->h_array_owned) { 2910d0321e0SJeremy L Thompson impl->h_array = impl->h_array_owned; 2920d0321e0SJeremy L Thompson } else { 2932b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 2940d0321e0SJeremy L Thompson } 2950d0321e0SJeremy L Thompson } 2960d0321e0SJeremy L Thompson if (impl->d_array) { 2972b730f8bSJeremy L Thompson CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val)); 298b2165e7aSSebastian Grimberg impl->h_array = NULL; 2990d0321e0SJeremy L Thompson } 3000d0321e0SJeremy L Thompson if (impl->h_array) { 3012b730f8bSJeremy L Thompson CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val)); 302b2165e7aSSebastian Grimberg impl->d_array = NULL; 3030d0321e0SJeremy L Thompson } 3040d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 3050d0321e0SJeremy L Thompson } 3060d0321e0SJeremy L Thompson 3070d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 3080d0321e0SJeremy L Thompson // Vector Take Array 3090d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 3102b730f8bSJeremy L Thompson static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 3110d0321e0SJeremy L Thompson CeedVector_Hip *impl; 312b7453713SJeremy L Thompson 3132b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 3140d0321e0SJeremy L Thompson 31543c928f4SJeremy L Thompson // Sync array to requested mem_type 3162b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 3170d0321e0SJeremy L Thompson 3180d0321e0SJeremy L Thompson // Update pointer 31943c928f4SJeremy L Thompson switch (mem_type) { 3200d0321e0SJeremy L Thompson case CEED_MEM_HOST: 3210d0321e0SJeremy L Thompson (*array) = impl->h_array_borrowed; 3220d0321e0SJeremy L Thompson impl->h_array_borrowed = NULL; 3230d0321e0SJeremy L Thompson impl->h_array = NULL; 3240d0321e0SJeremy L Thompson break; 3250d0321e0SJeremy L Thompson case CEED_MEM_DEVICE: 3260d0321e0SJeremy L Thompson (*array) = impl->d_array_borrowed; 3270d0321e0SJeremy L Thompson impl->d_array_borrowed = NULL; 3280d0321e0SJeremy L Thompson impl->d_array = NULL; 3290d0321e0SJeremy L Thompson break; 3300d0321e0SJeremy L Thompson } 3310d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 3320d0321e0SJeremy L Thompson } 3330d0321e0SJeremy L Thompson 3340d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 3350d0321e0SJeremy L Thompson // Core logic for array syncronization for GetArray. 3360d0321e0SJeremy L Thompson // If a different memory type is most up to date, this will perform a copy 3370d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 3382b730f8bSJeremy L Thompson static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 3390d0321e0SJeremy L Thompson CeedVector_Hip *impl; 340b7453713SJeremy L Thompson 3412b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 3420d0321e0SJeremy L Thompson 34343c928f4SJeremy L Thompson // Sync array to requested mem_type 3442b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 3450d0321e0SJeremy L Thompson 3460d0321e0SJeremy L Thompson // Update pointer 34743c928f4SJeremy L Thompson switch (mem_type) { 3480d0321e0SJeremy L Thompson case CEED_MEM_HOST: 3490d0321e0SJeremy L Thompson *array = impl->h_array; 3500d0321e0SJeremy L Thompson break; 3510d0321e0SJeremy L Thompson case CEED_MEM_DEVICE: 3520d0321e0SJeremy L Thompson *array = impl->d_array; 3530d0321e0SJeremy L Thompson break; 3540d0321e0SJeremy L Thompson } 3550d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 3560d0321e0SJeremy L Thompson } 3570d0321e0SJeremy L Thompson 3580d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 35943c928f4SJeremy L Thompson // Get read-only access to a vector via the specified mem_type 3600d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 3612b730f8bSJeremy L Thompson static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 36243c928f4SJeremy L Thompson return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array); 3630d0321e0SJeremy L Thompson } 3640d0321e0SJeremy L Thompson 3650d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 36643c928f4SJeremy L Thompson // Get read/write access to a vector via the specified mem_type 3670d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 3682b730f8bSJeremy L Thompson static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 3690d0321e0SJeremy L Thompson CeedVector_Hip *impl; 370b7453713SJeremy L Thompson 3712b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 3722b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array)); 3732b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec)); 37443c928f4SJeremy L Thompson switch (mem_type) { 3750d0321e0SJeremy L Thompson case CEED_MEM_HOST: 3760d0321e0SJeremy L Thompson impl->h_array = *array; 3770d0321e0SJeremy L Thompson break; 3780d0321e0SJeremy L Thompson case CEED_MEM_DEVICE: 3790d0321e0SJeremy L Thompson impl->d_array = *array; 3800d0321e0SJeremy L Thompson break; 3810d0321e0SJeremy L Thompson } 3820d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 3830d0321e0SJeremy L Thompson } 3840d0321e0SJeremy L Thompson 3850d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 38643c928f4SJeremy L Thompson // Get write access to a vector via the specified mem_type 3870d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 3882b730f8bSJeremy L Thompson static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 3890d0321e0SJeremy L Thompson bool has_array_of_type = true; 390b7453713SJeremy L Thompson CeedVector_Hip *impl; 391b7453713SJeremy L Thompson 392b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 3932b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type)); 3940d0321e0SJeremy L Thompson if (!has_array_of_type) { 3950d0321e0SJeremy L Thompson // Allocate if array is not yet allocated 3962b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 3970d0321e0SJeremy L Thompson } else { 3980d0321e0SJeremy L Thompson // Select dirty array 39943c928f4SJeremy L Thompson switch (mem_type) { 4000d0321e0SJeremy L Thompson case CEED_MEM_HOST: 4012b730f8bSJeremy L Thompson if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 4022b730f8bSJeremy L Thompson else impl->h_array = impl->h_array_owned; 4030d0321e0SJeremy L Thompson break; 4040d0321e0SJeremy L Thompson case CEED_MEM_DEVICE: 4052b730f8bSJeremy L Thompson if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 4062b730f8bSJeremy L Thompson else impl->d_array = impl->d_array_owned; 4070d0321e0SJeremy L Thompson } 4080d0321e0SJeremy L Thompson } 40943c928f4SJeremy L Thompson return CeedVectorGetArray_Hip(vec, mem_type, array); 4100d0321e0SJeremy L Thompson } 4110d0321e0SJeremy L Thompson 4120d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 4130d0321e0SJeremy L Thompson // Get the norm of a CeedVector 4140d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 4152b730f8bSJeremy L Thompson static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) { 4160d0321e0SJeremy L Thompson Ceed ceed; 417672b0f2aSSebastian Grimberg CeedSize length, num_calls; 418b7453713SJeremy L Thompson const CeedScalar *d_array; 419b7453713SJeremy L Thompson CeedVector_Hip *impl; 4200d0321e0SJeremy L Thompson hipblasHandle_t handle; 421b7453713SJeremy L Thompson 422b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 423b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 424b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 425eb7e6cafSJeremy L Thompson CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle)); 4260d0321e0SJeremy L Thompson 4279330daecSnbeams // Is the vector too long to handle with int32? If so, we will divide 4289330daecSnbeams // it up into "int32-sized" subsections and make repeated BLAS calls. 429672b0f2aSSebastian Grimberg num_calls = length / INT_MAX; 4309330daecSnbeams if (length % INT_MAX > 0) num_calls += 1; 4319330daecSnbeams 4320d0321e0SJeremy L Thompson // Compute norm 4332b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 4340d0321e0SJeremy L Thompson switch (type) { 4350d0321e0SJeremy L Thompson case CEED_NORM_1: { 436f6f49adbSnbeams *norm = 0.0; 4370d0321e0SJeremy L Thompson if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 4389330daecSnbeams float sub_norm = 0.0; 4399330daecSnbeams float *d_array_start; 440b7453713SJeremy L Thompson 4419330daecSnbeams for (CeedInt i = 0; i < num_calls; i++) { 4429330daecSnbeams d_array_start = (float *)d_array + (CeedSize)(i)*INT_MAX; 4439330daecSnbeams CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 4449330daecSnbeams CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 445b7453713SJeremy L Thompson 4469330daecSnbeams CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm)); 4479330daecSnbeams *norm += sub_norm; 4489330daecSnbeams } 4490d0321e0SJeremy L Thompson } else { 4509330daecSnbeams double sub_norm = 0.0; 4519330daecSnbeams double *d_array_start; 452b7453713SJeremy L Thompson 4539330daecSnbeams for (CeedInt i = 0; i < num_calls; i++) { 4549330daecSnbeams d_array_start = (double *)d_array + (CeedSize)(i)*INT_MAX; 4559330daecSnbeams CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 4569330daecSnbeams CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 457b7453713SJeremy L Thompson 4589330daecSnbeams CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm)); 4599330daecSnbeams *norm += sub_norm; 4609330daecSnbeams } 4619330daecSnbeams } 4620d0321e0SJeremy L Thompson break; 4630d0321e0SJeremy L Thompson } 4640d0321e0SJeremy L Thompson case CEED_NORM_2: { 4650d0321e0SJeremy L Thompson if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 4669330daecSnbeams float sub_norm = 0.0, norm_sum = 0.0; 4679330daecSnbeams float *d_array_start; 468b7453713SJeremy L Thompson 4699330daecSnbeams for (CeedInt i = 0; i < num_calls; i++) { 4709330daecSnbeams d_array_start = (float *)d_array + (CeedSize)(i)*INT_MAX; 4719330daecSnbeams CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 4729330daecSnbeams CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 473b7453713SJeremy L Thompson 4749330daecSnbeams CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm)); 4759330daecSnbeams norm_sum += sub_norm * sub_norm; 4769330daecSnbeams } 4779330daecSnbeams *norm = sqrt(norm_sum); 4780d0321e0SJeremy L Thompson } else { 4799330daecSnbeams double sub_norm = 0.0, norm_sum = 0.0; 4809330daecSnbeams double *d_array_start; 481b7453713SJeremy L Thompson 4829330daecSnbeams for (CeedInt i = 0; i < num_calls; i++) { 4839330daecSnbeams d_array_start = (double *)d_array + (CeedSize)(i)*INT_MAX; 4849330daecSnbeams CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 4859330daecSnbeams CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 486b7453713SJeremy L Thompson 4879330daecSnbeams CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm)); 4889330daecSnbeams norm_sum += sub_norm * sub_norm; 4899330daecSnbeams } 4909330daecSnbeams *norm = sqrt(norm_sum); 4919330daecSnbeams } 4920d0321e0SJeremy L Thompson break; 4930d0321e0SJeremy L Thompson } 4940d0321e0SJeremy L Thompson case CEED_NORM_MAX: { 495b7453713SJeremy L Thompson CeedInt index; 496b7453713SJeremy L Thompson 4970d0321e0SJeremy L Thompson if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 4989330daecSnbeams float sub_max = 0.0, current_max = 0.0; 4999330daecSnbeams float *d_array_start; 5009330daecSnbeams for (CeedInt i = 0; i < num_calls; i++) { 5019330daecSnbeams d_array_start = (float *)d_array + (CeedSize)(i)*INT_MAX; 5029330daecSnbeams CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 5039330daecSnbeams CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 504b7453713SJeremy L Thompson 505b7453713SJeremy L Thompson CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index)); 506b7453713SJeremy L Thompson CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost)); 5079330daecSnbeams if (fabs(sub_max) > current_max) current_max = fabs(sub_max); 5089330daecSnbeams } 5099330daecSnbeams *norm = current_max; 5109330daecSnbeams } else { 5119330daecSnbeams double sub_max = 0.0, current_max = 0.0; 5129330daecSnbeams double *d_array_start; 513b7453713SJeremy L Thompson 5149330daecSnbeams for (CeedInt i = 0; i < num_calls; i++) { 5159330daecSnbeams d_array_start = (double *)d_array + (CeedSize)(i)*INT_MAX; 5169330daecSnbeams CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX; 5179330daecSnbeams CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX; 518b7453713SJeremy L Thompson 519b7453713SJeremy L Thompson CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index)); 520b7453713SJeremy L Thompson CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost)); 5219330daecSnbeams if (fabs(sub_max) > current_max) current_max = fabs(sub_max); 5229330daecSnbeams } 5239330daecSnbeams *norm = current_max; 5249330daecSnbeams } 5250d0321e0SJeremy L Thompson break; 5260d0321e0SJeremy L Thompson } 5270d0321e0SJeremy L Thompson } 5282b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 5290d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 5300d0321e0SJeremy L Thompson } 5310d0321e0SJeremy L Thompson 5320d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5330d0321e0SJeremy L Thompson // Take reciprocal of a vector on host 5340d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5359330daecSnbeams static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) { 5369330daecSnbeams for (CeedSize i = 0; i < length; i++) { 5372b730f8bSJeremy L Thompson if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 5382b730f8bSJeremy L Thompson } 5390d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 5400d0321e0SJeremy L Thompson } 5410d0321e0SJeremy L Thompson 5420d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5430d0321e0SJeremy L Thompson // Take reciprocal of a vector on device (impl in .cu file) 5440d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5459330daecSnbeams int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length); 5460d0321e0SJeremy L Thompson 5470d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5480d0321e0SJeremy L Thompson // Take reciprocal of a vector 5490d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5500d0321e0SJeremy L Thompson static int CeedVectorReciprocal_Hip(CeedVector vec) { 5511f9221feSJeremy L Thompson CeedSize length; 552b7453713SJeremy L Thompson CeedVector_Hip *impl; 5530d0321e0SJeremy L Thompson 554b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 555b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(vec, &length)); 5560d0321e0SJeremy L Thompson // Set value for synced device/host array 5572b730f8bSJeremy L Thompson if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length)); 5582b730f8bSJeremy L Thompson if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length)); 5590d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 5600d0321e0SJeremy L Thompson } 5610d0321e0SJeremy L Thompson 5620d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5630d0321e0SJeremy L Thompson // Compute x = alpha x on the host 5640d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5659330daecSnbeams static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) { 5669330daecSnbeams for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha; 5670d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 5680d0321e0SJeremy L Thompson } 5690d0321e0SJeremy L Thompson 5700d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5710d0321e0SJeremy L Thompson // Compute x = alpha x on device (impl in .cu file) 5720d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5739330daecSnbeams int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length); 5740d0321e0SJeremy L Thompson 5750d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5760d0321e0SJeremy L Thompson // Compute x = alpha x 5770d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5780d0321e0SJeremy L Thompson static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) { 5791f9221feSJeremy L Thompson CeedSize length; 580b7453713SJeremy L Thompson CeedVector_Hip *x_impl; 5810d0321e0SJeremy L Thompson 582b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(x, &x_impl)); 583b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetLength(x, &length)); 5840d0321e0SJeremy L Thompson // Set value for synced device/host array 5852b730f8bSJeremy L Thompson if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length)); 5862b730f8bSJeremy L Thompson if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length)); 5870d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 5880d0321e0SJeremy L Thompson } 5890d0321e0SJeremy L Thompson 5900d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5910d0321e0SJeremy L Thompson // Compute y = alpha x + y on the host 5920d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5939330daecSnbeams static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) { 5949330daecSnbeams for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 5950d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 5960d0321e0SJeremy L Thompson } 5970d0321e0SJeremy L Thompson 5980d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 5990d0321e0SJeremy L Thompson // Compute y = alpha x + y on device (impl in .cu file) 6000d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 6019330daecSnbeams int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length); 6020d0321e0SJeremy L Thompson 6030d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 6040d0321e0SJeremy L Thompson // Compute y = alpha x + y 6050d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 6060d0321e0SJeremy L Thompson static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) { 607b7453713SJeremy L Thompson CeedSize length; 6080d0321e0SJeremy L Thompson CeedVector_Hip *y_impl, *x_impl; 609b7453713SJeremy L Thompson 6102b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetData(y, &y_impl)); 6112b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetData(x, &x_impl)); 6122b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetLength(y, &length)); 6130d0321e0SJeremy L Thompson // Set value for synced device/host array 6140d0321e0SJeremy L Thompson if (y_impl->d_array) { 6152b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 6162b730f8bSJeremy L Thompson CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length)); 6170d0321e0SJeremy L Thompson } 6180d0321e0SJeremy L Thompson if (y_impl->h_array) { 6192b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 6202b730f8bSJeremy L Thompson CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length)); 6210d0321e0SJeremy L Thompson } 6220d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 6230d0321e0SJeremy L Thompson } 624ff1e7120SSebastian Grimberg 6255fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------ 6265fb68f37SKaren (Ren) Stengel // Compute y = alpha x + beta y on the host 6275fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------ 6289330daecSnbeams static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) { 629aa67b842SZach Atkins for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i]; 6305fb68f37SKaren (Ren) Stengel return CEED_ERROR_SUCCESS; 6315fb68f37SKaren (Ren) Stengel } 6325fb68f37SKaren (Ren) Stengel 6335fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------ 6345fb68f37SKaren (Ren) Stengel // Compute y = alpha x + beta y on device (impl in .cu file) 6355fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------ 6369330daecSnbeams int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length); 6375fb68f37SKaren (Ren) Stengel 6385fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------ 6395fb68f37SKaren (Ren) Stengel // Compute y = alpha x + beta y 6405fb68f37SKaren (Ren) Stengel //------------------------------------------------------------------------------ 6415fb68f37SKaren (Ren) Stengel static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) { 642b7453713SJeremy L Thompson CeedSize length; 6435fb68f37SKaren (Ren) Stengel CeedVector_Hip *y_impl, *x_impl; 644b7453713SJeremy L Thompson 6455fb68f37SKaren (Ren) Stengel CeedCallBackend(CeedVectorGetData(y, &y_impl)); 6465fb68f37SKaren (Ren) Stengel CeedCallBackend(CeedVectorGetData(x, &x_impl)); 6475fb68f37SKaren (Ren) Stengel CeedCallBackend(CeedVectorGetLength(y, &length)); 6485fb68f37SKaren (Ren) Stengel // Set value for synced device/host array 6495fb68f37SKaren (Ren) Stengel if (y_impl->d_array) { 6505fb68f37SKaren (Ren) Stengel CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 6515fb68f37SKaren (Ren) Stengel CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length)); 6525fb68f37SKaren (Ren) Stengel } 6535fb68f37SKaren (Ren) Stengel if (y_impl->h_array) { 6545fb68f37SKaren (Ren) Stengel CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 6555fb68f37SKaren (Ren) Stengel CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length)); 6565fb68f37SKaren (Ren) Stengel } 6575fb68f37SKaren (Ren) Stengel return CEED_ERROR_SUCCESS; 6585fb68f37SKaren (Ren) Stengel } 6590d0321e0SJeremy L Thompson 6600d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 6610d0321e0SJeremy L Thompson // Compute the pointwise multiplication w = x .* y on the host 6620d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 6639330daecSnbeams static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) { 6649330daecSnbeams for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 6650d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 6660d0321e0SJeremy L Thompson } 6670d0321e0SJeremy L Thompson 6680d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 6690d0321e0SJeremy L Thompson // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 6700d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 6719330daecSnbeams int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length); 6720d0321e0SJeremy L Thompson 6730d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 6740d0321e0SJeremy L Thompson // Compute the pointwise multiplication w = x .* y 6750d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 6762b730f8bSJeremy L Thompson static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) { 677b7453713SJeremy L Thompson CeedSize length; 6780d0321e0SJeremy L Thompson CeedVector_Hip *w_impl, *x_impl, *y_impl; 679b7453713SJeremy L Thompson 6802b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetData(w, &w_impl)); 6812b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetData(x, &x_impl)); 6822b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetData(y, &y_impl)); 6832b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetLength(w, &length)); 6840d0321e0SJeremy L Thompson 6850d0321e0SJeremy L Thompson // Set value for synced device/host array 6860d0321e0SJeremy L Thompson if (!w_impl->d_array && !w_impl->h_array) { 6872b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSetValue(w, 0.0)); 6880d0321e0SJeremy L Thompson } 6890d0321e0SJeremy L Thompson if (w_impl->d_array) { 6902b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 6912b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 6922b730f8bSJeremy L Thompson CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 6930d0321e0SJeremy L Thompson } 6940d0321e0SJeremy L Thompson if (w_impl->h_array) { 6952b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 6962b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 6972b730f8bSJeremy L Thompson CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 6980d0321e0SJeremy L Thompson } 6990d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 7000d0321e0SJeremy L Thompson } 7010d0321e0SJeremy L Thompson 7020d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 7030d0321e0SJeremy L Thompson // Destroy the vector 7040d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 7050d0321e0SJeremy L Thompson static int CeedVectorDestroy_Hip(const CeedVector vec) { 7060d0321e0SJeremy L Thompson CeedVector_Hip *impl; 7070d0321e0SJeremy L Thompson 708b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetData(vec, &impl)); 7096e536b99SJeremy L Thompson CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned)); 7102b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&impl->h_array_owned)); 7112b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&impl)); 7120d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 7130d0321e0SJeremy L Thompson } 7140d0321e0SJeremy L Thompson 7150d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 7160d0321e0SJeremy L Thompson // Create a vector of the specified length (does not allocate memory) 7170d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 7181f9221feSJeremy L Thompson int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) { 7190d0321e0SJeremy L Thompson CeedVector_Hip *impl; 7200d0321e0SJeremy L Thompson Ceed ceed; 7210d0321e0SJeremy L Thompson 722b7453713SJeremy L Thompson CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 7232b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip)); 7242b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip)); 7252b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip)); 7262b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip)); 727008736bdSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Hip)); 7282b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip)); 7292b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip)); 7302b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip)); 7312b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip)); 7322b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip)); 7332b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip)); 734008736bdSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Hip)); 735008736bdSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Hip)); 736008736bdSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Hip)); 7372b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip)); 7382b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip)); 7392b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl)); 7402b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorSetData(vec, impl)); 7410d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 7420d0321e0SJeremy L Thompson } 7432a86cc9dSSebastian Grimberg 7442a86cc9dSSebastian Grimberg //------------------------------------------------------------------------------ 745