xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-ref/ceed-cuda-ref-vector.c (revision ff1e7120ad38c28723000cabebb8ede4ff31c408)
1*ff1e7120SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*ff1e7120SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*ff1e7120SSebastian Grimberg //
4*ff1e7120SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5*ff1e7120SSebastian Grimberg //
6*ff1e7120SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7*ff1e7120SSebastian Grimberg 
8*ff1e7120SSebastian Grimberg #include <ceed.h>
9*ff1e7120SSebastian Grimberg #include <ceed/backend.h>
10*ff1e7120SSebastian Grimberg #include <cublas_v2.h>
11*ff1e7120SSebastian Grimberg #include <cuda_runtime.h>
12*ff1e7120SSebastian Grimberg #include <math.h>
13*ff1e7120SSebastian Grimberg #include <stdbool.h>
14*ff1e7120SSebastian Grimberg #include <string.h>
15*ff1e7120SSebastian Grimberg 
16*ff1e7120SSebastian Grimberg #include "../cuda/ceed-cuda-common.h"
17*ff1e7120SSebastian Grimberg #include "ceed-cuda-ref.h"
18*ff1e7120SSebastian Grimberg 
19*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
20*ff1e7120SSebastian Grimberg // Check if host/device sync is needed
21*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
22*ff1e7120SSebastian Grimberg static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
23*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
24*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
25*ff1e7120SSebastian Grimberg 
26*ff1e7120SSebastian Grimberg   bool has_valid_array = false;
27*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
28*ff1e7120SSebastian Grimberg   switch (mem_type) {
29*ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
30*ff1e7120SSebastian Grimberg       *need_sync = has_valid_array && !impl->h_array;
31*ff1e7120SSebastian Grimberg       break;
32*ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
33*ff1e7120SSebastian Grimberg       *need_sync = has_valid_array && !impl->d_array;
34*ff1e7120SSebastian Grimberg       break;
35*ff1e7120SSebastian Grimberg   }
36*ff1e7120SSebastian Grimberg 
37*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
38*ff1e7120SSebastian Grimberg }
39*ff1e7120SSebastian Grimberg 
40*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
41*ff1e7120SSebastian Grimberg // Sync host to device
42*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
43*ff1e7120SSebastian Grimberg static inline int CeedVectorSyncH2D_Cuda(const CeedVector vec) {
44*ff1e7120SSebastian Grimberg   Ceed ceed;
45*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
46*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
47*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
48*ff1e7120SSebastian Grimberg 
49*ff1e7120SSebastian Grimberg   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
50*ff1e7120SSebastian Grimberg 
51*ff1e7120SSebastian Grimberg   CeedSize length;
52*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
53*ff1e7120SSebastian Grimberg   size_t bytes = length * sizeof(CeedScalar);
54*ff1e7120SSebastian Grimberg 
55*ff1e7120SSebastian Grimberg   if (impl->d_array_borrowed) {
56*ff1e7120SSebastian Grimberg     impl->d_array = impl->d_array_borrowed;
57*ff1e7120SSebastian Grimberg   } else if (impl->d_array_owned) {
58*ff1e7120SSebastian Grimberg     impl->d_array = impl->d_array_owned;
59*ff1e7120SSebastian Grimberg   } else {
60*ff1e7120SSebastian Grimberg     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
61*ff1e7120SSebastian Grimberg     impl->d_array = impl->d_array_owned;
62*ff1e7120SSebastian Grimberg   }
63*ff1e7120SSebastian Grimberg 
64*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMemcpy(impl->d_array, impl->h_array, bytes, cudaMemcpyHostToDevice));
65*ff1e7120SSebastian Grimberg 
66*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
67*ff1e7120SSebastian Grimberg }
68*ff1e7120SSebastian Grimberg 
69*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
70*ff1e7120SSebastian Grimberg // Sync device to host
71*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
72*ff1e7120SSebastian Grimberg static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
73*ff1e7120SSebastian Grimberg   Ceed ceed;
74*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
75*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
76*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
77*ff1e7120SSebastian Grimberg 
78*ff1e7120SSebastian Grimberg   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
79*ff1e7120SSebastian Grimberg 
80*ff1e7120SSebastian Grimberg   if (impl->h_array_borrowed) {
81*ff1e7120SSebastian Grimberg     impl->h_array = impl->h_array_borrowed;
82*ff1e7120SSebastian Grimberg   } else if (impl->h_array_owned) {
83*ff1e7120SSebastian Grimberg     impl->h_array = impl->h_array_owned;
84*ff1e7120SSebastian Grimberg   } else {
85*ff1e7120SSebastian Grimberg     CeedSize length;
86*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorGetLength(vec, &length));
87*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
88*ff1e7120SSebastian Grimberg     impl->h_array = impl->h_array_owned;
89*ff1e7120SSebastian Grimberg   }
90*ff1e7120SSebastian Grimberg 
91*ff1e7120SSebastian Grimberg   CeedSize length;
92*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
93*ff1e7120SSebastian Grimberg   size_t bytes = length * sizeof(CeedScalar);
94*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMemcpy(impl->h_array, impl->d_array, bytes, cudaMemcpyDeviceToHost));
95*ff1e7120SSebastian Grimberg 
96*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
97*ff1e7120SSebastian Grimberg }
98*ff1e7120SSebastian Grimberg 
99*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
100*ff1e7120SSebastian Grimberg // Sync arrays
101*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
102*ff1e7120SSebastian Grimberg static int CeedVectorSyncArray_Cuda(const CeedVector vec, CeedMemType mem_type) {
103*ff1e7120SSebastian Grimberg   // Check whether device/host sync is needed
104*ff1e7120SSebastian Grimberg   bool need_sync = false;
105*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync));
106*ff1e7120SSebastian Grimberg   if (!need_sync) return CEED_ERROR_SUCCESS;
107*ff1e7120SSebastian Grimberg 
108*ff1e7120SSebastian Grimberg   switch (mem_type) {
109*ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
110*ff1e7120SSebastian Grimberg       return CeedVectorSyncD2H_Cuda(vec);
111*ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
112*ff1e7120SSebastian Grimberg       return CeedVectorSyncH2D_Cuda(vec);
113*ff1e7120SSebastian Grimberg   }
114*ff1e7120SSebastian Grimberg   return CEED_ERROR_UNSUPPORTED;
115*ff1e7120SSebastian Grimberg }
116*ff1e7120SSebastian Grimberg 
117*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
118*ff1e7120SSebastian Grimberg // Set all pointers as invalid
119*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
120*ff1e7120SSebastian Grimberg static inline int CeedVectorSetAllInvalid_Cuda(const CeedVector vec) {
121*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
122*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
123*ff1e7120SSebastian Grimberg 
124*ff1e7120SSebastian Grimberg   impl->h_array = NULL;
125*ff1e7120SSebastian Grimberg   impl->d_array = NULL;
126*ff1e7120SSebastian Grimberg 
127*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
128*ff1e7120SSebastian Grimberg }
129*ff1e7120SSebastian Grimberg 
130*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
131*ff1e7120SSebastian Grimberg // Check if CeedVector has any valid pointer
132*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
133*ff1e7120SSebastian Grimberg static inline int CeedVectorHasValidArray_Cuda(const CeedVector vec, bool *has_valid_array) {
134*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
135*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
136*ff1e7120SSebastian Grimberg 
137*ff1e7120SSebastian Grimberg   *has_valid_array = !!impl->h_array || !!impl->d_array;
138*ff1e7120SSebastian Grimberg 
139*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
140*ff1e7120SSebastian Grimberg }
141*ff1e7120SSebastian Grimberg 
142*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
143*ff1e7120SSebastian Grimberg // Check if has array of given type
144*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
145*ff1e7120SSebastian Grimberg static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
146*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
147*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
148*ff1e7120SSebastian Grimberg 
149*ff1e7120SSebastian Grimberg   switch (mem_type) {
150*ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
151*ff1e7120SSebastian Grimberg       *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
152*ff1e7120SSebastian Grimberg       break;
153*ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
154*ff1e7120SSebastian Grimberg       *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
155*ff1e7120SSebastian Grimberg       break;
156*ff1e7120SSebastian Grimberg   }
157*ff1e7120SSebastian Grimberg 
158*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
159*ff1e7120SSebastian Grimberg }
160*ff1e7120SSebastian Grimberg 
161*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
162*ff1e7120SSebastian Grimberg // Check if has borrowed array of given type
163*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
164*ff1e7120SSebastian Grimberg static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
165*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
166*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
167*ff1e7120SSebastian Grimberg 
168*ff1e7120SSebastian Grimberg   switch (mem_type) {
169*ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
170*ff1e7120SSebastian Grimberg       *has_borrowed_array_of_type = !!impl->h_array_borrowed;
171*ff1e7120SSebastian Grimberg       break;
172*ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
173*ff1e7120SSebastian Grimberg       *has_borrowed_array_of_type = !!impl->d_array_borrowed;
174*ff1e7120SSebastian Grimberg       break;
175*ff1e7120SSebastian Grimberg   }
176*ff1e7120SSebastian Grimberg 
177*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
178*ff1e7120SSebastian Grimberg }
179*ff1e7120SSebastian Grimberg 
180*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
181*ff1e7120SSebastian Grimberg // Set array from host
182*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
183*ff1e7120SSebastian Grimberg static int CeedVectorSetArrayHost_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
184*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
185*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
186*ff1e7120SSebastian Grimberg 
187*ff1e7120SSebastian Grimberg   switch (copy_mode) {
188*ff1e7120SSebastian Grimberg     case CEED_COPY_VALUES: {
189*ff1e7120SSebastian Grimberg       CeedSize length;
190*ff1e7120SSebastian Grimberg       if (!impl->h_array_owned) {
191*ff1e7120SSebastian Grimberg         CeedCallBackend(CeedVectorGetLength(vec, &length));
192*ff1e7120SSebastian Grimberg         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
193*ff1e7120SSebastian Grimberg       }
194*ff1e7120SSebastian Grimberg       impl->h_array_borrowed = NULL;
195*ff1e7120SSebastian Grimberg       impl->h_array          = impl->h_array_owned;
196*ff1e7120SSebastian Grimberg       if (array) {
197*ff1e7120SSebastian Grimberg         CeedSize length;
198*ff1e7120SSebastian Grimberg         CeedCallBackend(CeedVectorGetLength(vec, &length));
199*ff1e7120SSebastian Grimberg         size_t bytes = length * sizeof(CeedScalar);
200*ff1e7120SSebastian Grimberg         memcpy(impl->h_array, array, bytes);
201*ff1e7120SSebastian Grimberg       }
202*ff1e7120SSebastian Grimberg     } break;
203*ff1e7120SSebastian Grimberg     case CEED_OWN_POINTER:
204*ff1e7120SSebastian Grimberg       CeedCallBackend(CeedFree(&impl->h_array_owned));
205*ff1e7120SSebastian Grimberg       impl->h_array_owned    = array;
206*ff1e7120SSebastian Grimberg       impl->h_array_borrowed = NULL;
207*ff1e7120SSebastian Grimberg       impl->h_array          = array;
208*ff1e7120SSebastian Grimberg       break;
209*ff1e7120SSebastian Grimberg     case CEED_USE_POINTER:
210*ff1e7120SSebastian Grimberg       CeedCallBackend(CeedFree(&impl->h_array_owned));
211*ff1e7120SSebastian Grimberg       impl->h_array_borrowed = array;
212*ff1e7120SSebastian Grimberg       impl->h_array          = array;
213*ff1e7120SSebastian Grimberg       break;
214*ff1e7120SSebastian Grimberg   }
215*ff1e7120SSebastian Grimberg 
216*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
217*ff1e7120SSebastian Grimberg }
218*ff1e7120SSebastian Grimberg 
219*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
220*ff1e7120SSebastian Grimberg // Set array from device
221*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
222*ff1e7120SSebastian Grimberg static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
223*ff1e7120SSebastian Grimberg   Ceed ceed;
224*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
225*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
226*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
227*ff1e7120SSebastian Grimberg 
228*ff1e7120SSebastian Grimberg   switch (copy_mode) {
229*ff1e7120SSebastian Grimberg     case CEED_COPY_VALUES: {
230*ff1e7120SSebastian Grimberg       CeedSize length;
231*ff1e7120SSebastian Grimberg       CeedCallBackend(CeedVectorGetLength(vec, &length));
232*ff1e7120SSebastian Grimberg       size_t bytes = length * sizeof(CeedScalar);
233*ff1e7120SSebastian Grimberg       if (!impl->d_array_owned) {
234*ff1e7120SSebastian Grimberg         CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
235*ff1e7120SSebastian Grimberg         impl->d_array = impl->d_array_owned;
236*ff1e7120SSebastian Grimberg       }
237*ff1e7120SSebastian Grimberg       if (array) CeedCallCuda(ceed, cudaMemcpy(impl->d_array, array, bytes, cudaMemcpyDeviceToDevice));
238*ff1e7120SSebastian Grimberg     } break;
239*ff1e7120SSebastian Grimberg     case CEED_OWN_POINTER:
240*ff1e7120SSebastian Grimberg       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
241*ff1e7120SSebastian Grimberg       impl->d_array_owned    = array;
242*ff1e7120SSebastian Grimberg       impl->d_array_borrowed = NULL;
243*ff1e7120SSebastian Grimberg       impl->d_array          = array;
244*ff1e7120SSebastian Grimberg       break;
245*ff1e7120SSebastian Grimberg     case CEED_USE_POINTER:
246*ff1e7120SSebastian Grimberg       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
247*ff1e7120SSebastian Grimberg       impl->d_array_owned    = NULL;
248*ff1e7120SSebastian Grimberg       impl->d_array_borrowed = array;
249*ff1e7120SSebastian Grimberg       impl->d_array          = array;
250*ff1e7120SSebastian Grimberg       break;
251*ff1e7120SSebastian Grimberg   }
252*ff1e7120SSebastian Grimberg 
253*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
254*ff1e7120SSebastian Grimberg }
255*ff1e7120SSebastian Grimberg 
256*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
257*ff1e7120SSebastian Grimberg // Set the array used by a vector,
258*ff1e7120SSebastian Grimberg //   freeing any previously allocated array if applicable
259*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
260*ff1e7120SSebastian Grimberg static int CeedVectorSetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
261*ff1e7120SSebastian Grimberg   Ceed ceed;
262*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
263*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
264*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
265*ff1e7120SSebastian Grimberg 
266*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
267*ff1e7120SSebastian Grimberg   switch (mem_type) {
268*ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
269*ff1e7120SSebastian Grimberg       return CeedVectorSetArrayHost_Cuda(vec, copy_mode, array);
270*ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
271*ff1e7120SSebastian Grimberg       return CeedVectorSetArrayDevice_Cuda(vec, copy_mode, array);
272*ff1e7120SSebastian Grimberg   }
273*ff1e7120SSebastian Grimberg 
274*ff1e7120SSebastian Grimberg   return CEED_ERROR_UNSUPPORTED;
275*ff1e7120SSebastian Grimberg }
276*ff1e7120SSebastian Grimberg 
277*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
278*ff1e7120SSebastian Grimberg // Set host array to value
279*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
280*ff1e7120SSebastian Grimberg static int CeedHostSetValue_Cuda(CeedScalar *h_array, CeedInt length, CeedScalar val) {
281*ff1e7120SSebastian Grimberg   for (int i = 0; i < length; i++) h_array[i] = val;
282*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
283*ff1e7120SSebastian Grimberg }
284*ff1e7120SSebastian Grimberg 
285*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
286*ff1e7120SSebastian Grimberg // Set device array to value (impl in .cu file)
287*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
288*ff1e7120SSebastian Grimberg int CeedDeviceSetValue_Cuda(CeedScalar *d_array, CeedInt length, CeedScalar val);
289*ff1e7120SSebastian Grimberg 
290*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
291*ff1e7120SSebastian Grimberg // Set a vector to a value,
292*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
293*ff1e7120SSebastian Grimberg static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) {
294*ff1e7120SSebastian Grimberg   Ceed ceed;
295*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
296*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
297*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
298*ff1e7120SSebastian Grimberg   CeedSize length;
299*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
300*ff1e7120SSebastian Grimberg 
301*ff1e7120SSebastian Grimberg   // Set value for synced device/host array
302*ff1e7120SSebastian Grimberg   if (!impl->d_array && !impl->h_array) {
303*ff1e7120SSebastian Grimberg     if (impl->d_array_borrowed) {
304*ff1e7120SSebastian Grimberg       impl->d_array = impl->d_array_borrowed;
305*ff1e7120SSebastian Grimberg     } else if (impl->h_array_borrowed) {
306*ff1e7120SSebastian Grimberg       impl->h_array = impl->h_array_borrowed;
307*ff1e7120SSebastian Grimberg     } else if (impl->d_array_owned) {
308*ff1e7120SSebastian Grimberg       impl->d_array = impl->d_array_owned;
309*ff1e7120SSebastian Grimberg     } else if (impl->h_array_owned) {
310*ff1e7120SSebastian Grimberg       impl->h_array = impl->h_array_owned;
311*ff1e7120SSebastian Grimberg     } else {
312*ff1e7120SSebastian Grimberg       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
313*ff1e7120SSebastian Grimberg     }
314*ff1e7120SSebastian Grimberg   }
315*ff1e7120SSebastian Grimberg   if (impl->d_array) {
316*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDeviceSetValue_Cuda(impl->d_array, length, val));
317*ff1e7120SSebastian Grimberg     impl->h_array = NULL;
318*ff1e7120SSebastian Grimberg   }
319*ff1e7120SSebastian Grimberg   if (impl->h_array) {
320*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostSetValue_Cuda(impl->h_array, length, val));
321*ff1e7120SSebastian Grimberg     impl->d_array = NULL;
322*ff1e7120SSebastian Grimberg   }
323*ff1e7120SSebastian Grimberg 
324*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
325*ff1e7120SSebastian Grimberg }
326*ff1e7120SSebastian Grimberg 
327*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
328*ff1e7120SSebastian Grimberg // Vector Take Array
329*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
330*ff1e7120SSebastian Grimberg static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
331*ff1e7120SSebastian Grimberg   Ceed ceed;
332*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
333*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
334*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
335*ff1e7120SSebastian Grimberg 
336*ff1e7120SSebastian Grimberg   // Sync array to requested mem_type
337*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
338*ff1e7120SSebastian Grimberg 
339*ff1e7120SSebastian Grimberg   // Update pointer
340*ff1e7120SSebastian Grimberg   switch (mem_type) {
341*ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
342*ff1e7120SSebastian Grimberg       (*array)               = impl->h_array_borrowed;
343*ff1e7120SSebastian Grimberg       impl->h_array_borrowed = NULL;
344*ff1e7120SSebastian Grimberg       impl->h_array          = NULL;
345*ff1e7120SSebastian Grimberg       break;
346*ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
347*ff1e7120SSebastian Grimberg       (*array)               = impl->d_array_borrowed;
348*ff1e7120SSebastian Grimberg       impl->d_array_borrowed = NULL;
349*ff1e7120SSebastian Grimberg       impl->d_array          = NULL;
350*ff1e7120SSebastian Grimberg       break;
351*ff1e7120SSebastian Grimberg   }
352*ff1e7120SSebastian Grimberg 
353*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
354*ff1e7120SSebastian Grimberg }
355*ff1e7120SSebastian Grimberg 
356*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
357*ff1e7120SSebastian Grimberg // Core logic for array syncronization for GetArray.
358*ff1e7120SSebastian Grimberg //   If a different memory type is most up to date, this will perform a copy
359*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
360*ff1e7120SSebastian Grimberg static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
361*ff1e7120SSebastian Grimberg   Ceed ceed;
362*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
363*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
364*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
365*ff1e7120SSebastian Grimberg 
366*ff1e7120SSebastian Grimberg   // Sync array to requested mem_type
367*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
368*ff1e7120SSebastian Grimberg 
369*ff1e7120SSebastian Grimberg   // Update pointer
370*ff1e7120SSebastian Grimberg   switch (mem_type) {
371*ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
372*ff1e7120SSebastian Grimberg       *array = impl->h_array;
373*ff1e7120SSebastian Grimberg       break;
374*ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
375*ff1e7120SSebastian Grimberg       *array = impl->d_array;
376*ff1e7120SSebastian Grimberg       break;
377*ff1e7120SSebastian Grimberg   }
378*ff1e7120SSebastian Grimberg 
379*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
380*ff1e7120SSebastian Grimberg }
381*ff1e7120SSebastian Grimberg 
382*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
383*ff1e7120SSebastian Grimberg // Get read-only access to a vector via the specified mem_type
384*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
385*ff1e7120SSebastian Grimberg static int CeedVectorGetArrayRead_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
386*ff1e7120SSebastian Grimberg   return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array);
387*ff1e7120SSebastian Grimberg }
388*ff1e7120SSebastian Grimberg 
389*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
390*ff1e7120SSebastian Grimberg // Get read/write access to a vector via the specified mem_type
391*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
392*ff1e7120SSebastian Grimberg static int CeedVectorGetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
393*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
394*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
395*ff1e7120SSebastian Grimberg 
396*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetArrayCore_Cuda(vec, mem_type, array));
397*ff1e7120SSebastian Grimberg 
398*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
399*ff1e7120SSebastian Grimberg   switch (mem_type) {
400*ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
401*ff1e7120SSebastian Grimberg       impl->h_array = *array;
402*ff1e7120SSebastian Grimberg       break;
403*ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
404*ff1e7120SSebastian Grimberg       impl->d_array = *array;
405*ff1e7120SSebastian Grimberg       break;
406*ff1e7120SSebastian Grimberg   }
407*ff1e7120SSebastian Grimberg 
408*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
409*ff1e7120SSebastian Grimberg }
410*ff1e7120SSebastian Grimberg 
411*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
412*ff1e7120SSebastian Grimberg // Get write access to a vector via the specified mem_type
413*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
414*ff1e7120SSebastian Grimberg static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
415*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
416*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
417*ff1e7120SSebastian Grimberg 
418*ff1e7120SSebastian Grimberg   bool has_array_of_type = true;
419*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type));
420*ff1e7120SSebastian Grimberg   if (!has_array_of_type) {
421*ff1e7120SSebastian Grimberg     // Allocate if array is not yet allocated
422*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
423*ff1e7120SSebastian Grimberg   } else {
424*ff1e7120SSebastian Grimberg     // Select dirty array
425*ff1e7120SSebastian Grimberg     switch (mem_type) {
426*ff1e7120SSebastian Grimberg       case CEED_MEM_HOST:
427*ff1e7120SSebastian Grimberg         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
428*ff1e7120SSebastian Grimberg         else impl->h_array = impl->h_array_owned;
429*ff1e7120SSebastian Grimberg         break;
430*ff1e7120SSebastian Grimberg       case CEED_MEM_DEVICE:
431*ff1e7120SSebastian Grimberg         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
432*ff1e7120SSebastian Grimberg         else impl->d_array = impl->d_array_owned;
433*ff1e7120SSebastian Grimberg     }
434*ff1e7120SSebastian Grimberg   }
435*ff1e7120SSebastian Grimberg 
436*ff1e7120SSebastian Grimberg   return CeedVectorGetArray_Cuda(vec, mem_type, array);
437*ff1e7120SSebastian Grimberg }
438*ff1e7120SSebastian Grimberg 
439*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
440*ff1e7120SSebastian Grimberg // Get the norm of a CeedVector
441*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
442*ff1e7120SSebastian Grimberg static int CeedVectorNorm_Cuda(CeedVector vec, CeedNormType type, CeedScalar *norm) {
443*ff1e7120SSebastian Grimberg   Ceed ceed;
444*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
445*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
446*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
447*ff1e7120SSebastian Grimberg   CeedSize length;
448*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
449*ff1e7120SSebastian Grimberg   cublasHandle_t handle;
450*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle));
451*ff1e7120SSebastian Grimberg 
452*ff1e7120SSebastian Grimberg   // Compute norm
453*ff1e7120SSebastian Grimberg   const CeedScalar *d_array;
454*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
455*ff1e7120SSebastian Grimberg   switch (type) {
456*ff1e7120SSebastian Grimberg     case CEED_NORM_1: {
457*ff1e7120SSebastian Grimberg       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
458*ff1e7120SSebastian Grimberg         CeedCallCublas(ceed, cublasSasum(handle, length, (float *)d_array, 1, (float *)norm));
459*ff1e7120SSebastian Grimberg       } else {
460*ff1e7120SSebastian Grimberg         CeedCallCublas(ceed, cublasDasum(handle, length, (double *)d_array, 1, (double *)norm));
461*ff1e7120SSebastian Grimberg       }
462*ff1e7120SSebastian Grimberg       break;
463*ff1e7120SSebastian Grimberg     }
464*ff1e7120SSebastian Grimberg     case CEED_NORM_2: {
465*ff1e7120SSebastian Grimberg       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
466*ff1e7120SSebastian Grimberg         CeedCallCublas(ceed, cublasSnrm2(handle, length, (float *)d_array, 1, (float *)norm));
467*ff1e7120SSebastian Grimberg       } else {
468*ff1e7120SSebastian Grimberg         CeedCallCublas(ceed, cublasDnrm2(handle, length, (double *)d_array, 1, (double *)norm));
469*ff1e7120SSebastian Grimberg       }
470*ff1e7120SSebastian Grimberg       break;
471*ff1e7120SSebastian Grimberg     }
472*ff1e7120SSebastian Grimberg     case CEED_NORM_MAX: {
473*ff1e7120SSebastian Grimberg       CeedInt indx;
474*ff1e7120SSebastian Grimberg       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
475*ff1e7120SSebastian Grimberg         CeedCallCublas(ceed, cublasIsamax(handle, length, (float *)d_array, 1, &indx));
476*ff1e7120SSebastian Grimberg       } else {
477*ff1e7120SSebastian Grimberg         CeedCallCublas(ceed, cublasIdamax(handle, length, (double *)d_array, 1, &indx));
478*ff1e7120SSebastian Grimberg       }
479*ff1e7120SSebastian Grimberg       CeedScalar normNoAbs;
480*ff1e7120SSebastian Grimberg       CeedCallCuda(ceed, cudaMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
481*ff1e7120SSebastian Grimberg       *norm = fabs(normNoAbs);
482*ff1e7120SSebastian Grimberg       break;
483*ff1e7120SSebastian Grimberg     }
484*ff1e7120SSebastian Grimberg   }
485*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
486*ff1e7120SSebastian Grimberg 
487*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
488*ff1e7120SSebastian Grimberg }
489*ff1e7120SSebastian Grimberg 
490*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
491*ff1e7120SSebastian Grimberg // Take reciprocal of a vector on host
492*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
493*ff1e7120SSebastian Grimberg static int CeedHostReciprocal_Cuda(CeedScalar *h_array, CeedInt length) {
494*ff1e7120SSebastian Grimberg   for (int i = 0; i < length; i++) {
495*ff1e7120SSebastian Grimberg     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
496*ff1e7120SSebastian Grimberg   }
497*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
498*ff1e7120SSebastian Grimberg }
499*ff1e7120SSebastian Grimberg 
500*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
501*ff1e7120SSebastian Grimberg // Take reciprocal of a vector on device (impl in .cu file)
502*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
503*ff1e7120SSebastian Grimberg int CeedDeviceReciprocal_Cuda(CeedScalar *d_array, CeedInt length);
504*ff1e7120SSebastian Grimberg 
505*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
506*ff1e7120SSebastian Grimberg // Take reciprocal of a vector
507*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
508*ff1e7120SSebastian Grimberg static int CeedVectorReciprocal_Cuda(CeedVector vec) {
509*ff1e7120SSebastian Grimberg   Ceed ceed;
510*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
511*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
512*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
513*ff1e7120SSebastian Grimberg   CeedSize length;
514*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
515*ff1e7120SSebastian Grimberg 
516*ff1e7120SSebastian Grimberg   // Set value for synced device/host array
517*ff1e7120SSebastian Grimberg   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Cuda(impl->d_array, length));
518*ff1e7120SSebastian Grimberg   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Cuda(impl->h_array, length));
519*ff1e7120SSebastian Grimberg 
520*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
521*ff1e7120SSebastian Grimberg }
522*ff1e7120SSebastian Grimberg 
523*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
524*ff1e7120SSebastian Grimberg // Compute x = alpha x on the host
525*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
526*ff1e7120SSebastian Grimberg static int CeedHostScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedInt length) {
527*ff1e7120SSebastian Grimberg   for (int i = 0; i < length; i++) x_array[i] *= alpha;
528*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
529*ff1e7120SSebastian Grimberg }
530*ff1e7120SSebastian Grimberg 
531*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
532*ff1e7120SSebastian Grimberg // Compute x = alpha x on device (impl in .cu file)
533*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
534*ff1e7120SSebastian Grimberg int CeedDeviceScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedInt length);
535*ff1e7120SSebastian Grimberg 
536*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
537*ff1e7120SSebastian Grimberg // Compute x = alpha x
538*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
539*ff1e7120SSebastian Grimberg static int CeedVectorScale_Cuda(CeedVector x, CeedScalar alpha) {
540*ff1e7120SSebastian Grimberg   Ceed ceed;
541*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
542*ff1e7120SSebastian Grimberg   CeedVector_Cuda *x_impl;
543*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(x, &x_impl));
544*ff1e7120SSebastian Grimberg   CeedSize length;
545*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(x, &length));
546*ff1e7120SSebastian Grimberg 
547*ff1e7120SSebastian Grimberg   // Set value for synced device/host array
548*ff1e7120SSebastian Grimberg   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Cuda(x_impl->d_array, alpha, length));
549*ff1e7120SSebastian Grimberg   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Cuda(x_impl->h_array, alpha, length));
550*ff1e7120SSebastian Grimberg 
551*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
552*ff1e7120SSebastian Grimberg }
553*ff1e7120SSebastian Grimberg 
554*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
555*ff1e7120SSebastian Grimberg // Compute y = alpha x + y on the host
556*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
557*ff1e7120SSebastian Grimberg static int CeedHostAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) {
558*ff1e7120SSebastian Grimberg   for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
559*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
560*ff1e7120SSebastian Grimberg }
561*ff1e7120SSebastian Grimberg 
562*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
563*ff1e7120SSebastian Grimberg // Compute y = alpha x + y on device (impl in .cu file)
564*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
565*ff1e7120SSebastian Grimberg int CeedDeviceAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length);
566*ff1e7120SSebastian Grimberg 
567*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
568*ff1e7120SSebastian Grimberg // Compute y = alpha x + y
569*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
570*ff1e7120SSebastian Grimberg static int CeedVectorAXPY_Cuda(CeedVector y, CeedScalar alpha, CeedVector x) {
571*ff1e7120SSebastian Grimberg   Ceed ceed;
572*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
573*ff1e7120SSebastian Grimberg   CeedVector_Cuda *y_impl, *x_impl;
574*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(y, &y_impl));
575*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(x, &x_impl));
576*ff1e7120SSebastian Grimberg   CeedSize length;
577*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(y, &length));
578*ff1e7120SSebastian Grimberg 
579*ff1e7120SSebastian Grimberg   // Set value for synced device/host array
580*ff1e7120SSebastian Grimberg   if (y_impl->d_array) {
581*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
582*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDeviceAXPY_Cuda(y_impl->d_array, alpha, x_impl->d_array, length));
583*ff1e7120SSebastian Grimberg   }
584*ff1e7120SSebastian Grimberg   if (y_impl->h_array) {
585*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
586*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostAXPY_Cuda(y_impl->h_array, alpha, x_impl->h_array, length));
587*ff1e7120SSebastian Grimberg   }
588*ff1e7120SSebastian Grimberg 
589*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
590*ff1e7120SSebastian Grimberg }
591*ff1e7120SSebastian Grimberg 
592*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
593*ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y on the host
594*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
595*ff1e7120SSebastian Grimberg static int CeedHostAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length) {
596*ff1e7120SSebastian Grimberg   for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
597*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
598*ff1e7120SSebastian Grimberg }
599*ff1e7120SSebastian Grimberg 
600*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
601*ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y on device (impl in .cu file)
602*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
603*ff1e7120SSebastian Grimberg int CeedDeviceAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length);
604*ff1e7120SSebastian Grimberg 
605*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
606*ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y
607*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
608*ff1e7120SSebastian Grimberg static int CeedVectorAXPBY_Cuda(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
609*ff1e7120SSebastian Grimberg   Ceed ceed;
610*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
611*ff1e7120SSebastian Grimberg   CeedVector_Cuda *y_impl, *x_impl;
612*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(y, &y_impl));
613*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(x, &x_impl));
614*ff1e7120SSebastian Grimberg   CeedSize length;
615*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(y, &length));
616*ff1e7120SSebastian Grimberg 
617*ff1e7120SSebastian Grimberg   // Set value for synced device/host array
618*ff1e7120SSebastian Grimberg   if (y_impl->d_array) {
619*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
620*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDeviceAXPBY_Cuda(y_impl->d_array, alpha, beta, x_impl->d_array, length));
621*ff1e7120SSebastian Grimberg   }
622*ff1e7120SSebastian Grimberg   if (y_impl->h_array) {
623*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
624*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostAXPBY_Cuda(y_impl->h_array, alpha, beta, x_impl->h_array, length));
625*ff1e7120SSebastian Grimberg   }
626*ff1e7120SSebastian Grimberg 
627*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
628*ff1e7120SSebastian Grimberg }
629*ff1e7120SSebastian Grimberg 
630*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
631*ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y on the host
632*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
633*ff1e7120SSebastian Grimberg static int CeedHostPointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) {
634*ff1e7120SSebastian Grimberg   for (int i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
635*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
636*ff1e7120SSebastian Grimberg }
637*ff1e7120SSebastian Grimberg 
638*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
639*ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
640*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
641*ff1e7120SSebastian Grimberg int CeedDevicePointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length);
642*ff1e7120SSebastian Grimberg 
643*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
644*ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y
645*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
646*ff1e7120SSebastian Grimberg static int CeedVectorPointwiseMult_Cuda(CeedVector w, CeedVector x, CeedVector y) {
647*ff1e7120SSebastian Grimberg   Ceed ceed;
648*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
649*ff1e7120SSebastian Grimberg   CeedVector_Cuda *w_impl, *x_impl, *y_impl;
650*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(w, &w_impl));
651*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(x, &x_impl));
652*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(y, &y_impl));
653*ff1e7120SSebastian Grimberg   CeedSize length;
654*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(w, &length));
655*ff1e7120SSebastian Grimberg 
656*ff1e7120SSebastian Grimberg   // Set value for synced device/host array
657*ff1e7120SSebastian Grimberg   if (!w_impl->d_array && !w_impl->h_array) {
658*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSetValue(w, 0.0));
659*ff1e7120SSebastian Grimberg   }
660*ff1e7120SSebastian Grimberg   if (w_impl->d_array) {
661*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
662*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
663*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDevicePointwiseMult_Cuda(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
664*ff1e7120SSebastian Grimberg   }
665*ff1e7120SSebastian Grimberg   if (w_impl->h_array) {
666*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
667*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
668*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostPointwiseMult_Cuda(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
669*ff1e7120SSebastian Grimberg   }
670*ff1e7120SSebastian Grimberg 
671*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
672*ff1e7120SSebastian Grimberg }
673*ff1e7120SSebastian Grimberg 
674*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
675*ff1e7120SSebastian Grimberg // Destroy the vector
676*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
677*ff1e7120SSebastian Grimberg static int CeedVectorDestroy_Cuda(const CeedVector vec) {
678*ff1e7120SSebastian Grimberg   Ceed ceed;
679*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
680*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
681*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
682*ff1e7120SSebastian Grimberg 
683*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
684*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&impl->h_array_owned));
685*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&impl));
686*ff1e7120SSebastian Grimberg 
687*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
688*ff1e7120SSebastian Grimberg }
689*ff1e7120SSebastian Grimberg 
690*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
691*ff1e7120SSebastian Grimberg // Create a vector of the specified length (does not allocate memory)
692*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
693*ff1e7120SSebastian Grimberg int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
694*ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
695*ff1e7120SSebastian Grimberg   Ceed             ceed;
696*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
697*ff1e7120SSebastian Grimberg 
698*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Cuda));
699*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Cuda));
700*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Cuda));
701*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Cuda));
702*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Cuda));
703*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Cuda));
704*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda));
705*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Cuda));
706*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Cuda));
707*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Cuda));
708*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Cuda));
709*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Cuda));
710*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Cuda));
711*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Cuda));
712*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Cuda));
713*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Cuda));
714*ff1e7120SSebastian Grimberg 
715*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(1, &impl));
716*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSetData(vec, impl));
717*ff1e7120SSebastian Grimberg 
718*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
719*ff1e7120SSebastian Grimberg }
720*ff1e7120SSebastian Grimberg 
721*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
722