1*ff1e7120SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*ff1e7120SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*ff1e7120SSebastian Grimberg // 4*ff1e7120SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*ff1e7120SSebastian Grimberg // 6*ff1e7120SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*ff1e7120SSebastian Grimberg 8*ff1e7120SSebastian Grimberg #include <ceed.h> 9*ff1e7120SSebastian Grimberg #include <ceed/backend.h> 10*ff1e7120SSebastian Grimberg #include <cublas_v2.h> 11*ff1e7120SSebastian Grimberg #include <cuda_runtime.h> 12*ff1e7120SSebastian Grimberg #include <math.h> 13*ff1e7120SSebastian Grimberg #include <stdbool.h> 14*ff1e7120SSebastian Grimberg #include <string.h> 15*ff1e7120SSebastian Grimberg 16*ff1e7120SSebastian Grimberg #include "../cuda/ceed-cuda-common.h" 17*ff1e7120SSebastian Grimberg #include "ceed-cuda-ref.h" 18*ff1e7120SSebastian Grimberg 19*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 20*ff1e7120SSebastian Grimberg // Check if host/device sync is needed 21*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 22*ff1e7120SSebastian Grimberg static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, CeedMemType mem_type, bool *need_sync) { 23*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 24*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 25*ff1e7120SSebastian Grimberg 26*ff1e7120SSebastian Grimberg bool has_valid_array = false; 27*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array)); 28*ff1e7120SSebastian Grimberg switch (mem_type) { 29*ff1e7120SSebastian Grimberg case CEED_MEM_HOST: 30*ff1e7120SSebastian Grimberg *need_sync = has_valid_array && !impl->h_array; 31*ff1e7120SSebastian Grimberg break; 32*ff1e7120SSebastian Grimberg case CEED_MEM_DEVICE: 33*ff1e7120SSebastian Grimberg *need_sync = has_valid_array && !impl->d_array; 34*ff1e7120SSebastian Grimberg break; 35*ff1e7120SSebastian Grimberg } 36*ff1e7120SSebastian Grimberg 37*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 38*ff1e7120SSebastian Grimberg } 39*ff1e7120SSebastian Grimberg 40*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 41*ff1e7120SSebastian Grimberg // Sync host to device 42*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 43*ff1e7120SSebastian Grimberg static inline int CeedVectorSyncH2D_Cuda(const CeedVector vec) { 44*ff1e7120SSebastian Grimberg Ceed ceed; 45*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 46*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 47*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 48*ff1e7120SSebastian Grimberg 49*ff1e7120SSebastian Grimberg CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 50*ff1e7120SSebastian Grimberg 51*ff1e7120SSebastian Grimberg CeedSize length; 52*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(vec, &length)); 53*ff1e7120SSebastian Grimberg size_t bytes = length * sizeof(CeedScalar); 54*ff1e7120SSebastian Grimberg 55*ff1e7120SSebastian Grimberg if (impl->d_array_borrowed) { 56*ff1e7120SSebastian Grimberg impl->d_array = impl->d_array_borrowed; 57*ff1e7120SSebastian Grimberg } else if (impl->d_array_owned) { 58*ff1e7120SSebastian Grimberg impl->d_array = impl->d_array_owned; 59*ff1e7120SSebastian Grimberg } else { 60*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes)); 61*ff1e7120SSebastian Grimberg impl->d_array = impl->d_array_owned; 62*ff1e7120SSebastian Grimberg } 63*ff1e7120SSebastian Grimberg 64*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMemcpy(impl->d_array, impl->h_array, bytes, cudaMemcpyHostToDevice)); 65*ff1e7120SSebastian Grimberg 66*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 67*ff1e7120SSebastian Grimberg } 68*ff1e7120SSebastian Grimberg 69*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 70*ff1e7120SSebastian Grimberg // Sync device to host 71*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 72*ff1e7120SSebastian Grimberg static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) { 73*ff1e7120SSebastian Grimberg Ceed ceed; 74*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 75*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 76*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 77*ff1e7120SSebastian Grimberg 78*ff1e7120SSebastian Grimberg CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 79*ff1e7120SSebastian Grimberg 80*ff1e7120SSebastian Grimberg if (impl->h_array_borrowed) { 81*ff1e7120SSebastian Grimberg impl->h_array = impl->h_array_borrowed; 82*ff1e7120SSebastian Grimberg } else if (impl->h_array_owned) { 83*ff1e7120SSebastian Grimberg impl->h_array = impl->h_array_owned; 84*ff1e7120SSebastian Grimberg } else { 85*ff1e7120SSebastian Grimberg CeedSize length; 86*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(vec, &length)); 87*ff1e7120SSebastian Grimberg CeedCallBackend(CeedCalloc(length, &impl->h_array_owned)); 88*ff1e7120SSebastian Grimberg impl->h_array = impl->h_array_owned; 89*ff1e7120SSebastian Grimberg } 90*ff1e7120SSebastian Grimberg 91*ff1e7120SSebastian Grimberg CeedSize length; 92*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(vec, &length)); 93*ff1e7120SSebastian Grimberg size_t bytes = length * sizeof(CeedScalar); 94*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMemcpy(impl->h_array, impl->d_array, bytes, cudaMemcpyDeviceToHost)); 95*ff1e7120SSebastian Grimberg 96*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 97*ff1e7120SSebastian Grimberg } 98*ff1e7120SSebastian Grimberg 99*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 100*ff1e7120SSebastian Grimberg // Sync arrays 101*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 102*ff1e7120SSebastian Grimberg static int CeedVectorSyncArray_Cuda(const CeedVector vec, CeedMemType mem_type) { 103*ff1e7120SSebastian Grimberg // Check whether device/host sync is needed 104*ff1e7120SSebastian Grimberg bool need_sync = false; 105*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync)); 106*ff1e7120SSebastian Grimberg if (!need_sync) return CEED_ERROR_SUCCESS; 107*ff1e7120SSebastian Grimberg 108*ff1e7120SSebastian Grimberg switch (mem_type) { 109*ff1e7120SSebastian Grimberg case CEED_MEM_HOST: 110*ff1e7120SSebastian Grimberg return CeedVectorSyncD2H_Cuda(vec); 111*ff1e7120SSebastian Grimberg case CEED_MEM_DEVICE: 112*ff1e7120SSebastian Grimberg return CeedVectorSyncH2D_Cuda(vec); 113*ff1e7120SSebastian Grimberg } 114*ff1e7120SSebastian Grimberg return CEED_ERROR_UNSUPPORTED; 115*ff1e7120SSebastian Grimberg } 116*ff1e7120SSebastian Grimberg 117*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 118*ff1e7120SSebastian Grimberg // Set all pointers as invalid 119*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 120*ff1e7120SSebastian Grimberg static inline int CeedVectorSetAllInvalid_Cuda(const CeedVector vec) { 121*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 122*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 123*ff1e7120SSebastian Grimberg 124*ff1e7120SSebastian Grimberg impl->h_array = NULL; 125*ff1e7120SSebastian Grimberg impl->d_array = NULL; 126*ff1e7120SSebastian Grimberg 127*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 128*ff1e7120SSebastian Grimberg } 129*ff1e7120SSebastian Grimberg 130*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 131*ff1e7120SSebastian Grimberg // Check if CeedVector has any valid pointer 132*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 133*ff1e7120SSebastian Grimberg static inline int CeedVectorHasValidArray_Cuda(const CeedVector vec, bool *has_valid_array) { 134*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 135*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 136*ff1e7120SSebastian Grimberg 137*ff1e7120SSebastian Grimberg *has_valid_array = !!impl->h_array || !!impl->d_array; 138*ff1e7120SSebastian Grimberg 139*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 140*ff1e7120SSebastian Grimberg } 141*ff1e7120SSebastian Grimberg 142*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 143*ff1e7120SSebastian Grimberg // Check if has array of given type 144*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 145*ff1e7120SSebastian Grimberg static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) { 146*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 147*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 148*ff1e7120SSebastian Grimberg 149*ff1e7120SSebastian Grimberg switch (mem_type) { 150*ff1e7120SSebastian Grimberg case CEED_MEM_HOST: 151*ff1e7120SSebastian Grimberg *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned; 152*ff1e7120SSebastian Grimberg break; 153*ff1e7120SSebastian Grimberg case CEED_MEM_DEVICE: 154*ff1e7120SSebastian Grimberg *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned; 155*ff1e7120SSebastian Grimberg break; 156*ff1e7120SSebastian Grimberg } 157*ff1e7120SSebastian Grimberg 158*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 159*ff1e7120SSebastian Grimberg } 160*ff1e7120SSebastian Grimberg 161*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 162*ff1e7120SSebastian Grimberg // Check if has borrowed array of given type 163*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 164*ff1e7120SSebastian Grimberg static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) { 165*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 166*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 167*ff1e7120SSebastian Grimberg 168*ff1e7120SSebastian Grimberg switch (mem_type) { 169*ff1e7120SSebastian Grimberg case CEED_MEM_HOST: 170*ff1e7120SSebastian Grimberg *has_borrowed_array_of_type = !!impl->h_array_borrowed; 171*ff1e7120SSebastian Grimberg break; 172*ff1e7120SSebastian Grimberg case CEED_MEM_DEVICE: 173*ff1e7120SSebastian Grimberg *has_borrowed_array_of_type = !!impl->d_array_borrowed; 174*ff1e7120SSebastian Grimberg break; 175*ff1e7120SSebastian Grimberg } 176*ff1e7120SSebastian Grimberg 177*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 178*ff1e7120SSebastian Grimberg } 179*ff1e7120SSebastian Grimberg 180*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 181*ff1e7120SSebastian Grimberg // Set array from host 182*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 183*ff1e7120SSebastian Grimberg static int CeedVectorSetArrayHost_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 184*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 185*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 186*ff1e7120SSebastian Grimberg 187*ff1e7120SSebastian Grimberg switch (copy_mode) { 188*ff1e7120SSebastian Grimberg case CEED_COPY_VALUES: { 189*ff1e7120SSebastian Grimberg CeedSize length; 190*ff1e7120SSebastian Grimberg if (!impl->h_array_owned) { 191*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(vec, &length)); 192*ff1e7120SSebastian Grimberg CeedCallBackend(CeedMalloc(length, &impl->h_array_owned)); 193*ff1e7120SSebastian Grimberg } 194*ff1e7120SSebastian Grimberg impl->h_array_borrowed = NULL; 195*ff1e7120SSebastian Grimberg impl->h_array = impl->h_array_owned; 196*ff1e7120SSebastian Grimberg if (array) { 197*ff1e7120SSebastian Grimberg CeedSize length; 198*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(vec, &length)); 199*ff1e7120SSebastian Grimberg size_t bytes = length * sizeof(CeedScalar); 200*ff1e7120SSebastian Grimberg memcpy(impl->h_array, array, bytes); 201*ff1e7120SSebastian Grimberg } 202*ff1e7120SSebastian Grimberg } break; 203*ff1e7120SSebastian Grimberg case CEED_OWN_POINTER: 204*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&impl->h_array_owned)); 205*ff1e7120SSebastian Grimberg impl->h_array_owned = array; 206*ff1e7120SSebastian Grimberg impl->h_array_borrowed = NULL; 207*ff1e7120SSebastian Grimberg impl->h_array = array; 208*ff1e7120SSebastian Grimberg break; 209*ff1e7120SSebastian Grimberg case CEED_USE_POINTER: 210*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&impl->h_array_owned)); 211*ff1e7120SSebastian Grimberg impl->h_array_borrowed = array; 212*ff1e7120SSebastian Grimberg impl->h_array = array; 213*ff1e7120SSebastian Grimberg break; 214*ff1e7120SSebastian Grimberg } 215*ff1e7120SSebastian Grimberg 216*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 217*ff1e7120SSebastian Grimberg } 218*ff1e7120SSebastian Grimberg 219*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 220*ff1e7120SSebastian Grimberg // Set array from device 221*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 222*ff1e7120SSebastian Grimberg static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) { 223*ff1e7120SSebastian Grimberg Ceed ceed; 224*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 225*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 226*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 227*ff1e7120SSebastian Grimberg 228*ff1e7120SSebastian Grimberg switch (copy_mode) { 229*ff1e7120SSebastian Grimberg case CEED_COPY_VALUES: { 230*ff1e7120SSebastian Grimberg CeedSize length; 231*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(vec, &length)); 232*ff1e7120SSebastian Grimberg size_t bytes = length * sizeof(CeedScalar); 233*ff1e7120SSebastian Grimberg if (!impl->d_array_owned) { 234*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes)); 235*ff1e7120SSebastian Grimberg impl->d_array = impl->d_array_owned; 236*ff1e7120SSebastian Grimberg } 237*ff1e7120SSebastian Grimberg if (array) CeedCallCuda(ceed, cudaMemcpy(impl->d_array, array, bytes, cudaMemcpyDeviceToDevice)); 238*ff1e7120SSebastian Grimberg } break; 239*ff1e7120SSebastian Grimberg case CEED_OWN_POINTER: 240*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaFree(impl->d_array_owned)); 241*ff1e7120SSebastian Grimberg impl->d_array_owned = array; 242*ff1e7120SSebastian Grimberg impl->d_array_borrowed = NULL; 243*ff1e7120SSebastian Grimberg impl->d_array = array; 244*ff1e7120SSebastian Grimberg break; 245*ff1e7120SSebastian Grimberg case CEED_USE_POINTER: 246*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaFree(impl->d_array_owned)); 247*ff1e7120SSebastian Grimberg impl->d_array_owned = NULL; 248*ff1e7120SSebastian Grimberg impl->d_array_borrowed = array; 249*ff1e7120SSebastian Grimberg impl->d_array = array; 250*ff1e7120SSebastian Grimberg break; 251*ff1e7120SSebastian Grimberg } 252*ff1e7120SSebastian Grimberg 253*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 254*ff1e7120SSebastian Grimberg } 255*ff1e7120SSebastian Grimberg 256*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 257*ff1e7120SSebastian Grimberg // Set the array used by a vector, 258*ff1e7120SSebastian Grimberg // freeing any previously allocated array if applicable 259*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 260*ff1e7120SSebastian Grimberg static int CeedVectorSetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) { 261*ff1e7120SSebastian Grimberg Ceed ceed; 262*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 263*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 264*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 265*ff1e7120SSebastian Grimberg 266*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec)); 267*ff1e7120SSebastian Grimberg switch (mem_type) { 268*ff1e7120SSebastian Grimberg case CEED_MEM_HOST: 269*ff1e7120SSebastian Grimberg return CeedVectorSetArrayHost_Cuda(vec, copy_mode, array); 270*ff1e7120SSebastian Grimberg case CEED_MEM_DEVICE: 271*ff1e7120SSebastian Grimberg return CeedVectorSetArrayDevice_Cuda(vec, copy_mode, array); 272*ff1e7120SSebastian Grimberg } 273*ff1e7120SSebastian Grimberg 274*ff1e7120SSebastian Grimberg return CEED_ERROR_UNSUPPORTED; 275*ff1e7120SSebastian Grimberg } 276*ff1e7120SSebastian Grimberg 277*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 278*ff1e7120SSebastian Grimberg // Set host array to value 279*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 280*ff1e7120SSebastian Grimberg static int CeedHostSetValue_Cuda(CeedScalar *h_array, CeedInt length, CeedScalar val) { 281*ff1e7120SSebastian Grimberg for (int i = 0; i < length; i++) h_array[i] = val; 282*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 283*ff1e7120SSebastian Grimberg } 284*ff1e7120SSebastian Grimberg 285*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 286*ff1e7120SSebastian Grimberg // Set device array to value (impl in .cu file) 287*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 288*ff1e7120SSebastian Grimberg int CeedDeviceSetValue_Cuda(CeedScalar *d_array, CeedInt length, CeedScalar val); 289*ff1e7120SSebastian Grimberg 290*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 291*ff1e7120SSebastian Grimberg // Set a vector to a value, 292*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 293*ff1e7120SSebastian Grimberg static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) { 294*ff1e7120SSebastian Grimberg Ceed ceed; 295*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 296*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 297*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 298*ff1e7120SSebastian Grimberg CeedSize length; 299*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(vec, &length)); 300*ff1e7120SSebastian Grimberg 301*ff1e7120SSebastian Grimberg // Set value for synced device/host array 302*ff1e7120SSebastian Grimberg if (!impl->d_array && !impl->h_array) { 303*ff1e7120SSebastian Grimberg if (impl->d_array_borrowed) { 304*ff1e7120SSebastian Grimberg impl->d_array = impl->d_array_borrowed; 305*ff1e7120SSebastian Grimberg } else if (impl->h_array_borrowed) { 306*ff1e7120SSebastian Grimberg impl->h_array = impl->h_array_borrowed; 307*ff1e7120SSebastian Grimberg } else if (impl->d_array_owned) { 308*ff1e7120SSebastian Grimberg impl->d_array = impl->d_array_owned; 309*ff1e7120SSebastian Grimberg } else if (impl->h_array_owned) { 310*ff1e7120SSebastian Grimberg impl->h_array = impl->h_array_owned; 311*ff1e7120SSebastian Grimberg } else { 312*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL)); 313*ff1e7120SSebastian Grimberg } 314*ff1e7120SSebastian Grimberg } 315*ff1e7120SSebastian Grimberg if (impl->d_array) { 316*ff1e7120SSebastian Grimberg CeedCallBackend(CeedDeviceSetValue_Cuda(impl->d_array, length, val)); 317*ff1e7120SSebastian Grimberg impl->h_array = NULL; 318*ff1e7120SSebastian Grimberg } 319*ff1e7120SSebastian Grimberg if (impl->h_array) { 320*ff1e7120SSebastian Grimberg CeedCallBackend(CeedHostSetValue_Cuda(impl->h_array, length, val)); 321*ff1e7120SSebastian Grimberg impl->d_array = NULL; 322*ff1e7120SSebastian Grimberg } 323*ff1e7120SSebastian Grimberg 324*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 325*ff1e7120SSebastian Grimberg } 326*ff1e7120SSebastian Grimberg 327*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 328*ff1e7120SSebastian Grimberg // Vector Take Array 329*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 330*ff1e7120SSebastian Grimberg static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, CeedScalar **array) { 331*ff1e7120SSebastian Grimberg Ceed ceed; 332*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 333*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 334*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 335*ff1e7120SSebastian Grimberg 336*ff1e7120SSebastian Grimberg // Sync array to requested mem_type 337*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 338*ff1e7120SSebastian Grimberg 339*ff1e7120SSebastian Grimberg // Update pointer 340*ff1e7120SSebastian Grimberg switch (mem_type) { 341*ff1e7120SSebastian Grimberg case CEED_MEM_HOST: 342*ff1e7120SSebastian Grimberg (*array) = impl->h_array_borrowed; 343*ff1e7120SSebastian Grimberg impl->h_array_borrowed = NULL; 344*ff1e7120SSebastian Grimberg impl->h_array = NULL; 345*ff1e7120SSebastian Grimberg break; 346*ff1e7120SSebastian Grimberg case CEED_MEM_DEVICE: 347*ff1e7120SSebastian Grimberg (*array) = impl->d_array_borrowed; 348*ff1e7120SSebastian Grimberg impl->d_array_borrowed = NULL; 349*ff1e7120SSebastian Grimberg impl->d_array = NULL; 350*ff1e7120SSebastian Grimberg break; 351*ff1e7120SSebastian Grimberg } 352*ff1e7120SSebastian Grimberg 353*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 354*ff1e7120SSebastian Grimberg } 355*ff1e7120SSebastian Grimberg 356*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 357*ff1e7120SSebastian Grimberg // Core logic for array syncronization for GetArray. 358*ff1e7120SSebastian Grimberg // If a different memory type is most up to date, this will perform a copy 359*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 360*ff1e7120SSebastian Grimberg static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 361*ff1e7120SSebastian Grimberg Ceed ceed; 362*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 363*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 364*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 365*ff1e7120SSebastian Grimberg 366*ff1e7120SSebastian Grimberg // Sync array to requested mem_type 367*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSyncArray(vec, mem_type)); 368*ff1e7120SSebastian Grimberg 369*ff1e7120SSebastian Grimberg // Update pointer 370*ff1e7120SSebastian Grimberg switch (mem_type) { 371*ff1e7120SSebastian Grimberg case CEED_MEM_HOST: 372*ff1e7120SSebastian Grimberg *array = impl->h_array; 373*ff1e7120SSebastian Grimberg break; 374*ff1e7120SSebastian Grimberg case CEED_MEM_DEVICE: 375*ff1e7120SSebastian Grimberg *array = impl->d_array; 376*ff1e7120SSebastian Grimberg break; 377*ff1e7120SSebastian Grimberg } 378*ff1e7120SSebastian Grimberg 379*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 380*ff1e7120SSebastian Grimberg } 381*ff1e7120SSebastian Grimberg 382*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 383*ff1e7120SSebastian Grimberg // Get read-only access to a vector via the specified mem_type 384*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 385*ff1e7120SSebastian Grimberg static int CeedVectorGetArrayRead_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) { 386*ff1e7120SSebastian Grimberg return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array); 387*ff1e7120SSebastian Grimberg } 388*ff1e7120SSebastian Grimberg 389*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 390*ff1e7120SSebastian Grimberg // Get read/write access to a vector via the specified mem_type 391*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 392*ff1e7120SSebastian Grimberg static int CeedVectorGetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 393*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 394*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 395*ff1e7120SSebastian Grimberg 396*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetArrayCore_Cuda(vec, mem_type, array)); 397*ff1e7120SSebastian Grimberg 398*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec)); 399*ff1e7120SSebastian Grimberg switch (mem_type) { 400*ff1e7120SSebastian Grimberg case CEED_MEM_HOST: 401*ff1e7120SSebastian Grimberg impl->h_array = *array; 402*ff1e7120SSebastian Grimberg break; 403*ff1e7120SSebastian Grimberg case CEED_MEM_DEVICE: 404*ff1e7120SSebastian Grimberg impl->d_array = *array; 405*ff1e7120SSebastian Grimberg break; 406*ff1e7120SSebastian Grimberg } 407*ff1e7120SSebastian Grimberg 408*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 409*ff1e7120SSebastian Grimberg } 410*ff1e7120SSebastian Grimberg 411*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 412*ff1e7120SSebastian Grimberg // Get write access to a vector via the specified mem_type 413*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 414*ff1e7120SSebastian Grimberg static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) { 415*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 416*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 417*ff1e7120SSebastian Grimberg 418*ff1e7120SSebastian Grimberg bool has_array_of_type = true; 419*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type)); 420*ff1e7120SSebastian Grimberg if (!has_array_of_type) { 421*ff1e7120SSebastian Grimberg // Allocate if array is not yet allocated 422*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL)); 423*ff1e7120SSebastian Grimberg } else { 424*ff1e7120SSebastian Grimberg // Select dirty array 425*ff1e7120SSebastian Grimberg switch (mem_type) { 426*ff1e7120SSebastian Grimberg case CEED_MEM_HOST: 427*ff1e7120SSebastian Grimberg if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed; 428*ff1e7120SSebastian Grimberg else impl->h_array = impl->h_array_owned; 429*ff1e7120SSebastian Grimberg break; 430*ff1e7120SSebastian Grimberg case CEED_MEM_DEVICE: 431*ff1e7120SSebastian Grimberg if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed; 432*ff1e7120SSebastian Grimberg else impl->d_array = impl->d_array_owned; 433*ff1e7120SSebastian Grimberg } 434*ff1e7120SSebastian Grimberg } 435*ff1e7120SSebastian Grimberg 436*ff1e7120SSebastian Grimberg return CeedVectorGetArray_Cuda(vec, mem_type, array); 437*ff1e7120SSebastian Grimberg } 438*ff1e7120SSebastian Grimberg 439*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 440*ff1e7120SSebastian Grimberg // Get the norm of a CeedVector 441*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 442*ff1e7120SSebastian Grimberg static int CeedVectorNorm_Cuda(CeedVector vec, CeedNormType type, CeedScalar *norm) { 443*ff1e7120SSebastian Grimberg Ceed ceed; 444*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 445*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 446*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 447*ff1e7120SSebastian Grimberg CeedSize length; 448*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(vec, &length)); 449*ff1e7120SSebastian Grimberg cublasHandle_t handle; 450*ff1e7120SSebastian Grimberg CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle)); 451*ff1e7120SSebastian Grimberg 452*ff1e7120SSebastian Grimberg // Compute norm 453*ff1e7120SSebastian Grimberg const CeedScalar *d_array; 454*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array)); 455*ff1e7120SSebastian Grimberg switch (type) { 456*ff1e7120SSebastian Grimberg case CEED_NORM_1: { 457*ff1e7120SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 458*ff1e7120SSebastian Grimberg CeedCallCublas(ceed, cublasSasum(handle, length, (float *)d_array, 1, (float *)norm)); 459*ff1e7120SSebastian Grimberg } else { 460*ff1e7120SSebastian Grimberg CeedCallCublas(ceed, cublasDasum(handle, length, (double *)d_array, 1, (double *)norm)); 461*ff1e7120SSebastian Grimberg } 462*ff1e7120SSebastian Grimberg break; 463*ff1e7120SSebastian Grimberg } 464*ff1e7120SSebastian Grimberg case CEED_NORM_2: { 465*ff1e7120SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 466*ff1e7120SSebastian Grimberg CeedCallCublas(ceed, cublasSnrm2(handle, length, (float *)d_array, 1, (float *)norm)); 467*ff1e7120SSebastian Grimberg } else { 468*ff1e7120SSebastian Grimberg CeedCallCublas(ceed, cublasDnrm2(handle, length, (double *)d_array, 1, (double *)norm)); 469*ff1e7120SSebastian Grimberg } 470*ff1e7120SSebastian Grimberg break; 471*ff1e7120SSebastian Grimberg } 472*ff1e7120SSebastian Grimberg case CEED_NORM_MAX: { 473*ff1e7120SSebastian Grimberg CeedInt indx; 474*ff1e7120SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 475*ff1e7120SSebastian Grimberg CeedCallCublas(ceed, cublasIsamax(handle, length, (float *)d_array, 1, &indx)); 476*ff1e7120SSebastian Grimberg } else { 477*ff1e7120SSebastian Grimberg CeedCallCublas(ceed, cublasIdamax(handle, length, (double *)d_array, 1, &indx)); 478*ff1e7120SSebastian Grimberg } 479*ff1e7120SSebastian Grimberg CeedScalar normNoAbs; 480*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost)); 481*ff1e7120SSebastian Grimberg *norm = fabs(normNoAbs); 482*ff1e7120SSebastian Grimberg break; 483*ff1e7120SSebastian Grimberg } 484*ff1e7120SSebastian Grimberg } 485*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array)); 486*ff1e7120SSebastian Grimberg 487*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 488*ff1e7120SSebastian Grimberg } 489*ff1e7120SSebastian Grimberg 490*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 491*ff1e7120SSebastian Grimberg // Take reciprocal of a vector on host 492*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 493*ff1e7120SSebastian Grimberg static int CeedHostReciprocal_Cuda(CeedScalar *h_array, CeedInt length) { 494*ff1e7120SSebastian Grimberg for (int i = 0; i < length; i++) { 495*ff1e7120SSebastian Grimberg if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i]; 496*ff1e7120SSebastian Grimberg } 497*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 498*ff1e7120SSebastian Grimberg } 499*ff1e7120SSebastian Grimberg 500*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 501*ff1e7120SSebastian Grimberg // Take reciprocal of a vector on device (impl in .cu file) 502*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 503*ff1e7120SSebastian Grimberg int CeedDeviceReciprocal_Cuda(CeedScalar *d_array, CeedInt length); 504*ff1e7120SSebastian Grimberg 505*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 506*ff1e7120SSebastian Grimberg // Take reciprocal of a vector 507*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 508*ff1e7120SSebastian Grimberg static int CeedVectorReciprocal_Cuda(CeedVector vec) { 509*ff1e7120SSebastian Grimberg Ceed ceed; 510*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 511*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 512*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 513*ff1e7120SSebastian Grimberg CeedSize length; 514*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(vec, &length)); 515*ff1e7120SSebastian Grimberg 516*ff1e7120SSebastian Grimberg // Set value for synced device/host array 517*ff1e7120SSebastian Grimberg if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Cuda(impl->d_array, length)); 518*ff1e7120SSebastian Grimberg if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Cuda(impl->h_array, length)); 519*ff1e7120SSebastian Grimberg 520*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 521*ff1e7120SSebastian Grimberg } 522*ff1e7120SSebastian Grimberg 523*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 524*ff1e7120SSebastian Grimberg // Compute x = alpha x on the host 525*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 526*ff1e7120SSebastian Grimberg static int CeedHostScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedInt length) { 527*ff1e7120SSebastian Grimberg for (int i = 0; i < length; i++) x_array[i] *= alpha; 528*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 529*ff1e7120SSebastian Grimberg } 530*ff1e7120SSebastian Grimberg 531*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 532*ff1e7120SSebastian Grimberg // Compute x = alpha x on device (impl in .cu file) 533*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 534*ff1e7120SSebastian Grimberg int CeedDeviceScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedInt length); 535*ff1e7120SSebastian Grimberg 536*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 537*ff1e7120SSebastian Grimberg // Compute x = alpha x 538*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 539*ff1e7120SSebastian Grimberg static int CeedVectorScale_Cuda(CeedVector x, CeedScalar alpha) { 540*ff1e7120SSebastian Grimberg Ceed ceed; 541*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(x, &ceed)); 542*ff1e7120SSebastian Grimberg CeedVector_Cuda *x_impl; 543*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(x, &x_impl)); 544*ff1e7120SSebastian Grimberg CeedSize length; 545*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(x, &length)); 546*ff1e7120SSebastian Grimberg 547*ff1e7120SSebastian Grimberg // Set value for synced device/host array 548*ff1e7120SSebastian Grimberg if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Cuda(x_impl->d_array, alpha, length)); 549*ff1e7120SSebastian Grimberg if (x_impl->h_array) CeedCallBackend(CeedHostScale_Cuda(x_impl->h_array, alpha, length)); 550*ff1e7120SSebastian Grimberg 551*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 552*ff1e7120SSebastian Grimberg } 553*ff1e7120SSebastian Grimberg 554*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 555*ff1e7120SSebastian Grimberg // Compute y = alpha x + y on the host 556*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 557*ff1e7120SSebastian Grimberg static int CeedHostAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) { 558*ff1e7120SSebastian Grimberg for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i]; 559*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 560*ff1e7120SSebastian Grimberg } 561*ff1e7120SSebastian Grimberg 562*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 563*ff1e7120SSebastian Grimberg // Compute y = alpha x + y on device (impl in .cu file) 564*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 565*ff1e7120SSebastian Grimberg int CeedDeviceAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length); 566*ff1e7120SSebastian Grimberg 567*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 568*ff1e7120SSebastian Grimberg // Compute y = alpha x + y 569*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 570*ff1e7120SSebastian Grimberg static int CeedVectorAXPY_Cuda(CeedVector y, CeedScalar alpha, CeedVector x) { 571*ff1e7120SSebastian Grimberg Ceed ceed; 572*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 573*ff1e7120SSebastian Grimberg CeedVector_Cuda *y_impl, *x_impl; 574*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(y, &y_impl)); 575*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(x, &x_impl)); 576*ff1e7120SSebastian Grimberg CeedSize length; 577*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(y, &length)); 578*ff1e7120SSebastian Grimberg 579*ff1e7120SSebastian Grimberg // Set value for synced device/host array 580*ff1e7120SSebastian Grimberg if (y_impl->d_array) { 581*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 582*ff1e7120SSebastian Grimberg CeedCallBackend(CeedDeviceAXPY_Cuda(y_impl->d_array, alpha, x_impl->d_array, length)); 583*ff1e7120SSebastian Grimberg } 584*ff1e7120SSebastian Grimberg if (y_impl->h_array) { 585*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 586*ff1e7120SSebastian Grimberg CeedCallBackend(CeedHostAXPY_Cuda(y_impl->h_array, alpha, x_impl->h_array, length)); 587*ff1e7120SSebastian Grimberg } 588*ff1e7120SSebastian Grimberg 589*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 590*ff1e7120SSebastian Grimberg } 591*ff1e7120SSebastian Grimberg 592*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 593*ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y on the host 594*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 595*ff1e7120SSebastian Grimberg static int CeedHostAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length) { 596*ff1e7120SSebastian Grimberg for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i]; 597*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 598*ff1e7120SSebastian Grimberg } 599*ff1e7120SSebastian Grimberg 600*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 601*ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y on device (impl in .cu file) 602*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 603*ff1e7120SSebastian Grimberg int CeedDeviceAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length); 604*ff1e7120SSebastian Grimberg 605*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 606*ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y 607*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 608*ff1e7120SSebastian Grimberg static int CeedVectorAXPBY_Cuda(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) { 609*ff1e7120SSebastian Grimberg Ceed ceed; 610*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(y, &ceed)); 611*ff1e7120SSebastian Grimberg CeedVector_Cuda *y_impl, *x_impl; 612*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(y, &y_impl)); 613*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(x, &x_impl)); 614*ff1e7120SSebastian Grimberg CeedSize length; 615*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(y, &length)); 616*ff1e7120SSebastian Grimberg 617*ff1e7120SSebastian Grimberg // Set value for synced device/host array 618*ff1e7120SSebastian Grimberg if (y_impl->d_array) { 619*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 620*ff1e7120SSebastian Grimberg CeedCallBackend(CeedDeviceAXPBY_Cuda(y_impl->d_array, alpha, beta, x_impl->d_array, length)); 621*ff1e7120SSebastian Grimberg } 622*ff1e7120SSebastian Grimberg if (y_impl->h_array) { 623*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 624*ff1e7120SSebastian Grimberg CeedCallBackend(CeedHostAXPBY_Cuda(y_impl->h_array, alpha, beta, x_impl->h_array, length)); 625*ff1e7120SSebastian Grimberg } 626*ff1e7120SSebastian Grimberg 627*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 628*ff1e7120SSebastian Grimberg } 629*ff1e7120SSebastian Grimberg 630*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 631*ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y on the host 632*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 633*ff1e7120SSebastian Grimberg static int CeedHostPointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) { 634*ff1e7120SSebastian Grimberg for (int i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i]; 635*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 636*ff1e7120SSebastian Grimberg } 637*ff1e7120SSebastian Grimberg 638*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 639*ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y on device (impl in .cu file) 640*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 641*ff1e7120SSebastian Grimberg int CeedDevicePointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length); 642*ff1e7120SSebastian Grimberg 643*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 644*ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y 645*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 646*ff1e7120SSebastian Grimberg static int CeedVectorPointwiseMult_Cuda(CeedVector w, CeedVector x, CeedVector y) { 647*ff1e7120SSebastian Grimberg Ceed ceed; 648*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(w, &ceed)); 649*ff1e7120SSebastian Grimberg CeedVector_Cuda *w_impl, *x_impl, *y_impl; 650*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(w, &w_impl)); 651*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(x, &x_impl)); 652*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(y, &y_impl)); 653*ff1e7120SSebastian Grimberg CeedSize length; 654*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetLength(w, &length)); 655*ff1e7120SSebastian Grimberg 656*ff1e7120SSebastian Grimberg // Set value for synced device/host array 657*ff1e7120SSebastian Grimberg if (!w_impl->d_array && !w_impl->h_array) { 658*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSetValue(w, 0.0)); 659*ff1e7120SSebastian Grimberg } 660*ff1e7120SSebastian Grimberg if (w_impl->d_array) { 661*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE)); 662*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE)); 663*ff1e7120SSebastian Grimberg CeedCallBackend(CeedDevicePointwiseMult_Cuda(w_impl->d_array, x_impl->d_array, y_impl->d_array, length)); 664*ff1e7120SSebastian Grimberg } 665*ff1e7120SSebastian Grimberg if (w_impl->h_array) { 666*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST)); 667*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST)); 668*ff1e7120SSebastian Grimberg CeedCallBackend(CeedHostPointwiseMult_Cuda(w_impl->h_array, x_impl->h_array, y_impl->h_array, length)); 669*ff1e7120SSebastian Grimberg } 670*ff1e7120SSebastian Grimberg 671*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 672*ff1e7120SSebastian Grimberg } 673*ff1e7120SSebastian Grimberg 674*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 675*ff1e7120SSebastian Grimberg // Destroy the vector 676*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 677*ff1e7120SSebastian Grimberg static int CeedVectorDestroy_Cuda(const CeedVector vec) { 678*ff1e7120SSebastian Grimberg Ceed ceed; 679*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 680*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 681*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetData(vec, &impl)); 682*ff1e7120SSebastian Grimberg 683*ff1e7120SSebastian Grimberg CeedCallCuda(ceed, cudaFree(impl->d_array_owned)); 684*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&impl->h_array_owned)); 685*ff1e7120SSebastian Grimberg CeedCallBackend(CeedFree(&impl)); 686*ff1e7120SSebastian Grimberg 687*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 688*ff1e7120SSebastian Grimberg } 689*ff1e7120SSebastian Grimberg 690*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 691*ff1e7120SSebastian Grimberg // Create a vector of the specified length (does not allocate memory) 692*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 693*ff1e7120SSebastian Grimberg int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) { 694*ff1e7120SSebastian Grimberg CeedVector_Cuda *impl; 695*ff1e7120SSebastian Grimberg Ceed ceed; 696*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorGetCeed(vec, &ceed)); 697*ff1e7120SSebastian Grimberg 698*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Cuda)); 699*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Cuda)); 700*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Cuda)); 701*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Cuda)); 702*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Cuda)); 703*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Cuda)); 704*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda)); 705*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Cuda)); 706*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Cuda)); 707*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Cuda)); 708*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Cuda)); 709*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Cuda)); 710*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Cuda)); 711*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Cuda)); 712*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Cuda)); 713*ff1e7120SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Cuda)); 714*ff1e7120SSebastian Grimberg 715*ff1e7120SSebastian Grimberg CeedCallBackend(CeedCalloc(1, &impl)); 716*ff1e7120SSebastian Grimberg CeedCallBackend(CeedVectorSetData(vec, impl)); 717*ff1e7120SSebastian Grimberg 718*ff1e7120SSebastian Grimberg return CEED_ERROR_SUCCESS; 719*ff1e7120SSebastian Grimberg } 720*ff1e7120SSebastian Grimberg 721*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------ 722